[security] transport.Controller{} and transport.Transport{} security and performance improvements (#564)

* cache transports in controller by privkey-generated pubkey, add retry logic to transport requests

Signed-off-by: kim <grufwub@gmail.com>

* update code comments, defer mutex unlocks

Signed-off-by: kim <grufwub@gmail.com>

* add count to 'performing request' log message

Signed-off-by: kim <grufwub@gmail.com>

* reduce repeated conversions of same url.URL object

Signed-off-by: kim <grufwub@gmail.com>

* move worker.Worker to concurrency subpackage, add WorkQueue type, limit transport http client use by WorkQueue

Signed-off-by: kim <grufwub@gmail.com>

* fix security advisories regarding max outgoing conns, max rsp body size

- implemented by a new httpclient.Client{} that wraps an underlying
  client with a queue to limit connections, and limit reader wrapping
  a response body with a configured maximum size
- update pub.HttpClient args passed around to be this new httpclient.Client{}

Signed-off-by: kim <grufwub@gmail.com>

* add httpclient tests, move ip validation to separate package + change mechanism

Signed-off-by: kim <grufwub@gmail.com>

* fix merge conflicts

Signed-off-by: kim <grufwub@gmail.com>

* use singular mutex in transport rather than separate signer mus

Signed-off-by: kim <grufwub@gmail.com>

* improved useragent string

Signed-off-by: kim <grufwub@gmail.com>

* add note regarding missing test

Signed-off-by: kim <grufwub@gmail.com>

* remove useragent field from transport (instead store in controller)

Signed-off-by: kim <grufwub@gmail.com>

* shutup linter

Signed-off-by: kim <grufwub@gmail.com>

* reset other signing headers on each loop iteration

Signed-off-by: kim <grufwub@gmail.com>

* respect request ctx during retry-backoff sleep period

Signed-off-by: kim <grufwub@gmail.com>

* use external pkg with docs explaining performance "hack"

Signed-off-by: kim <grufwub@gmail.com>

* use http package constants instead of string method literals

Signed-off-by: kim <grufwub@gmail.com>

* add license file headers

Signed-off-by: kim <grufwub@gmail.com>

* update code comment to match new func names

Signed-off-by: kim <grufwub@gmail.com>

* updates to user-agent string

Signed-off-by: kim <grufwub@gmail.com>

* update signed testrig models to fit with new transport logic (instead uses separate signer now)

Signed-off-by: kim <grufwub@gmail.com>

* fuck you linter

Signed-off-by: kim <grufwub@gmail.com>
This commit is contained in:
kim 2022-05-15 10:16:43 +01:00 committed by GitHub
parent 4ac508f037
commit 223025fc27
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
61 changed files with 1801 additions and 435 deletions

View file

@ -21,7 +21,6 @@ package server
import (
"context"
"fmt"
"net/http"
"os"
"os/signal"
"path"
@ -56,12 +55,14 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/api/s2s/user"
"github.com/superseriousbusiness/gotosocial/internal/api/s2s/webfinger"
"github.com/superseriousbusiness/gotosocial/internal/api/security"
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db/bundb"
"github.com/superseriousbusiness/gotosocial/internal/email"
"github.com/superseriousbusiness/gotosocial/internal/federation"
"github.com/superseriousbusiness/gotosocial/internal/federation/federatingdb"
"github.com/superseriousbusiness/gotosocial/internal/gotosocial"
"github.com/superseriousbusiness/gotosocial/internal/httpclient"
"github.com/superseriousbusiness/gotosocial/internal/media"
"github.com/superseriousbusiness/gotosocial/internal/messages"
"github.com/superseriousbusiness/gotosocial/internal/oauth"
@ -71,7 +72,6 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/transport"
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
"github.com/superseriousbusiness/gotosocial/internal/web"
"github.com/superseriousbusiness/gotosocial/internal/worker"
)
// Start creates and starts a gotosocial server
@ -93,8 +93,8 @@ var Start action.GTSAction = func(ctx context.Context) error {
// NOTE: these MUST NOT be used until they are passed to the
// processor and it is started. The reason being that the processor
// sets the Worker process functions and start the underlying pools
clientWorker := worker.New[messages.FromClientAPI](-1, -1)
fedWorker := worker.New[messages.FromFederator](-1, -1)
clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
federatingDB := federatingdb.New(dbService, fedWorker)
@ -120,13 +120,16 @@ var Start action.GTSAction = func(ctx context.Context) error {
return fmt.Errorf("error creating storage backend: %s", err)
}
// Build HTTP client (TODO: add configurables here)
client := httpclient.New(httpclient.Config{})
// build backend handlers
mediaManager, err := media.NewManager(dbService, storage)
if err != nil {
return fmt.Errorf("error creating media manager: %s", err)
}
oauthServer := oauth.New(ctx, dbService)
transportController := transport.NewController(dbService, federatingDB, &federation.Clock{}, http.DefaultClient)
transportController := transport.NewController(dbService, federatingDB, &federation.Clock{}, client)
federator := federation.NewFederator(dbService, federatingDB, transportController, typeConverter, mediaManager)
// decide whether to create a noop email sender (won't send emails) or a real one

View file

@ -54,11 +54,11 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/api/s2s/user"
"github.com/superseriousbusiness/gotosocial/internal/api/s2s/webfinger"
"github.com/superseriousbusiness/gotosocial/internal/api/security"
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/gotosocial"
"github.com/superseriousbusiness/gotosocial/internal/messages"
"github.com/superseriousbusiness/gotosocial/internal/oidc"
"github.com/superseriousbusiness/gotosocial/internal/web"
"github.com/superseriousbusiness/gotosocial/internal/worker"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@ -74,8 +74,8 @@ var Start action.GTSAction = func(ctx context.Context) error {
testrig.StandardStorageSetup(storageBackend, "./testrig/media")
// Create client API and federator worker pools
clientWorker := worker.New[messages.FromClientAPI](-1, -1)
fedWorker := worker.New[messages.FromFederator](-1, -1)
clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
// build backend handlers
oauthServer := testrig.NewTestOauthServer(dbService)

5
go.mod
View file

@ -3,7 +3,10 @@ module github.com/superseriousbusiness/gotosocial
go 1.18
require (
codeberg.org/gruf/go-byteutil v1.0.1
codeberg.org/gruf/go-cache/v2 v2.0.1
codeberg.org/gruf/go-debug v1.1.2
codeberg.org/gruf/go-errors/v2 v2.0.1
codeberg.org/gruf/go-mutexes v1.1.2
codeberg.org/gruf/go-runners v1.2.1
codeberg.org/gruf/go-store v1.3.7
@ -52,8 +55,6 @@ require (
require (
codeberg.org/gruf/go-bitutil v1.0.0 // indirect
codeberg.org/gruf/go-bytes v1.0.2 // indirect
codeberg.org/gruf/go-byteutil v1.0.0 // indirect
codeberg.org/gruf/go-errors/v2 v2.0.1 // indirect
codeberg.org/gruf/go-fastcopy v1.1.1 // indirect
codeberg.org/gruf/go-fastpath v1.0.3 // indirect
codeberg.org/gruf/go-hashenc v1.0.2 // indirect

5
go.sum
View file

@ -40,9 +40,12 @@ codeberg.org/gruf/go-bitutil v1.0.0/go.mod h1:sb8IjlDnjVTz8zPK/8lmHesKxY0Yb3iqHW
codeberg.org/gruf/go-bytes v1.0.0/go.mod h1:1v/ibfaosfXSZtRdW2rWaVrDXMc9E3bsi/M9Ekx39cg=
codeberg.org/gruf/go-bytes v1.0.2 h1:malqE42Ni+h1nnYWBUAJaDDtEzF4aeN4uPN8DfMNNvo=
codeberg.org/gruf/go-bytes v1.0.2/go.mod h1:1v/ibfaosfXSZtRdW2rWaVrDXMc9E3bsi/M9Ekx39cg=
codeberg.org/gruf/go-byteutil v1.0.0 h1:xgKFNj/gH1r3yRo7gnyR4qrAKyeWCXs6B19ISX0DUAY=
codeberg.org/gruf/go-byteutil v1.0.0/go.mod h1:cWM3tgMCroSzqoBXUXMhvxTxYJp+TbCr6ioISRY5vSU=
codeberg.org/gruf/go-byteutil v1.0.1 h1:cOSaqe2aytOTAC5NM62LI0w8qPfJ9n2BBddk5KyMgd0=
codeberg.org/gruf/go-byteutil v1.0.1/go.mod h1:cWM3tgMCroSzqoBXUXMhvxTxYJp+TbCr6ioISRY5vSU=
codeberg.org/gruf/go-cache v1.1.2/go.mod h1:/Dbc+xU72Op3hMn6x2PXF3NE9uIDFeS+sXPF00hN/7o=
codeberg.org/gruf/go-cache/v2 v2.0.1 h1:dyyfn6W6jfUlD/HWu5oz48sowSgsfKKeg2lU6T0gRww=
codeberg.org/gruf/go-cache/v2 v2.0.1/go.mod h1:VyfrDnPVUXUKYVkXnFOHRO1EoN+8zrTC9jRU6VmL3p0=
codeberg.org/gruf/go-debug v1.1.2 h1:7Tqkktg60M/4WtXTTNUFH2T/6irBw4tI4viv7IRLZDE=
codeberg.org/gruf/go-debug v1.1.2/go.mod h1:N+vSy9uJBQgpQcJUqjctvqFz7tBHJf+S/PIjLILzpLg=
codeberg.org/gruf/go-errors/v2 v2.0.0/go.mod h1:ZRhbdhvgoUA3Yw6e56kd9Ox984RrvbEFC2pOXyHDJP4=

View file

@ -11,6 +11,7 @@ import (
"github.com/spf13/viper"
"github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/api/client/account"
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/email"
@ -20,7 +21,6 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/messages"
"github.com/superseriousbusiness/gotosocial/internal/oauth"
"github.com/superseriousbusiness/gotosocial/internal/processing"
"github.com/superseriousbusiness/gotosocial/internal/worker"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@ -62,8 +62,8 @@ func (suite *AccountStandardTestSuite) SetupTest() {
testrig.InitTestConfig()
testrig.InitTestLog()
fedWorker := worker.New[messages.FromFederator](-1, -1)
clientWorker := worker.New[messages.FromClientAPI](-1, -1)
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
suite.db = testrig.NewTestDB()
suite.storage = testrig.NewTestStorage()

View file

@ -29,6 +29,7 @@ import (
"github.com/spf13/viper"
"github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/api/client/admin"
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/email"
@ -38,7 +39,6 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/messages"
"github.com/superseriousbusiness/gotosocial/internal/oauth"
"github.com/superseriousbusiness/gotosocial/internal/processing"
"github.com/superseriousbusiness/gotosocial/internal/worker"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@ -80,8 +80,8 @@ func (suite *AdminStandardTestSuite) SetupTest() {
testrig.InitTestConfig()
testrig.InitTestLog()
fedWorker := worker.New[messages.FromFederator](-1, -1)
clientWorker := worker.New[messages.FromClientAPI](-1, -1)
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
suite.db = testrig.NewTestDB()
suite.storage = testrig.NewTestStorage()

View file

@ -31,6 +31,7 @@ import (
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/api/client/fileserver"
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/email"
"github.com/superseriousbusiness/gotosocial/internal/federation"
@ -40,7 +41,6 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/oauth"
"github.com/superseriousbusiness/gotosocial/internal/processing"
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
"github.com/superseriousbusiness/gotosocial/internal/worker"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@ -77,8 +77,8 @@ func (suite *ServeFileTestSuite) SetupSuite() {
testrig.InitTestConfig()
testrig.InitTestLog()
fedWorker := worker.New[messages.FromFederator](-1, -1)
clientWorker := worker.New[messages.FromClientAPI](-1, -1)
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
suite.db = testrig.NewTestDB()
suite.storage = testrig.NewTestStorage()

View file

@ -28,6 +28,7 @@ import (
"github.com/spf13/viper"
"github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/api/client/followrequest"
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/email"
@ -37,7 +38,6 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/messages"
"github.com/superseriousbusiness/gotosocial/internal/oauth"
"github.com/superseriousbusiness/gotosocial/internal/processing"
"github.com/superseriousbusiness/gotosocial/internal/worker"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@ -77,8 +77,8 @@ func (suite *FollowRequestStandardTestSuite) SetupTest() {
testrig.InitTestConfig()
testrig.InitTestLog()
fedWorker := worker.New[messages.FromFederator](-1, -1)
clientWorker := worker.New[messages.FromClientAPI](-1, -1)
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
suite.db = testrig.NewTestDB()
suite.storage = testrig.NewTestStorage()

View file

@ -37,6 +37,7 @@ import (
"github.com/stretchr/testify/suite"
mediamodule "github.com/superseriousbusiness/gotosocial/internal/api/client/media"
"github.com/superseriousbusiness/gotosocial/internal/api/model"
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/email"
@ -47,7 +48,6 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/oauth"
"github.com/superseriousbusiness/gotosocial/internal/processing"
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
"github.com/superseriousbusiness/gotosocial/internal/worker"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@ -84,8 +84,8 @@ func (suite *MediaCreateTestSuite) SetupSuite() {
testrig.InitTestConfig()
testrig.InitTestLog()
fedWorker := worker.New[messages.FromFederator](-1, -1)
clientWorker := worker.New[messages.FromClientAPI](-1, -1)
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
suite.db = testrig.NewTestDB()
suite.storage = testrig.NewTestStorage()

View file

@ -35,6 +35,7 @@ import (
"github.com/stretchr/testify/suite"
mediamodule "github.com/superseriousbusiness/gotosocial/internal/api/client/media"
"github.com/superseriousbusiness/gotosocial/internal/api/model"
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/email"
@ -45,7 +46,6 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/oauth"
"github.com/superseriousbusiness/gotosocial/internal/processing"
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
"github.com/superseriousbusiness/gotosocial/internal/worker"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@ -82,8 +82,8 @@ func (suite *MediaUpdateTestSuite) SetupSuite() {
testrig.InitTestConfig()
testrig.InitTestLog()
fedWorker := worker.New[messages.FromFederator](-1, -1)
clientWorker := worker.New[messages.FromClientAPI](-1, -1)
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
suite.db = testrig.NewTestDB()
suite.storage = testrig.NewTestStorage()

View file

@ -32,6 +32,7 @@ import (
"github.com/superseriousbusiness/activity/pub"
"github.com/superseriousbusiness/activity/streams"
"github.com/superseriousbusiness/gotosocial/internal/api/client/status"
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/email"
"github.com/superseriousbusiness/gotosocial/internal/federation"
@ -40,7 +41,6 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/messages"
"github.com/superseriousbusiness/gotosocial/internal/processing"
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
"github.com/superseriousbusiness/gotosocial/internal/worker"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@ -90,8 +90,8 @@ func (suite *StatusStandardTestSuite) SetupTest() {
testrig.StandardDBSetup(suite.db, nil)
testrig.StandardStorageSetup(suite.storage, "../../../../testrig/media")
fedWorker := worker.New[messages.FromFederator](-1, -1)
clientWorker := worker.New[messages.FromClientAPI](-1, -1)
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
suite.mediaManager = testrig.NewTestMediaManager(suite.db, suite.storage)
suite.federator = testrig.NewTestFederator(suite.db, testrig.NewTestTransportController(suite.testHttpClient(), suite.db, fedWorker), suite.storage, suite.mediaManager, fedWorker)

View file

@ -22,6 +22,7 @@ import (
"codeberg.org/gruf/go-store/kv"
"github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/api/client/user"
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/email"
"github.com/superseriousbusiness/gotosocial/internal/federation"
@ -30,7 +31,6 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/messages"
"github.com/superseriousbusiness/gotosocial/internal/processing"
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
"github.com/superseriousbusiness/gotosocial/internal/worker"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@ -58,8 +58,8 @@ type UserStandardTestSuite struct {
func (suite *UserStandardTestSuite) SetupTest() {
testrig.InitTestLog()
testrig.InitTestConfig()
fedWorker := worker.New[messages.FromFederator](-1, -1)
clientWorker := worker.New[messages.FromClientAPI](-1, -1)
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
suite.testTokens = testrig.NewTestTokens()
suite.testClients = testrig.NewTestClients()
suite.testApplications = testrig.NewTestApplications()

View file

@ -33,11 +33,11 @@ import (
"github.com/superseriousbusiness/activity/pub"
"github.com/superseriousbusiness/activity/streams"
"github.com/superseriousbusiness/gotosocial/internal/api/s2s/user"
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/id"
"github.com/superseriousbusiness/gotosocial/internal/messages"
"github.com/superseriousbusiness/gotosocial/internal/worker"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@ -85,8 +85,8 @@ func (suite *InboxPostTestSuite) TestPostBlock() {
suite.NoError(err)
body := bytes.NewReader(bodyJson)
clientWorker := worker.New[messages.FromClientAPI](-1, -1)
fedWorker := worker.New[messages.FromFederator](-1, -1)
clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
tc := testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil), suite.db, fedWorker)
federator := testrig.NewTestFederator(suite.db, tc, suite.storage, suite.mediaManager, fedWorker)
@ -188,8 +188,8 @@ func (suite *InboxPostTestSuite) TestPostUnblock() {
suite.NoError(err)
body := bytes.NewReader(bodyJson)
clientWorker := worker.New[messages.FromClientAPI](-1, -1)
fedWorker := worker.New[messages.FromFederator](-1, -1)
clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
tc := testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil), suite.db, fedWorker)
federator := testrig.NewTestFederator(suite.db, tc, suite.storage, suite.mediaManager, fedWorker)
@ -281,8 +281,8 @@ func (suite *InboxPostTestSuite) TestPostUpdate() {
suite.NoError(err)
body := bytes.NewReader(bodyJson)
clientWorker := worker.New[messages.FromClientAPI](-1, -1)
fedWorker := worker.New[messages.FromFederator](-1, -1)
clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
tc := testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil), suite.db, fedWorker)
federator := testrig.NewTestFederator(suite.db, tc, suite.storage, suite.mediaManager, fedWorker)
@ -403,8 +403,8 @@ func (suite *InboxPostTestSuite) TestPostDelete() {
suite.NoError(err)
body := bytes.NewReader(bodyJson)
clientWorker := worker.New[messages.FromClientAPI](-1, -1)
fedWorker := worker.New[messages.FromFederator](-1, -1)
clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
tc := testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil), suite.db, fedWorker)
federator := testrig.NewTestFederator(suite.db, tc, suite.storage, suite.mediaManager, fedWorker)

View file

@ -31,8 +31,8 @@ import (
"github.com/superseriousbusiness/activity/streams"
"github.com/superseriousbusiness/activity/streams/vocab"
"github.com/superseriousbusiness/gotosocial/internal/api/s2s/user"
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/messages"
"github.com/superseriousbusiness/gotosocial/internal/worker"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@ -46,8 +46,8 @@ func (suite *OutboxGetTestSuite) TestGetOutbox() {
signedRequest := derefRequests["foss_satan_dereference_zork_outbox"]
targetAccount := suite.testAccounts["local_account_1"]
clientWorker := worker.New[messages.FromClientAPI](-1, -1)
fedWorker := worker.New[messages.FromFederator](-1, -1)
clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
tc := testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil), suite.db, fedWorker)
federator := testrig.NewTestFederator(suite.db, tc, suite.storage, suite.mediaManager, fedWorker)
@ -104,8 +104,8 @@ func (suite *OutboxGetTestSuite) TestGetOutboxFirstPage() {
signedRequest := derefRequests["foss_satan_dereference_zork_outbox_first"]
targetAccount := suite.testAccounts["local_account_1"]
clientWorker := worker.New[messages.FromClientAPI](-1, -1)
fedWorker := worker.New[messages.FromFederator](-1, -1)
clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
tc := testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil), suite.db, fedWorker)
federator := testrig.NewTestFederator(suite.db, tc, suite.storage, suite.mediaManager, fedWorker)
@ -162,8 +162,8 @@ func (suite *OutboxGetTestSuite) TestGetOutboxNextPage() {
signedRequest := derefRequests["foss_satan_dereference_zork_outbox_next"]
targetAccount := suite.testAccounts["local_account_1"]
clientWorker := worker.New[messages.FromClientAPI](-1, -1)
fedWorker := worker.New[messages.FromFederator](-1, -1)
clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
tc := testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil), suite.db, fedWorker)
federator := testrig.NewTestFederator(suite.db, tc, suite.storage, suite.mediaManager, fedWorker)

View file

@ -33,8 +33,8 @@ import (
"github.com/superseriousbusiness/activity/streams"
"github.com/superseriousbusiness/activity/streams/vocab"
"github.com/superseriousbusiness/gotosocial/internal/api/s2s/user"
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/messages"
"github.com/superseriousbusiness/gotosocial/internal/worker"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@ -49,8 +49,8 @@ func (suite *RepliesGetTestSuite) TestGetReplies() {
targetAccount := suite.testAccounts["local_account_1"]
targetStatus := suite.testStatuses["local_account_1_status_1"]
clientWorker := worker.New[messages.FromClientAPI](-1, -1)
fedWorker := worker.New[messages.FromFederator](-1, -1)
clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
tc := testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil), suite.db, fedWorker)
federator := testrig.NewTestFederator(suite.db, tc, suite.storage, suite.mediaManager, fedWorker)
@ -113,8 +113,8 @@ func (suite *RepliesGetTestSuite) TestGetRepliesNext() {
targetAccount := suite.testAccounts["local_account_1"]
targetStatus := suite.testStatuses["local_account_1_status_1"]
clientWorker := worker.New[messages.FromClientAPI](-1, -1)
fedWorker := worker.New[messages.FromFederator](-1, -1)
clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
tc := testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil), suite.db, fedWorker)
federator := testrig.NewTestFederator(suite.db, tc, suite.storage, suite.mediaManager, fedWorker)
@ -180,8 +180,8 @@ func (suite *RepliesGetTestSuite) TestGetRepliesLast() {
targetAccount := suite.testAccounts["local_account_1"]
targetStatus := suite.testStatuses["local_account_1_status_1"]
clientWorker := worker.New[messages.FromClientAPI](-1, -1)
fedWorker := worker.New[messages.FromFederator](-1, -1)
clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
tc := testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil), suite.db, fedWorker)
federator := testrig.NewTestFederator(suite.db, tc, suite.storage, suite.mediaManager, fedWorker)

View file

@ -32,8 +32,8 @@ import (
"github.com/superseriousbusiness/activity/streams"
"github.com/superseriousbusiness/activity/streams/vocab"
"github.com/superseriousbusiness/gotosocial/internal/api/s2s/user"
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/messages"
"github.com/superseriousbusiness/gotosocial/internal/worker"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@ -48,8 +48,8 @@ func (suite *StatusGetTestSuite) TestGetStatus() {
targetAccount := suite.testAccounts["local_account_1"]
targetStatus := suite.testStatuses["local_account_1_status_1"]
clientWorker := worker.New[messages.FromClientAPI](-1, -1)
fedWorker := worker.New[messages.FromFederator](-1, -1)
clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
tc := testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil), suite.db, fedWorker)
federator := testrig.NewTestFederator(suite.db, tc, suite.storage, suite.mediaManager, fedWorker)
@ -116,8 +116,8 @@ func (suite *StatusGetTestSuite) TestGetStatusLowercase() {
targetAccount := suite.testAccounts["local_account_1"]
targetStatus := suite.testStatuses["local_account_1_status_1"]
clientWorker := worker.New[messages.FromClientAPI](-1, -1)
fedWorker := worker.New[messages.FromFederator](-1, -1)
clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
tc := testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil), suite.db, fedWorker)
federator := testrig.NewTestFederator(suite.db, tc, suite.storage, suite.mediaManager, fedWorker)

View file

@ -23,6 +23,7 @@ import (
"github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/api/s2s/user"
"github.com/superseriousbusiness/gotosocial/internal/api/security"
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/email"
"github.com/superseriousbusiness/gotosocial/internal/federation"
@ -32,7 +33,6 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/oauth"
"github.com/superseriousbusiness/gotosocial/internal/processing"
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
"github.com/superseriousbusiness/gotosocial/internal/worker"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@ -78,8 +78,8 @@ func (suite *UserStandardTestSuite) SetupTest() {
testrig.InitTestConfig()
testrig.InitTestLog()
clientWorker := worker.New[messages.FromClientAPI](-1, -1)
fedWorker := worker.New[messages.FromFederator](-1, -1)
clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
suite.db = testrig.NewTestDB()
suite.tc = testrig.NewTestTypeConverter(suite.db)

View file

@ -33,9 +33,9 @@ import (
"github.com/superseriousbusiness/activity/streams/vocab"
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
"github.com/superseriousbusiness/gotosocial/internal/api/s2s/user"
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/messages"
"github.com/superseriousbusiness/gotosocial/internal/oauth"
"github.com/superseriousbusiness/gotosocial/internal/worker"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@ -49,8 +49,8 @@ func (suite *UserGetTestSuite) TestGetUser() {
signedRequest := derefRequests["foss_satan_dereference_zork"]
targetAccount := suite.testAccounts["local_account_1"]
clientWorker := worker.New[messages.FromClientAPI](-1, -1)
fedWorker := worker.New[messages.FromFederator](-1, -1)
clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
tc := testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil), suite.db, fedWorker)
federator := testrig.NewTestFederator(suite.db, tc, suite.storage, suite.mediaManager, fedWorker)
@ -130,8 +130,8 @@ func (suite *UserGetTestSuite) TestGetUserPublicKeyDeleted() {
derefRequests := testrig.NewTestDereferenceRequests(suite.testAccounts)
signedRequest := derefRequests["foss_satan_dereference_zork_public_key"]
clientWorker := worker.New[messages.FromClientAPI](-1, -1)
fedWorker := worker.New[messages.FromFederator](-1, -1)
clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
tc := testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil), suite.db, fedWorker)
federator := testrig.NewTestFederator(suite.db, tc, suite.storage, suite.mediaManager, fedWorker)

View file

@ -28,6 +28,7 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/ap"
"github.com/superseriousbusiness/gotosocial/internal/api/s2s/webfinger"
"github.com/superseriousbusiness/gotosocial/internal/api/security"
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/email"
"github.com/superseriousbusiness/gotosocial/internal/federation"
@ -37,7 +38,6 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/oauth"
"github.com/superseriousbusiness/gotosocial/internal/processing"
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
"github.com/superseriousbusiness/gotosocial/internal/worker"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@ -81,8 +81,8 @@ func (suite *WebfingerStandardTestSuite) SetupTest() {
testrig.InitTestLog()
testrig.InitTestConfig()
clientWorker := worker.New[messages.FromClientAPI](-1, -1)
fedWorker := worker.New[messages.FromFederator](-1, -1)
clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
suite.db = testrig.NewTestDB()
suite.tc = testrig.NewTestTypeConverter(suite.db)

View file

@ -31,10 +31,10 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/api/s2s/webfinger"
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/messages"
"github.com/superseriousbusiness/gotosocial/internal/processing"
"github.com/superseriousbusiness/gotosocial/internal/worker"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@ -71,8 +71,8 @@ func (suite *WebfingerGetTestSuite) TestFingerUser() {
func (suite *WebfingerGetTestSuite) TestFingerUserWithDifferentAccountDomainByHost() {
viper.Set(config.Keys.Host, "gts.example.org")
viper.Set(config.Keys.AccountDomain, "example.org")
clientWorker := worker.New[messages.FromClientAPI](-1, -1)
fedWorker := worker.New[messages.FromFederator](-1, -1)
clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
suite.processor = processing.NewProcessor(suite.tc, suite.federator, testrig.NewTestOauthServer(suite.db), testrig.NewTestMediaManager(suite.db, suite.storage), suite.storage, suite.db, suite.emailSender, clientWorker, fedWorker)
suite.webfingerModule = webfinger.New(suite.processor).(*webfinger.Module)
@ -107,8 +107,8 @@ func (suite *WebfingerGetTestSuite) TestFingerUserWithDifferentAccountDomainByHo
func (suite *WebfingerGetTestSuite) TestFingerUserWithDifferentAccountDomainByAccountDomain() {
viper.Set(config.Keys.Host, "gts.example.org")
viper.Set(config.Keys.AccountDomain, "example.org")
clientWorker := worker.New[messages.FromClientAPI](-1, -1)
fedWorker := worker.New[messages.FromFederator](-1, -1)
clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
suite.processor = processing.NewProcessor(suite.tc, suite.federator, testrig.NewTestOauthServer(suite.db), testrig.NewTestMediaManager(suite.db, suite.storage), suite.storage, suite.db, suite.emailSender, clientWorker, fedWorker)
suite.webfingerModule = webfinger.New(suite.processor).(*webfinger.Module)

View file

@ -1,4 +1,4 @@
package worker
package concurrency
import (
"context"
@ -12,17 +12,17 @@ import (
"github.com/sirupsen/logrus"
)
// Worker represents a proccessor for MsgType objects, using a worker pool to allocate resources.
type Worker[MsgType any] struct {
// WorkerPool represents a proccessor for MsgType objects, using a worker pool to allocate resources.
type WorkerPool[MsgType any] struct {
workers runners.WorkerPool
process func(context.Context, MsgType) error
prefix string // contains type prefix for logging
}
// New returns a new Worker[MsgType] with given number of workers and queue ratio,
// New returns a new WorkerPool[MsgType] with given number of workers and queue ratio,
// where the queue ratio is multiplied by no. workers to get queue size. If args < 1
// then suitable defaults are determined from the runtime's GOMAXPROCS variable.
func New[MsgType any](workers int, queueRatio int) *Worker[MsgType] {
func NewWorkerPool[MsgType any](workers int, queueRatio int) *WorkerPool[MsgType] {
var zero MsgType
if workers < 1 {
@ -38,7 +38,7 @@ func New[MsgType any](workers int, queueRatio int) *Worker[MsgType] {
msgType := reflect.TypeOf(zero).String()
_, msgType = path.Split(msgType)
w := &Worker[MsgType]{
w := &WorkerPool[MsgType]{
workers: runners.NewWorkerPool(workers, workers*queueRatio),
process: nil,
prefix: fmt.Sprintf("worker.Worker[%s]", msgType),
@ -55,7 +55,7 @@ func New[MsgType any](workers int, queueRatio int) *Worker[MsgType] {
}
// Start will attempt to start the underlying worker pool, or return error.
func (w *Worker[MsgType]) Start() error {
func (w *WorkerPool[MsgType]) Start() error {
logrus.Infof("%s starting", w.prefix)
// Check processor was set
@ -72,7 +72,7 @@ func (w *Worker[MsgType]) Start() error {
}
// Stop will attempt to stop the underlying worker pool, or return error.
func (w *Worker[MsgType]) Stop() error {
func (w *WorkerPool[MsgType]) Stop() error {
logrus.Infof("%s stopping", w.prefix)
// Attempt to stop pool
@ -84,7 +84,7 @@ func (w *Worker[MsgType]) Stop() error {
}
// SetProcessor will set the Worker's processor function, which is called for each queued message.
func (w *Worker[MsgType]) SetProcessor(fn func(context.Context, MsgType) error) {
func (w *WorkerPool[MsgType]) SetProcessor(fn func(context.Context, MsgType) error) {
if w.process != nil {
logrus.Panicf("%s Worker.process is already set", w.prefix)
}
@ -92,7 +92,7 @@ func (w *Worker[MsgType]) SetProcessor(fn func(context.Context, MsgType) error)
}
// Queue will queue provided message to be processed with there's a free worker.
func (w *Worker[MsgType]) Queue(msg MsgType) {
func (w *WorkerPool[MsgType]) Queue(msg MsgType) {
logrus.Tracef("%s queueing message (workers=%d queue=%d): %+v",
w.prefix, w.workers.Workers(), w.workers.Queue(), msg,
)

View file

@ -29,12 +29,12 @@ import (
"github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/activity/streams"
"github.com/superseriousbusiness/activity/streams/vocab"
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/federation/dereferencing"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/messages"
"github.com/superseriousbusiness/gotosocial/internal/transport"
"github.com/superseriousbusiness/gotosocial/internal/worker"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@ -150,7 +150,7 @@ func (suite *DereferencerStandardTestSuite) mockTransportController() transport.
return response, nil
}
fedWorker := worker.New[messages.FromFederator](-1, -1)
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
mockClient := testrig.NewMockHTTPClient(do)
return testrig.NewTestTransportController(mockClient, suite.db, fedWorker)
}

View file

@ -28,10 +28,10 @@ import (
"time"
"github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/federation"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/messages"
"github.com/superseriousbusiness/gotosocial/internal/worker"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@ -57,7 +57,7 @@ func (suite *FederatingActorTestSuite) TestSendNoRemoteFollowers() {
)
testActivity := testrig.WrapAPNoteInCreate(testrig.URLMustParse("http://localhost:8080/whatever_some_create"), testrig.URLMustParse(testAccount.URI), time.Now(), testNote)
fedWorker := worker.New[messages.FromFederator](-1, -1)
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
// setup transport controller with a no-op client so we don't make external calls
sentMessages := []*url.URL{}
@ -112,7 +112,7 @@ func (suite *FederatingActorTestSuite) TestSendRemoteFollower() {
)
testActivity := testrig.WrapAPNoteInCreate(testrig.URLMustParse("http://localhost:8080/whatever_some_create"), testrig.URLMustParse(testAccount.URI), time.Now(), testNote)
fedWorker := worker.New[messages.FromFederator](-1, -1)
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
// setup transport controller with a no-op client so we don't make external calls
sentMessages := []*url.URL{}

View file

@ -24,10 +24,10 @@ import (
"codeberg.org/gruf/go-mutexes"
"github.com/superseriousbusiness/activity/pub"
"github.com/superseriousbusiness/activity/streams/vocab"
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/messages"
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
"github.com/superseriousbusiness/gotosocial/internal/worker"
)
// DB wraps the pub.Database interface with a couple of custom functions for GoToSocial.
@ -44,12 +44,12 @@ type DB interface {
type federatingDB struct {
locks mutexes.MutexMap
db db.DB
fedWorker *worker.Worker[messages.FromFederator]
fedWorker *concurrency.WorkerPool[messages.FromFederator]
typeConverter typeutils.TypeConverter
}
// New returns a DB interface using the given database and config
func New(db db.DB, fedWorker *worker.Worker[messages.FromFederator]) DB {
func New(db db.DB, fedWorker *concurrency.WorkerPool[messages.FromFederator]) DB {
fdb := federatingDB{
locks: mutexes.NewMap(-1, -1), // use defaults
db: db,

View file

@ -23,12 +23,12 @@ import (
"github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/ap"
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/federation/federatingdb"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/messages"
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
"github.com/superseriousbusiness/gotosocial/internal/worker"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@ -36,7 +36,7 @@ type FederatingDBTestSuite struct {
suite.Suite
db db.DB
tc typeutils.TypeConverter
fedWorker *worker.Worker[messages.FromFederator]
fedWorker *concurrency.WorkerPool[messages.FromFederator]
fromFederator chan messages.FromFederator
federatingDB federatingdb.DB
@ -65,7 +65,7 @@ func (suite *FederatingDBTestSuite) SetupSuite() {
func (suite *FederatingDBTestSuite) SetupTest() {
testrig.InitTestLog()
testrig.InitTestConfig()
suite.fedWorker = worker.New[messages.FromFederator](-1, -1)
suite.fedWorker = concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
suite.fromFederator = make(chan messages.FromFederator, 10)
suite.fedWorker.SetProcessor(func(ctx context.Context, msg messages.FromFederator) error {
suite.fromFederator <- msg

View file

@ -28,10 +28,10 @@ import (
"github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/activity/pub"
"github.com/superseriousbusiness/gotosocial/internal/ap"
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/federation"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/messages"
"github.com/superseriousbusiness/gotosocial/internal/worker"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@ -44,7 +44,7 @@ func (suite *FederatingProtocolTestSuite) TestPostInboxRequestBodyHook() {
// the activity we're gonna use
activity := suite.testActivities["dm_for_zork"]
fedWorker := worker.New[messages.FromFederator](-1, -1)
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
// setup transport controller with a no-op client so we don't make external calls
tc := testrig.NewTestTransportController(testrig.NewMockHTTPClient(func(req *http.Request) (*http.Response, error) {
@ -78,7 +78,7 @@ func (suite *FederatingProtocolTestSuite) TestAuthenticatePostInbox() {
sendingAccount := suite.testAccounts["remote_account_1"]
inboxAccount := suite.testAccounts["local_account_1"]
fedWorker := worker.New[messages.FromFederator](-1, -1)
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
tc := testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil), suite.db, fedWorker)
// now setup module being tested, with the mock transport controller

View file

@ -0,0 +1,199 @@
/*
GoToSocial
Copyright (C) 2021-2022 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package httpclient
import (
"errors"
"io"
"net"
"net/http"
"net/netip"
"runtime"
"time"
)
// ErrReservedAddr is returned if a dialed address resolves to an IP within a blocked or reserved net.
var ErrReservedAddr = errors.New("dial within blocked / reserved IP range")
// ErrBodyTooLarge is returned when a received response body is above predefined limit (default 40MB).
var ErrBodyTooLarge = errors.New("body size too large")
// dialer is the base net.Dialer used by all package-created http.Transports.
var dialer = &net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
Resolver: &net.Resolver{Dial: nil},
}
// Config provides configuration details for setting up a new
// instance of httpclient.Client{}. Within are a subset of the
// configuration values passed to initialized http.Transport{}
// and http.Client{}, along with httpclient.Client{} specific.
type Config struct {
// MaxOpenConns limits the max number of concurrent open connections.
MaxOpenConns int
// MaxIdleConns: see http.Transport{}.MaxIdleConns.
MaxIdleConns int
// ReadBufferSize: see http.Transport{}.ReadBufferSize.
ReadBufferSize int
// WriteBufferSize: see http.Transport{}.WriteBufferSize.
WriteBufferSize int
// MaxBodySize determines the maximum fetchable body size.
MaxBodySize int64
// Timeout: see http.Client{}.Timeout.
Timeout time.Duration
// DisableCompression: see http.Transport{}.DisableCompression.
DisableCompression bool
// AllowRanges allows outgoing communications to given IP nets.
AllowRanges []netip.Prefix
// BlockRanges blocks outgoing communiciations to given IP nets.
BlockRanges []netip.Prefix
}
// Client wraps an underlying http.Client{} to provide the following:
// - setting a maximum received request body size, returning error on
// large content lengths, and using a limited reader in all other
// cases to protect against forged / unknown content-lengths
// - protection from server side request forgery (SSRF) by only dialing
// out to known public IP prefixes, configurable with allows/blocks
// - limit number of concurrent requests, else blocking until a slot
// is available (context channels still respected)
type Client struct {
client http.Client
queue chan struct{}
bmax int64
}
// New returns a new instance of Client initialized using configuration.
func New(cfg Config) *Client {
var c Client
// Copy global
d := dialer
if cfg.MaxOpenConns <= 0 {
// By default base this value on GOMAXPROCS.
maxprocs := runtime.GOMAXPROCS(0)
cfg.MaxOpenConns = maxprocs * 10
}
if cfg.MaxIdleConns <= 0 {
// By default base this value on MaxOpenConns
cfg.MaxIdleConns = cfg.MaxOpenConns * 10
}
if cfg.MaxBodySize <= 0 {
// By default set this to a reasonable 40MB
cfg.MaxBodySize = 40 * 1024 * 1024
}
// Protect dialer with IP range sanitizer
d.Control = (&sanitizer{
allow: cfg.AllowRanges,
block: cfg.BlockRanges,
}).Sanitize
// Prepare client fields
c.bmax = cfg.MaxBodySize
c.queue = make(chan struct{}, cfg.MaxOpenConns)
c.client.Timeout = cfg.Timeout
// Set underlying HTTP client roundtripper
c.client.Transport = &http.Transport{
Proxy: http.ProxyFromEnvironment,
ForceAttemptHTTP2: true,
DialContext: d.DialContext,
MaxIdleConns: cfg.MaxIdleConns,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
ReadBufferSize: cfg.ReadBufferSize,
WriteBufferSize: cfg.WriteBufferSize,
DisableCompression: cfg.DisableCompression,
}
return &c
}
// Do will perform given request when an available slot in the queue is available,
// and block until this time. For returned values, this follows the same semantics
// as the standard http.Client{}.Do() implementation except that response body will
// be wrapped by an io.LimitReader() to limit response body sizes.
func (c *Client) Do(req *http.Request) (*http.Response, error) {
select {
// Request context cancelled
case <-req.Context().Done():
return nil, req.Context().Err()
// Slot in queue acquired
case c.queue <- struct{}{}:
// NOTE:
// Ideally here we would set the slot release to happen either
// on error return, or via callback from the response body closer.
// However when implementing this, there appear deadlocks between
// the channel queue here and the media manager worker pool. So
// currently we only place a limit on connections dialing out, but
// there may still be more connections open than len(c.queue) given
// that connections may not be closed until response body is closed.
// The current implementation will reduce the viability of denial of
// service attacks, but if there are future issues heed this advice :]
defer func() { <-c.queue }()
}
// Perform the HTTP request
rsp, err := c.client.Do(req)
if err != nil {
return nil, err
}
// Check response body not too large
if rsp.ContentLength > c.bmax {
return nil, ErrBodyTooLarge
}
// Seperate the body implementers
rbody := (io.Reader)(rsp.Body)
cbody := (io.Closer)(rsp.Body)
var limit int64
if limit = rsp.ContentLength; limit < 0 {
// If unknown, use max as reader limit
limit = c.bmax
}
// Don't trust them, limit body reads
rbody = io.LimitReader(rbody, limit)
// Wrap body with limit
rsp.Body = &struct {
io.Reader
io.Closer
}{rbody, cbody}
return rsp, nil
}

View file

@ -0,0 +1,154 @@
/*
GoToSocial
Copyright (C) 2021-2022 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package httpclient_test
import (
"bytes"
"errors"
"io"
"net/http"
"net/http/httptest"
"net/netip"
"testing"
"github.com/superseriousbusiness/gotosocial/internal/httpclient"
)
var privateIPs = []string{
"http://127.0.0.1:80",
"http://0.0.0.0:80",
"http://192.168.0.1:80",
"http://192.168.1.0:80",
"http://10.0.0.0:80",
"http://172.16.0.0:80",
"http://10.255.255.255:80",
"http://172.31.255.255:80",
"http://255.255.255.255:80",
}
var bodies = []string{
"hello world!",
"{}",
`{"key": "value", "some": "kinda bullshit"}`,
"body with\r\nnewlines",
}
// Note:
// There is no test for the .MaxOpenConns implementation
// in the httpclient.Client{}, due to the difficult to test
// this. The block is only held for the actual dial out to
// the connection, so the usual test of blocking and holding
// open this queue slot to check we can't open another isn't
// an easy test here.
func TestHTTPClientSmallBody(t *testing.T) {
for _, body := range bodies {
_TestHTTPClientWithBody(t, []byte(body), int(^uint16(0)))
}
}
func TestHTTPClientExactBody(t *testing.T) {
for _, body := range bodies {
_TestHTTPClientWithBody(t, []byte(body), len(body))
}
}
func TestHTTPClientLargeBody(t *testing.T) {
for _, body := range bodies {
_TestHTTPClientWithBody(t, []byte(body), len(body)-1)
}
}
func _TestHTTPClientWithBody(t *testing.T, body []byte, max int) {
var (
handler http.HandlerFunc
expect []byte
expectErr error
)
// If this is a larger body, reslice and
// set error so we know what to expect
expect = body
if max < len(body) {
expect = expect[:max]
expectErr = httpclient.ErrBodyTooLarge
}
// Create new HTTP client with maximum body size
client := httpclient.New(httpclient.Config{
MaxBodySize: int64(max),
DisableCompression: true,
AllowRanges: []netip.Prefix{
// Loopback (used by server)
netip.MustParsePrefix("127.0.0.1/8"),
},
})
// Set simple body-writing test handler
handler = func(rw http.ResponseWriter, r *http.Request) {
_, _ = rw.Write(body)
}
// Start the test server
srv := httptest.NewServer(handler)
defer srv.Close()
// Wrap body to provide reader iface
rbody := bytes.NewReader(body)
// Create the test HTTP request
req, _ := http.NewRequest("POST", srv.URL, rbody)
// Perform the test request
rsp, err := client.Do(req)
if !errors.Is(err, expectErr) {
t.Fatalf("error performing client request: %v", err)
} else if err != nil {
return // expected error
}
defer rsp.Body.Close()
// Read response body into memory
check, err := io.ReadAll(rsp.Body)
if err != nil {
t.Fatalf("error reading response body: %v", err)
}
// Check actual response body matches expected
if !bytes.Equal(expect, check) {
t.Errorf("response body did not match expected: expect=%q actual=%q", string(expect), string(check))
}
}
func TestHTTPClientPrivateIP(t *testing.T) {
client := httpclient.New(httpclient.Config{})
for _, addr := range privateIPs {
// Prepare request to private IP
req, _ := http.NewRequest("GET", addr, nil)
// Perform the HTTP request
_, err := client.Do(req)
if !errors.Is(err, httpclient.ErrReservedAddr) {
t.Errorf("dialing private address did not return expected error: %v", err)
}
}
}

View file

@ -0,0 +1,64 @@
/*
GoToSocial
Copyright (C) 2021-2022 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package httpclient
import (
"net/netip"
"syscall"
"github.com/superseriousbusiness/gotosocial/internal/netutil"
)
type sanitizer struct {
allow []netip.Prefix
block []netip.Prefix
}
// Sanitize implements the required net.Dialer.Control function signature.
func (s *sanitizer) Sanitize(ntwrk, addr string, _ syscall.RawConn) error {
// Parse IP+port from addr
ipport, err := netip.ParseAddrPort(addr)
if err != nil {
return err
}
// Seperate the IP
ip := ipport.Addr()
// Check if this is explicitly allowed
for i := 0; i < len(s.allow); i++ {
if s.allow[i].Contains(ip) {
return nil
}
}
// Now check if explicity blocked
for i := 0; i < len(s.block); i++ {
if s.block[i].Contains(ip) {
return ErrReservedAddr
}
}
// Validate this is a safe IP
if !netutil.ValidateIP(ip) {
return ErrReservedAddr
}
return nil
}

View file

@ -27,9 +27,9 @@ import (
"github.com/robfig/cron/v3"
"github.com/sirupsen/logrus"
"github.com/spf13/viper"
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/worker"
)
// Manager provides an interface for managing media: parsing, storing, and retrieving media objects like photos, videos, and gifs.
@ -79,8 +79,8 @@ type Manager interface {
type manager struct {
db db.DB
storage *kv.KVStore
emojiWorker *worker.Worker[*ProcessingEmoji]
mediaWorker *worker.Worker[*ProcessingMedia]
emojiWorker *concurrency.WorkerPool[*ProcessingEmoji]
mediaWorker *concurrency.WorkerPool[*ProcessingMedia]
stopCronJobs func() error
}
@ -89,7 +89,7 @@ type manager struct {
// A worker pool will also be initialized for the manager, to ensure that only
// a limited number of media will be processed in parallel. The numbers of workers
// is determined from the $GOMAXPROCS environment variable (usually no. CPU cores).
// See internal/worker.New() documentation for further information.
// See internal/concurrency.NewWorkerPool() documentation for further information.
func NewManager(database db.DB, storage *kv.KVStore) (Manager, error) {
m := &manager{
db: database,
@ -97,7 +97,7 @@ func NewManager(database db.DB, storage *kv.KVStore) (Manager, error) {
}
// Prepare the media worker pool
m.mediaWorker = worker.New[*ProcessingMedia](-1, 10)
m.mediaWorker = concurrency.NewWorkerPool[*ProcessingMedia](-1, 10)
m.mediaWorker.SetProcessor(func(ctx context.Context, media *ProcessingMedia) error {
if err := ctx.Err(); err != nil {
return err
@ -109,7 +109,7 @@ func NewManager(database db.DB, storage *kv.KVStore) (Manager, error) {
})
// Prepare the emoji worker pool
m.emojiWorker = worker.New[*ProcessingEmoji](-1, 10)
m.emojiWorker = concurrency.NewWorkerPool[*ProcessingEmoji](-1, 10)
m.emojiWorker.SetProcessor(func(ctx context.Context, emoji *ProcessingEmoji) error {
if err := ctx.Err(); err != nil {
return err

View file

@ -0,0 +1,78 @@
/*
GoToSocial
Copyright (C) 2021-2022 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package netutil
import (
"net/netip"
)
var (
// IPv6GlobalUnicast is the global IPv6 unicast IP prefix.
IPv6GlobalUnicast = netip.MustParsePrefix("ff00::/8")
// IPvReserved contains IPv4 reserved IP prefixes.
IPv4Reserved = [...]netip.Prefix{
netip.MustParsePrefix("0.0.0.0/8"), // Current network
netip.MustParsePrefix("10.0.0.0/8"), // Private
netip.MustParsePrefix("100.64.0.0/10"), // RFC6598
netip.MustParsePrefix("127.0.0.0/8"), // Loopback
netip.MustParsePrefix("169.254.0.0/16"), // Link-local
netip.MustParsePrefix("172.16.0.0/12"), // Private
netip.MustParsePrefix("192.0.0.0/24"), // RFC6890
netip.MustParsePrefix("192.0.2.0/24"), // Test, doc, examples
netip.MustParsePrefix("192.88.99.0/24"), // IPv6 to IPv4 relay
netip.MustParsePrefix("192.168.0.0/16"), // Private
netip.MustParsePrefix("198.18.0.0/15"), // Benchmarking tests
netip.MustParsePrefix("198.51.100.0/24"), // Test, doc, examples
netip.MustParsePrefix("203.0.113.0/24"), // Test, doc, examples
netip.MustParsePrefix("224.0.0.0/4"), // Multicast
netip.MustParsePrefix("240.0.0.0/4"), // Reserved (includes broadcast / 255.255.255.255)
}
)
// ValidateAddr will parse a netip.AddrPort from string, and return the result of ValidateIP() on addr.
func ValidateAddr(s string) bool {
ipport, err := netip.ParseAddrPort(s)
if err != nil {
return false
}
return ValidateIP(ipport.Addr())
}
// ValidateIP returns whether IP is an IPv4/6 address in non-reserved, public ranges.
func ValidateIP(ip netip.Addr) bool {
switch {
// IPv4: check if IPv4 in reserved nets
case ip.Is4():
for _, reserved := range IPv4Reserved {
if reserved.Contains(ip) {
return false
}
}
return true
// IPv6: check if in global unicast (public internet)
case ip.Is6():
return IPv6GlobalUnicast.Contains(ip)
// Assume malicious by default
default:
return false
}
}

View file

@ -23,6 +23,7 @@ import (
"mime/multipart"
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/federation"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
@ -33,7 +34,6 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/text"
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
"github.com/superseriousbusiness/gotosocial/internal/visibility"
"github.com/superseriousbusiness/gotosocial/internal/worker"
"github.com/superseriousbusiness/oauth2/v4"
)
@ -84,7 +84,7 @@ type Processor interface {
type processor struct {
tc typeutils.TypeConverter
mediaManager media.Manager
clientWorker *worker.Worker[messages.FromClientAPI]
clientWorker *concurrency.WorkerPool[messages.FromClientAPI]
oauthServer oauth.Server
filter visibility.Filter
formatter text.Formatter
@ -94,7 +94,7 @@ type processor struct {
}
// New returns a new account processor.
func New(db db.DB, tc typeutils.TypeConverter, mediaManager media.Manager, oauthServer oauth.Server, clientWorker *worker.Worker[messages.FromClientAPI], federator federation.Federator, parseMention gtsmodel.ParseMentionFunc) Processor {
func New(db db.DB, tc typeutils.TypeConverter, mediaManager media.Manager, oauthServer oauth.Server, clientWorker *concurrency.WorkerPool[messages.FromClientAPI], federator federation.Federator, parseMention gtsmodel.ParseMentionFunc) Processor {
return &processor{
tc: tc,
mediaManager: mediaManager,

View file

@ -24,6 +24,7 @@ import (
"codeberg.org/gruf/go-store/kv"
"github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/activity/pub"
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/email"
"github.com/superseriousbusiness/gotosocial/internal/federation"
@ -35,7 +36,6 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/processing/account"
"github.com/superseriousbusiness/gotosocial/internal/transport"
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
"github.com/superseriousbusiness/gotosocial/internal/worker"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@ -81,8 +81,8 @@ func (suite *AccountStandardTestSuite) SetupTest() {
testrig.InitTestLog()
testrig.InitTestConfig()
fedWorker := worker.New[messages.FromFederator](-1, -1)
clientWorker := worker.New[messages.FromClientAPI](-1, -1)
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
clientWorker.SetProcessor(func(_ context.Context, msg messages.FromClientAPI) error {
suite.fromClientAPIChan <- msg
return nil

View file

@ -23,13 +23,13 @@ import (
"mime/multipart"
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/media"
"github.com/superseriousbusiness/gotosocial/internal/messages"
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
"github.com/superseriousbusiness/gotosocial/internal/worker"
)
// Processor wraps a bunch of functions for processing admin actions.
@ -47,12 +47,12 @@ type Processor interface {
type processor struct {
tc typeutils.TypeConverter
mediaManager media.Manager
clientWorker *worker.Worker[messages.FromClientAPI]
clientWorker *concurrency.WorkerPool[messages.FromClientAPI]
db db.DB
}
// New returns a new admin processor.
func New(db db.DB, tc typeutils.TypeConverter, mediaManager media.Manager, clientWorker *worker.Worker[messages.FromClientAPI]) Processor {
func New(db db.DB, tc typeutils.TypeConverter, mediaManager media.Manager, clientWorker *concurrency.WorkerPool[messages.FromClientAPI]) Processor {
return &processor{
tc: tc,
mediaManager: mediaManager,

View file

@ -26,6 +26,7 @@ import (
"codeberg.org/gruf/go-store/kv"
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/media"
@ -33,7 +34,6 @@ import (
mediaprocessing "github.com/superseriousbusiness/gotosocial/internal/processing/media"
"github.com/superseriousbusiness/gotosocial/internal/transport"
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
"github.com/superseriousbusiness/gotosocial/internal/worker"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@ -122,7 +122,7 @@ func (suite *MediaStandardTestSuite) mockTransportController() transport.Control
return response, nil
}
fedWorker := worker.New[messages.FromFederator](-1, -1)
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
mockClient := testrig.NewMockHTTPClient(do)
return testrig.NewTestTransportController(mockClient, suite.db, fedWorker)
}

View file

@ -25,6 +25,7 @@ import (
"codeberg.org/gruf/go-store/kv"
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/email"
"github.com/superseriousbusiness/gotosocial/internal/federation"
@ -44,7 +45,6 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/timeline"
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
"github.com/superseriousbusiness/gotosocial/internal/visibility"
"github.com/superseriousbusiness/gotosocial/internal/worker"
)
// Processor should be passed to api modules (see internal/apimodule/...). It is used for
@ -237,8 +237,8 @@ type Processor interface {
// processor just implements the Processor interface
type processor struct {
clientWorker *worker.Worker[messages.FromClientAPI]
fedWorker *worker.Worker[messages.FromFederator]
clientWorker *concurrency.WorkerPool[messages.FromClientAPI]
fedWorker *concurrency.WorkerPool[messages.FromFederator]
federator federation.Federator
tc typeutils.TypeConverter
@ -271,8 +271,8 @@ func NewProcessor(
storage *kv.KVStore,
db db.DB,
emailSender email.Sender,
clientWorker *worker.Worker[messages.FromClientAPI],
fedWorker *worker.Worker[messages.FromFederator],
clientWorker *concurrency.WorkerPool[messages.FromClientAPI],
fedWorker *concurrency.WorkerPool[messages.FromFederator],
) Processor {
parseMentionFunc := GetParseMentionFunc(db, federator)

View file

@ -29,6 +29,7 @@ import (
"codeberg.org/gruf/go-store/kv"
"github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/activity/streams"
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/email"
"github.com/superseriousbusiness/gotosocial/internal/federation"
@ -40,7 +41,6 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/timeline"
"github.com/superseriousbusiness/gotosocial/internal/transport"
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
"github.com/superseriousbusiness/gotosocial/internal/worker"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@ -217,8 +217,8 @@ func (suite *ProcessingStandardTestSuite) SetupTest() {
}, nil
})
clientWorker := worker.New[messages.FromClientAPI](-1, -1)
fedWorker := worker.New[messages.FromFederator](-1, -1)
clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
suite.transportController = testrig.NewTestTransportController(httpClient, suite.db, fedWorker)
suite.mediaManager = testrig.NewTestMediaManager(suite.db, suite.storage)

View file

@ -22,6 +22,7 @@ import (
"context"
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
@ -29,7 +30,6 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/text"
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
"github.com/superseriousbusiness/gotosocial/internal/visibility"
"github.com/superseriousbusiness/gotosocial/internal/worker"
)
// Processor wraps a bunch of functions for processing statuses.
@ -74,12 +74,12 @@ type processor struct {
db db.DB
filter visibility.Filter
formatter text.Formatter
clientWorker *worker.Worker[messages.FromClientAPI]
clientWorker *concurrency.WorkerPool[messages.FromClientAPI]
parseMention gtsmodel.ParseMentionFunc
}
// New returns a new status processor.
func New(db db.DB, tc typeutils.TypeConverter, clientWorker *worker.Worker[messages.FromClientAPI], parseMention gtsmodel.ParseMentionFunc) Processor {
func New(db db.DB, tc typeutils.TypeConverter, clientWorker *concurrency.WorkerPool[messages.FromClientAPI], parseMention gtsmodel.ParseMentionFunc) Processor {
return &processor{
tc: tc,
db: db,

View file

@ -21,6 +21,7 @@ package status_test
import (
"codeberg.org/gruf/go-store/kv"
"github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/federation"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
@ -30,7 +31,6 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/processing/status"
"github.com/superseriousbusiness/gotosocial/internal/transport"
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
"github.com/superseriousbusiness/gotosocial/internal/worker"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@ -42,7 +42,7 @@ type StatusStandardTestSuite struct {
storage *kv.KVStore
mediaManager media.Manager
federator federation.Federator
clientWorker *worker.Worker[messages.FromClientAPI]
clientWorker *concurrency.WorkerPool[messages.FromClientAPI]
// standard suite models
testTokens map[string]*gtsmodel.Token
@ -75,11 +75,11 @@ func (suite *StatusStandardTestSuite) SetupTest() {
testrig.InitTestConfig()
testrig.InitTestLog()
fedWorker := worker.New[messages.FromFederator](-1, -1)
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
suite.db = testrig.NewTestDB()
suite.typeConverter = testrig.NewTestTypeConverter(suite.db)
suite.clientWorker = worker.New[messages.FromClientAPI](-1, -1)
suite.clientWorker = concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
suite.tc = testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil), suite.db, fedWorker)
suite.storage = testrig.NewTestStorage()
suite.mediaManager = testrig.NewTestMediaManager(suite.db, suite.storage)

View file

@ -20,13 +20,17 @@ package transport
import (
"context"
"crypto"
"crypto/rsa"
"crypto/x509"
"encoding/json"
"fmt"
"net/url"
"sync"
"runtime/debug"
"time"
"github.com/go-fed/httpsig"
"codeberg.org/gruf/go-byteutil"
"codeberg.org/gruf/go-cache/v2"
"github.com/sirupsen/logrus"
"github.com/spf13/viper"
"github.com/superseriousbusiness/activity/pub"
"github.com/superseriousbusiness/activity/streams"
@ -37,109 +41,85 @@ import (
// Controller generates transports for use in making federation requests to other servers.
type Controller interface {
NewTransport(pubKeyID string, privkey crypto.PrivateKey) (Transport, error)
// NewTransport returns an http signature transport with the given public key ID (URL location of pubkey), and the given private key.
NewTransport(pubKeyID string, privkey *rsa.PrivateKey) (Transport, error)
// NewTransportForUsername searches for account with username, and returns result of .NewTransport().
NewTransportForUsername(ctx context.Context, username string) (Transport, error)
}
type controller struct {
db db.DB
fedDB federatingdb.DB
clock pub.Clock
client pub.HttpClient
appAgent string
// dereferenceFollowersShortcut is a shortcut to dereference followers of an
// account on this instance, without making any external api/http calls.
//
// It is passed to new transports, and should only be invoked when the iri.Host == this host.
dereferenceFollowersShortcut func(ctx context.Context, iri *url.URL) ([]byte, error)
// dereferenceUserShortcut is a shortcut to dereference followers an account on
// this instance, without making any external api/http calls.
//
// It is passed to new transports, and should only be invoked when the iri.Host == this host.
dereferenceUserShortcut func(ctx context.Context, iri *url.URL) ([]byte, error)
}
func dereferenceFollowersShortcut(federatingDB federatingdb.DB) func(context.Context, *url.URL) ([]byte, error) {
return func(ctx context.Context, iri *url.URL) ([]byte, error) {
followers, err := federatingDB.Followers(ctx, iri)
if err != nil {
return nil, err
}
i, err := streams.Serialize(followers)
if err != nil {
return nil, err
}
return json.Marshal(i)
}
}
func dereferenceUserShortcut(federatingDB federatingdb.DB) func(context.Context, *url.URL) ([]byte, error) {
return func(ctx context.Context, iri *url.URL) ([]byte, error) {
user, err := federatingDB.Get(ctx, iri)
if err != nil {
return nil, err
}
i, err := streams.Serialize(user)
if err != nil {
return nil, err
}
return json.Marshal(i)
}
cache cache.Cache[string, *transport]
userAgent string
}
// NewController returns an implementation of the Controller interface for creating new transports
func NewController(db db.DB, federatingDB federatingdb.DB, clock pub.Clock, client pub.HttpClient) Controller {
applicationName := viper.GetString(config.Keys.ApplicationName)
host := viper.GetString(config.Keys.Host)
appAgent := fmt.Sprintf("%s %s", applicationName, host)
return &controller{
// Determine build information
build, _ := debug.ReadBuildInfo()
c := &controller{
db: db,
fedDB: federatingDB,
clock: clock,
client: client,
appAgent: appAgent,
dereferenceFollowersShortcut: dereferenceFollowersShortcut(federatingDB),
dereferenceUserShortcut: dereferenceUserShortcut(federatingDB),
}
cache: cache.New[string, *transport](),
userAgent: fmt.Sprintf("%s; %s (gofed/activity gotosocial-%s)", applicationName, host, build.Main.Version),
}
// NewTransport returns a new http signature transport with the given public key id (a URL), and the given private key.
func (c *controller) NewTransport(pubKeyID string, privkey crypto.PrivateKey) (Transport, error) {
prefs := []httpsig.Algorithm{httpsig.RSA_SHA256}
digestAlgo := httpsig.DigestSha256
getHeaders := []string{httpsig.RequestTarget, "host", "date"}
postHeaders := []string{httpsig.RequestTarget, "host", "date", "digest"}
getSigner, _, err := httpsig.NewSigner(prefs, digestAlgo, getHeaders, httpsig.Signature, 120)
if err != nil {
return nil, fmt.Errorf("error creating get signer: %s", err)
// Transport cache has TTL=1hr freq=1m
c.cache.SetTTL(time.Hour, false)
if !c.cache.Start(time.Minute) {
logrus.Panic("failed to start transport controller cache")
}
postSigner, _, err := httpsig.NewSigner(prefs, digestAlgo, postHeaders, httpsig.Signature, 120)
if err != nil {
return nil, fmt.Errorf("error creating post signer: %s", err)
return c
}
sigTransport := pub.NewHttpSigTransport(c.client, c.appAgent, c.clock, getSigner, postSigner, pubKeyID, privkey)
func (c *controller) NewTransport(pubKeyID string, privkey *rsa.PrivateKey) (Transport, error) {
// Generate public key string for cache key
//
// NOTE: it is safe to use the public key as the cache
// key here as we are generating it ourselves from the
// private key. If we were simply using a public key
// provided as argument that would absolutely NOT be safe.
pubStr := privkeyToPublicStr(privkey)
return &transport{
client: c.client,
appAgent: c.appAgent,
gofedAgent: "(go-fed/activity v1.0.0)",
clock: c.clock,
// First check for cached transport
transp, ok := c.cache.Get(pubStr)
if ok {
return transp, nil
}
// Create the transport
transp = &transport{
controller: c,
pubKeyID: pubKeyID,
privkey: privkey,
sigTransport: sigTransport,
getSigner: getSigner,
getSignerMu: &sync.Mutex{},
dereferenceFollowersShortcut: c.dereferenceFollowersShortcut,
dereferenceUserShortcut: c.dereferenceUserShortcut,
}, nil
}
// Cache this transport under pubkey
if !c.cache.Put(pubStr, transp) {
var cached *transport
cached, ok = c.cache.Get(pubStr)
if !ok {
// Some ridiculous race cond.
c.cache.Set(pubStr, transp)
} else {
// Use already cached
transp = cached
}
}
return transp, nil
}
func (c *controller) NewTransportForUsername(ctx context.Context, username string) (Transport, error) {
@ -164,3 +144,45 @@ func (c *controller) NewTransportForUsername(ctx context.Context, username strin
}
return transport, nil
}
// dereferenceLocalFollowers is a shortcut to dereference followers of an
// account on this instance, without making any external api/http calls.
//
// It is passed to new transports, and should only be invoked when the iri.Host == this host.
func (c *controller) dereferenceLocalFollowers(ctx context.Context, iri *url.URL) ([]byte, error) {
followers, err := c.fedDB.Followers(ctx, iri)
if err != nil {
return nil, err
}
i, err := streams.Serialize(followers)
if err != nil {
return nil, err
}
return json.Marshal(i)
}
// dereferenceLocalUser is a shortcut to dereference followers an account on
// this instance, without making any external api/http calls.
//
// It is passed to new transports, and should only be invoked when the iri.Host == this host.
func (c *controller) dereferenceLocalUser(ctx context.Context, iri *url.URL) ([]byte, error) {
user, err := c.fedDB.Get(ctx, iri)
if err != nil {
return nil, err
}
i, err := streams.Serialize(user)
if err != nil {
return nil, err
}
return json.Marshal(i)
}
// privkeyToPublicStr will create a string representation of RSA public key from private.
func privkeyToPublicStr(privkey *rsa.PrivateKey) string {
b := x509.MarshalPKCS1PublicKey(&privkey.PublicKey)
return byteutil.B2S(b)
}

View file

@ -19,13 +19,14 @@
package transport
import (
"bytes"
"context"
"fmt"
"net/http"
"net/url"
"strings"
"sync"
"github.com/sirupsen/logrus"
"github.com/spf13/viper"
"github.com/superseriousbusiness/gotosocial/internal/config"
)
@ -72,6 +73,28 @@ func (t *transport) Deliver(ctx context.Context, b []byte, to *url.URL) error {
return nil
}
logrus.Debugf("Deliver: posting as %s to %s", t.pubKeyID, to.String())
return t.sigTransport.Deliver(ctx, b, to)
urlStr := to.String()
req, err := http.NewRequestWithContext(ctx, "POST", urlStr, bytes.NewReader(b))
if err != nil {
return err
}
req.Header.Add("Content-Type", "application/ld+json; profile=\"https://www.w3.org/ns/activitystreams\"")
req.Header.Add("Accept-Charset", "utf-8")
req.Header.Add("User-Agent", t.controller.userAgent)
req.Header.Set("Host", to.Host)
resp, err := t.POST(req, b)
if err != nil {
return err
}
defer resp.Body.Close()
if code := resp.StatusCode; code != http.StatusOK &&
code != http.StatusCreated && code != http.StatusAccepted {
return fmt.Errorf("POST request to %s failed (%d): %s", urlStr, resp.StatusCode, resp.Status)
}
return nil
}

View file

@ -20,32 +20,55 @@ package transport
import (
"context"
"fmt"
"io/ioutil"
"net/http"
"net/url"
"github.com/sirupsen/logrus"
"github.com/spf13/viper"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/uris"
)
func (t *transport) Dereference(ctx context.Context, iri *url.URL) ([]byte, error) {
l := logrus.WithField("func", "Dereference")
// if the request is to us, we can shortcut for certain URIs rather than going through
// the normal request flow, thereby saving time and energy
if iri.Host == viper.GetString(config.Keys.Host) {
if uris.IsFollowersPath(iri) {
// the request is for followers of one of our accounts, which we can shortcut
return t.dereferenceFollowersShortcut(ctx, iri)
return t.controller.dereferenceLocalFollowers(ctx, iri)
}
if uris.IsUserPath(iri) {
// the request is for one of our accounts, which we can shortcut
return t.dereferenceUserShortcut(ctx, iri)
return t.controller.dereferenceLocalUser(ctx, iri)
}
}
// the request is either for a remote host or for us but we don't have a shortcut, so continue as normal
l.Debugf("performing GET to %s", iri.String())
return t.sigTransport.Dereference(ctx, iri)
// Build IRI just once
iriStr := iri.String()
// Prepare new HTTP request to endpoint
req, err := http.NewRequestWithContext(ctx, "GET", iriStr, nil)
if err != nil {
return nil, err
}
req.Header.Add("Accept", "application/ld+json; profile=\"https://www.w3.org/ns/activitystreams\"")
req.Header.Add("Accept-Charset", "utf-8")
req.Header.Add("User-Agent", t.controller.userAgent)
req.Header.Set("Host", iri.Host)
// Perform the HTTP request
rsp, err := t.GET(req)
if err != nil {
return nil, err
}
defer rsp.Body.Close()
// Check for an expected status code
if rsp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("GET request to %s failed (%d): %s", iriStr, rsp.StatusCode, rsp.Status)
}
return ioutil.ReadAll(rsp.Body)
}

View file

@ -80,43 +80,38 @@ func (t *transport) DereferenceInstance(ctx context.Context, iri *url.URL) (*gts
}
func dereferenceByAPIV1Instance(ctx context.Context, t *transport, iri *url.URL) (*gtsmodel.Instance, error) {
l := logrus.WithField("func", "dereferenceByAPIV1Instance")
cleanIRI := &url.URL{
Scheme: iri.Scheme,
Host: iri.Host,
Path: "api/v1/instance",
}
l.Debugf("performing GET to %s", cleanIRI.String())
req, err := http.NewRequestWithContext(ctx, "GET", cleanIRI.String(), nil)
if err != nil {
return nil, err
}
req.Header.Add("Accept", "application/json")
req.Header.Add("Date", t.clock.Now().UTC().Format("Mon, 02 Jan 2006 15:04:05")+" GMT")
req.Header.Add("User-Agent", fmt.Sprintf("%s %s", t.appAgent, t.gofedAgent))
req.Header.Set("Host", cleanIRI.Host)
t.getSignerMu.Lock()
err = t.getSigner.SignRequest(t.privkey, t.pubKeyID, req, nil)
t.getSignerMu.Unlock()
if err != nil {
return nil, err
}
resp, err := t.client.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("GET request to %s failed (%d): %s", cleanIRI.String(), resp.StatusCode, resp.Status)
}
b, err := ioutil.ReadAll(resp.Body)
// Build IRI just once
iriStr := cleanIRI.String()
req, err := http.NewRequestWithContext(ctx, "GET", iriStr, nil)
if err != nil {
return nil, err
}
if len(b) == 0 {
req.Header.Add("Accept", "application/json")
req.Header.Add("User-Agent", t.controller.userAgent)
req.Header.Set("Host", cleanIRI.Host)
resp, err := t.GET(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("GET request to %s failed (%d): %s", iriStr, resp.StatusCode, resp.Status)
}
b, err := ioutil.ReadAll(resp.Body)
if err != nil {
return nil, err
} else if len(b) == 0 {
return nil, errors.New("response bytes was len 0")
}
@ -237,44 +232,37 @@ func dereferenceByNodeInfo(c context.Context, t *transport, iri *url.URL) (*gtsm
}
func callNodeInfoWellKnown(ctx context.Context, t *transport, iri *url.URL) (*url.URL, error) {
l := logrus.WithField("func", "callNodeInfoWellKnown")
cleanIRI := &url.URL{
Scheme: iri.Scheme,
Host: iri.Host,
Path: ".well-known/nodeinfo",
}
l.Debugf("performing GET to %s", cleanIRI.String())
req, err := http.NewRequestWithContext(ctx, "GET", cleanIRI.String(), nil)
if err != nil {
return nil, err
}
// Build IRI just once
iriStr := cleanIRI.String()
req.Header.Add("Accept", "application/json")
req.Header.Add("Date", t.clock.Now().UTC().Format("Mon, 02 Jan 2006 15:04:05")+" GMT")
req.Header.Add("User-Agent", fmt.Sprintf("%s %s", t.appAgent, t.gofedAgent))
req.Header.Set("Host", cleanIRI.Host)
t.getSignerMu.Lock()
err = t.getSigner.SignRequest(t.privkey, t.pubKeyID, req, nil)
t.getSignerMu.Unlock()
req, err := http.NewRequestWithContext(ctx, "GET", iriStr, nil)
if err != nil {
return nil, err
}
resp, err := t.client.Do(req)
req.Header.Add("Accept", "application/json")
req.Header.Add("User-Agent", t.controller.userAgent)
req.Header.Set("Host", cleanIRI.Host)
resp, err := t.GET(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("callNodeInfoWellKnown: GET request to %s failed (%d): %s", cleanIRI.String(), resp.StatusCode, resp.Status)
return nil, fmt.Errorf("callNodeInfoWellKnown: GET request to %s failed (%d): %s", iriStr, resp.StatusCode, resp.Status)
}
b, err := ioutil.ReadAll(resp.Body)
if err != nil {
return nil, err
}
if len(b) == 0 {
} else if len(b) == 0 {
return nil, errors.New("callNodeInfoWellKnown: response bytes was len 0")
}
@ -302,38 +290,31 @@ func callNodeInfoWellKnown(ctx context.Context, t *transport, iri *url.URL) (*ur
}
func callNodeInfo(ctx context.Context, t *transport, iri *url.URL) (*apimodel.Nodeinfo, error) {
l := logrus.WithField("func", "callNodeInfo")
// Build IRI just once
iriStr := iri.String()
l.Debugf("performing GET to %s", iri.String())
req, err := http.NewRequestWithContext(ctx, "GET", iri.String(), nil)
req, err := http.NewRequestWithContext(ctx, "GET", iriStr, nil)
if err != nil {
return nil, err
}
req.Header.Add("Accept", "application/json")
req.Header.Add("Date", t.clock.Now().UTC().Format("Mon, 02 Jan 2006 15:04:05")+" GMT")
req.Header.Add("User-Agent", fmt.Sprintf("%s %s", t.appAgent, t.gofedAgent))
req.Header.Add("User-Agent", t.controller.userAgent)
req.Header.Set("Host", iri.Host)
t.getSignerMu.Lock()
err = t.getSigner.SignRequest(t.privkey, t.pubKeyID, req, nil)
t.getSignerMu.Unlock()
if err != nil {
return nil, err
}
resp, err := t.client.Do(req)
resp, err := t.GET(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("callNodeInfo: GET request to %s failed (%d): %s", iri.String(), resp.StatusCode, resp.Status)
return nil, fmt.Errorf("callNodeInfo: GET request to %s failed (%d): %s", iriStr, resp.StatusCode, resp.Status)
}
b, err := ioutil.ReadAll(resp.Body)
if err != nil {
return nil, err
}
if len(b) == 0 {
} else if len(b) == 0 {
return nil, errors.New("callNodeInfo: response bytes was len 0")
}

View file

@ -24,34 +24,31 @@ import (
"io"
"net/http"
"net/url"
"github.com/sirupsen/logrus"
)
func (t *transport) DereferenceMedia(ctx context.Context, iri *url.URL) (io.ReadCloser, int, error) {
l := logrus.WithField("func", "DereferenceMedia")
l.Debugf("performing GET to %s", iri.String())
req, err := http.NewRequestWithContext(ctx, "GET", iri.String(), nil)
// Build IRI just once
iriStr := iri.String()
// Prepare HTTP request to this media's IRI
req, err := http.NewRequestWithContext(ctx, "GET", iriStr, nil)
if err != nil {
return nil, 0, err
}
req.Header.Add("Accept", "*/*") // we don't know what kind of media we're going to get here
req.Header.Add("User-Agent", t.controller.userAgent)
req.Header.Set("Host", iri.Host)
// Perform the HTTP request
rsp, err := t.GET(req)
if err != nil {
return nil, 0, err
}
req.Header.Add("Accept", "*/*") // we don't know what kind of media we're going to get here
req.Header.Add("Date", t.clock.Now().UTC().Format("Mon, 02 Jan 2006 15:04:05")+" GMT")
req.Header.Add("User-Agent", fmt.Sprintf("%s %s", t.appAgent, t.gofedAgent))
req.Header.Set("Host", iri.Host)
t.getSignerMu.Lock()
err = t.getSigner.SignRequest(t.privkey, t.pubKeyID, req, nil)
t.getSignerMu.Unlock()
if err != nil {
return nil, 0, err
// Check for an expected status code
if rsp.StatusCode != http.StatusOK {
return nil, 0, fmt.Errorf("GET request to %s failed (%d): %s", iriStr, rsp.StatusCode, rsp.Status)
}
resp, err := t.client.Do(req)
if err != nil {
return nil, 0, err
}
if resp.StatusCode != http.StatusOK {
return nil, 0, fmt.Errorf("GET request to %s failed (%d): %s", iri.String(), resp.StatusCode, resp.Status)
}
return resp.Body, int(resp.ContentLength), nil
return rsp.Body, int(rsp.ContentLength), nil
}

View file

@ -23,46 +23,36 @@ import (
"fmt"
"io/ioutil"
"net/http"
"net/url"
"github.com/sirupsen/logrus"
)
func (t *transport) Finger(ctx context.Context, targetUsername string, targetDomain string) ([]byte, error) {
l := logrus.WithField("func", "Finger")
urlString := fmt.Sprintf("https://%s/.well-known/webfinger?resource=acct:%s@%s", targetDomain, targetUsername, targetDomain)
l.Debugf("performing GET to %s", urlString)
// Prepare URL string
urlStr := "https://" +
targetDomain +
"/.well-known/webfinger?resource=acct:" +
targetUsername + "@" + targetDomain
iri, err := url.Parse(urlString)
if err != nil {
return nil, fmt.Errorf("Finger: error parsing url %s: %s", urlString, err)
}
l.Debugf("performing GET to %s", iri.String())
req, err := http.NewRequestWithContext(ctx, "GET", iri.String(), nil)
// Generate new GET request from URL string
req, err := http.NewRequestWithContext(ctx, "GET", urlStr, nil)
if err != nil {
return nil, err
}
req.Header.Add("Accept", "application/json")
req.Header.Add("Accept", "application/jrd+json")
req.Header.Add("Date", t.clock.Now().UTC().Format("Mon, 02 Jan 2006 15:04:05")+" GMT")
req.Header.Add("User-Agent", fmt.Sprintf("%s %s", t.appAgent, t.gofedAgent))
req.Header.Set("Host", iri.Host)
t.getSignerMu.Lock()
err = t.getSigner.SignRequest(t.privkey, t.pubKeyID, req, nil)
t.getSignerMu.Unlock()
req.Header.Add("User-Agent", t.controller.userAgent)
req.Header.Set("Host", req.URL.Host)
// Perform the HTTP request
rsp, err := t.GET(req)
if err != nil {
return nil, err
}
resp, err := t.client.Do(req)
if err != nil {
return nil, err
defer rsp.Body.Close()
// Check for an expected status code
if rsp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("GET request to %s failed (%d): %s", urlStr, rsp.StatusCode, rsp.Status)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("GET request to %s failed (%d): %s", iri.String(), resp.StatusCode, resp.Status)
}
return ioutil.ReadAll(resp.Body)
return ioutil.ReadAll(rsp.Body)
}

View file

@ -0,0 +1,43 @@
/*
GoToSocial
Copyright (C) 2021-2022 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package transport
import (
"github.com/go-fed/httpsig"
)
var (
// http signer preferences
prefs = []httpsig.Algorithm{httpsig.RSA_SHA256}
digestAlgo = httpsig.DigestSha256
getHeaders = []string{httpsig.RequestTarget, "host", "date"}
postHeaders = []string{httpsig.RequestTarget, "host", "date", "digest"}
)
// NewGETSigner returns a new httpsig.Signer instance initialized with GTS GET preferences.
func NewGETSigner(expiresIn int64) (httpsig.Signer, error) {
sig, _, err := httpsig.NewSigner(prefs, digestAlgo, getHeaders, httpsig.Signature, expiresIn)
return sig, err
}
// NewPOSTSigner returns a new httpsig.Signer instance initialized with GTS POST preferences.
func NewPOSTSigner(expiresIn int64) (httpsig.Signer, error) {
sig, _, err := httpsig.NewSigner(prefs, digestAlgo, postHeaders, httpsig.Signature, expiresIn)
return sig, err
}

View file

@ -21,11 +21,18 @@ package transport
import (
"context"
"crypto"
"crypto/x509"
"errors"
"io"
"net/http"
"net/url"
"strings"
"sync"
"time"
errorsv2 "codeberg.org/gruf/go-errors/v2"
"github.com/go-fed/httpsig"
"github.com/sirupsen/logrus"
"github.com/superseriousbusiness/activity/pub"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
@ -43,28 +50,148 @@ type Transport interface {
DereferenceInstance(ctx context.Context, iri *url.URL) (*gtsmodel.Instance, error)
// Finger performs a webfinger request with the given username and domain, and returns the bytes from the response body.
Finger(ctx context.Context, targetUsername string, targetDomains string) ([]byte, error)
// SigTransport returns the underlying http signature transport wrapped by the GoToSocial transport.
SigTransport() pub.Transport
}
// transport implements the Transport interface
type transport struct {
client pub.HttpClient
appAgent string
gofedAgent string
clock pub.Clock
controller *controller
pubKeyID string
privkey crypto.PrivateKey
sigTransport *pub.HttpSigTransport
signerExp time.Time
getSigner httpsig.Signer
getSignerMu *sync.Mutex
// shortcuts for dereferencing things that exist on our instance without making an http call to ourself
dereferenceFollowersShortcut func(ctx context.Context, iri *url.URL) ([]byte, error)
dereferenceUserShortcut func(ctx context.Context, iri *url.URL) ([]byte, error)
postSigner httpsig.Signer
signerMu sync.Mutex
}
func (t *transport) SigTransport() pub.Transport {
return t.sigTransport
// GET will perform given http request using transport client, retrying on certain preset errors, or if status code is among retryOn.
func (t *transport) GET(r *http.Request, retryOn ...int) (*http.Response, error) {
if r.Method != http.MethodGet {
return nil, errors.New("must be GET request")
}
return t.do(r, func(r *http.Request) error {
return t.signGET(r)
}, retryOn...)
}
// POST will perform given http request using transport client, retrying on certain preset errors, or if status code is among retryOn.
func (t *transport) POST(r *http.Request, body []byte, retryOn ...int) (*http.Response, error) {
if r.Method != http.MethodPost {
return nil, errors.New("must be POST request")
}
return t.do(r, func(r *http.Request) error {
return t.signPOST(r, body)
}, retryOn...)
}
func (t *transport) do(r *http.Request, signer func(*http.Request) error, retryOn ...int) (*http.Response, error) {
const maxRetries = 5
backoff := time.Second * 2
// Start a log entry for this request
l := logrus.WithFields(logrus.Fields{
"pubKeyID": t.pubKeyID,
"method": r.Method,
"url": r.URL.String(),
})
for i := 0; i < maxRetries; i++ {
// Reset signing header fields
now := t.controller.clock.Now().UTC()
r.Header.Set("Date", now.Format("Mon, 02 Jan 2006 15:04:05")+" GMT")
r.Header.Del("Signature")
r.Header.Del("Digest")
// Perform request signing
if err := signer(r); err != nil {
return nil, err
}
l.Infof("performing request")
// Attempt to perform request
rsp, err := t.controller.client.Do(r)
if err == nil { //nolint shutup linter
// TooManyRequest means we need to slow
// down and retry our request. Codes over
// 500 generally indicate temp. outages.
if code := rsp.StatusCode; code < 500 &&
code != http.StatusTooManyRequests &&
!containsInt(retryOn, rsp.StatusCode) {
return rsp, nil
}
// Generate error from status code for logging
err = errors.New(`http response "` + rsp.Status + `"`)
} else if errorsv2.Is(err, context.DeadlineExceeded, context.Canceled) {
// Return early if context has cancelled
return nil, err
} else if strings.Contains(err.Error(), "stopped after 10 redirects") {
// Don't bother if net/http returned after too many redirects
return nil, err
} else if errors.As(err, &x509.UnknownAuthorityError{}) {
// Unknown authority errors we do NOT recover from
return nil, err
}
l.Errorf("backing off for %s after http request error: %v", backoff.String(), err)
select {
// Request ctx cancelled
case <-r.Context().Done():
return nil, r.Context().Err()
// Backoff for some time
case <-time.After(backoff):
backoff *= 2
}
}
return nil, errors.New("transport reached max retries")
}
// signGET will safely sign an HTTP GET request.
func (t *transport) signGET(r *http.Request) (err error) {
t.safesign(func() {
err = t.getSigner.SignRequest(t.privkey, t.pubKeyID, r, nil)
})
return
}
// signPOST will safely sign an HTTP POST request for given body.
func (t *transport) signPOST(r *http.Request, body []byte) (err error) {
t.safesign(func() {
err = t.postSigner.SignRequest(t.privkey, t.pubKeyID, r, body)
})
return
}
// safesign will perform sign function within mutex protection,
// and ensured that httpsig.Signers are up-to-date.
func (t *transport) safesign(sign func()) {
// Perform within mu safety
t.signerMu.Lock()
defer t.signerMu.Unlock()
if now := time.Now(); now.After(t.signerExp) {
const expiry = 120
// Signers have expired and require renewal
t.getSigner, _ = NewGETSigner(expiry)
t.postSigner, _ = NewPOSTSigner(expiry)
t.signerExp = now.Add(time.Second * expiry)
}
// Perform signing
sign()
}
// containsInt checks if slice contains check.
func containsInt(slice []int, check int) bool {
for _, i := range slice {
if i == check {
return true
}
}
return false
}

View file

@ -1,13 +1,13 @@
package testrig
import (
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/federation/federatingdb"
"github.com/superseriousbusiness/gotosocial/internal/messages"
"github.com/superseriousbusiness/gotosocial/internal/worker"
)
// NewTestFederatingDB returns a federating DB with the underlying db
func NewTestFederatingDB(db db.DB, fedWorker *worker.Worker[messages.FromFederator]) federatingdb.DB {
func NewTestFederatingDB(db db.DB, fedWorker *concurrency.WorkerPool[messages.FromFederator]) federatingdb.DB {
return federatingdb.New(db, fedWorker)
}

View file

@ -20,15 +20,15 @@ package testrig
import (
"codeberg.org/gruf/go-store/kv"
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/federation"
"github.com/superseriousbusiness/gotosocial/internal/media"
"github.com/superseriousbusiness/gotosocial/internal/messages"
"github.com/superseriousbusiness/gotosocial/internal/transport"
"github.com/superseriousbusiness/gotosocial/internal/worker"
)
// NewTestFederator returns a federator with the given database and (mock!!) transport controller.
func NewTestFederator(db db.DB, tc transport.Controller, storage *kv.KVStore, mediaManager media.Manager, fedWorker *worker.Worker[messages.FromFederator]) federation.Federator {
func NewTestFederator(db db.DB, tc transport.Controller, storage *kv.KVStore, mediaManager media.Manager, fedWorker *concurrency.WorkerPool[messages.FromFederator]) federation.Federator {
return federation.NewFederator(db, NewTestFederatingDB(db, fedWorker), tc, NewTestTypeConverter(db), mediaManager)
}

View file

@ -20,16 +20,16 @@ package testrig
import (
"codeberg.org/gruf/go-store/kv"
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/email"
"github.com/superseriousbusiness/gotosocial/internal/federation"
"github.com/superseriousbusiness/gotosocial/internal/media"
"github.com/superseriousbusiness/gotosocial/internal/messages"
"github.com/superseriousbusiness/gotosocial/internal/processing"
"github.com/superseriousbusiness/gotosocial/internal/worker"
)
// NewTestProcessor returns a Processor suitable for testing purposes
func NewTestProcessor(db db.DB, storage *kv.KVStore, federator federation.Federator, emailSender email.Sender, mediaManager media.Manager, clientWorker *worker.Worker[messages.FromClientAPI], fedWorker *worker.Worker[messages.FromFederator]) processing.Processor {
func NewTestProcessor(db db.DB, storage *kv.KVStore, federator federation.Federator, emailSender email.Sender, mediaManager media.Manager, clientWorker *concurrency.WorkerPool[messages.FromClientAPI], fedWorker *concurrency.WorkerPool[messages.FromFederator]) processing.Processor {
return processing.NewProcessor(NewTestTypeConverter(db), federator, NewTestOauthServer(db), mediaManager, storage, db, emailSender, clientWorker, fedWorker)
}

View file

@ -20,8 +20,6 @@ package testrig
import (
"bytes"
"context"
"crypto"
"crypto/rand"
"crypto/rsa"
"crypto/x509"
@ -29,7 +27,6 @@ import (
"encoding/json"
"encoding/pem"
"fmt"
"io/ioutil"
"net"
"net/http"
"net/url"
@ -42,8 +39,7 @@ import (
"github.com/superseriousbusiness/activity/streams/vocab"
"github.com/superseriousbusiness/gotosocial/internal/ap"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/messages"
"github.com/superseriousbusiness/gotosocial/internal/worker"
"github.com/superseriousbusiness/gotosocial/internal/transport"
)
// NewTestTokens returns a map of tokens keyed according to which account the token belongs to.
@ -1855,86 +1851,71 @@ func NewTestDereferenceRequests(accounts map[string]*gtsmodel.Account) map[strin
}
}
// GetSignatureForActivity does some sneaky sneaky work with a mock http client and a test transport controller, in order to derive
// the HTTP Signature for the given activity, public key ID, private key, and destination.
func GetSignatureForActivity(activity pub.Activity, pubKeyID string, privkey crypto.PrivateKey, destination *url.URL) (signatureHeader string, digestHeader string, dateHeader string) {
// create a client that basically just pulls the signature out of the request and sets it
client := &mockHTTPClient{
do: func(req *http.Request) (*http.Response, error) {
signatureHeader = req.Header.Get("Signature")
digestHeader = req.Header.Get("Digest")
dateHeader = req.Header.Get("Date")
r := ioutil.NopCloser(bytes.NewReader([]byte{})) // we only need this so the 'close' func doesn't nil out
return &http.Response{
StatusCode: 200,
Body: r,
}, nil
},
}
// Create temporary federator worker for transport controller
fedWorker := worker.New[messages.FromFederator](-1, -1)
_ = fedWorker.Start()
defer func() { _ = fedWorker.Stop() }()
// use the client to create a new transport
c := NewTestTransportController(client, NewTestDB(), fedWorker)
tp, err := c.NewTransport(pubKeyID, privkey)
if err != nil {
panic(err)
}
// GetSignatureForActivity prepares a mock HTTP request as if it were going to deliver activity to destination signed for privkey and pubKeyID, signs the request and returns the header values.
func GetSignatureForActivity(activity pub.Activity, pubKeyID string, privkey *rsa.PrivateKey, destination *url.URL) (signatureHeader string, digestHeader string, dateHeader string) {
// convert the activity into json bytes
m, err := activity.Serialize()
if err != nil {
panic(err)
}
bytes, err := json.Marshal(m)
b, err := json.Marshal(m)
if err != nil {
panic(err)
}
// trigger the delivery function for the underlying signature transport, which will trigger the 'do' function of the recorder above
if err := tp.SigTransport().Deliver(context.Background(), bytes, destination); err != nil {
// Prepare HTTP request signer
sig, err := transport.NewPOSTSigner(120)
if err != nil {
panic(err)
}
// Prepare a mock request ready for signing
r, err := http.NewRequest("POST", destination.String(), bytes.NewReader(b))
if err != nil {
panic(err)
}
r.Header.Set("Host", destination.Host)
r.Header.Set("Date", time.Now().Format("Mon, 02 Jan 2006 15:04:05")+" GMT")
// Sign this new HTTP request
if err := sig.SignRequest(privkey, pubKeyID, r, b); err != nil {
panic(err)
}
// Load signed data from request
signatureHeader = r.Header.Get("Signature")
digestHeader = r.Header.Get("Digest")
dateHeader = r.Header.Get("Date")
// headers should now be populated
return
}
// GetSignatureForDereference does some sneaky sneaky work with a mock http client and a test transport controller, in order to derive
// the HTTP Signature for the given derefence GET request using public key ID, private key, and destination.
func GetSignatureForDereference(pubKeyID string, privkey crypto.PrivateKey, destination *url.URL) (signatureHeader string, digestHeader string, dateHeader string) {
// create a client that basically just pulls the signature out of the request and sets it
client := &mockHTTPClient{
do: func(req *http.Request) (*http.Response, error) {
signatureHeader = req.Header.Get("Signature")
dateHeader = req.Header.Get("Date")
r := ioutil.NopCloser(bytes.NewReader([]byte{})) // we only need this so the 'close' func doesn't nil out
return &http.Response{
StatusCode: 200,
Body: r,
}, nil
},
}
// Create temporary federator worker for transport controller
fedWorker := worker.New[messages.FromFederator](-1, -1)
_ = fedWorker.Start()
defer func() { _ = fedWorker.Stop() }()
// use the client to create a new transport
c := NewTestTransportController(client, NewTestDB(), fedWorker)
tp, err := c.NewTransport(pubKeyID, privkey)
// GetSignatureForDereference prepares a mock HTTP request as if it were going to dereference destination signed for privkey and pubKeyID, signs the request and returns the header values.
func GetSignatureForDereference(pubKeyID string, privkey *rsa.PrivateKey, destination *url.URL) (signatureHeader string, digestHeader string, dateHeader string) {
// Prepare HTTP request signer
sig, err := transport.NewGETSigner(120)
if err != nil {
panic(err)
}
// trigger the dereference function for the underlying signature transport, which will trigger the 'do' function of the recorder above
if _, err := tp.SigTransport().Dereference(context.Background(), destination); err != nil {
// Prepare a mock request ready for signing
r, err := http.NewRequest("GET", destination.String(), nil)
if err != nil {
panic(err)
}
r.Header.Set("Host", destination.Host)
r.Header.Set("Date", time.Now().Format("Mon, 02 Jan 2006 15:04:05")+" GMT")
// Sign this new HTTP request
if err := sig.SignRequest(privkey, pubKeyID, r, nil); err != nil {
panic(err)
}
// Load signed data from request
signatureHeader = r.Header.Get("Signature")
digestHeader = r.Header.Get("Digest")
dateHeader = r.Header.Get("Date")
// headers should now be populated
return

View file

@ -24,11 +24,11 @@ import (
"net/http"
"github.com/superseriousbusiness/activity/pub"
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/federation"
"github.com/superseriousbusiness/gotosocial/internal/messages"
"github.com/superseriousbusiness/gotosocial/internal/transport"
"github.com/superseriousbusiness/gotosocial/internal/worker"
)
// NewTestTransportController returns a test transport controller with the given http client.
@ -40,7 +40,7 @@ import (
// Unlike the other test interfaces provided in this package, you'll probably want to call this function
// PER TEST rather than per suite, so that the do function can be set on a test by test (or even more granular)
// basis.
func NewTestTransportController(client pub.HttpClient, db db.DB, fedWorker *worker.Worker[messages.FromFederator]) transport.Controller {
func NewTestTransportController(client pub.HttpClient, db db.DB, fedWorker *concurrency.WorkerPool[messages.FromFederator]) transport.Controller {
return transport.NewController(db, NewTestFederatingDB(db, fedWorker), &federation.Clock{}, client)
}

View file

@ -16,11 +16,30 @@ func Copy(b []byte) []byte {
}
// B2S returns a string representation of []byte without allocation.
//
// According to the Go spec strings are immutable and byte slices are not. The way this gets implemented is strings under the hood are:
// type StringHeader struct {
// Data uintptr
// Len int
// }
//
// while slices are:
// type SliceHeader struct {
// Data uintptr
// Len int
// Cap int
// }
// because being mutable, you can change the data, length etc, but the string has to promise to be read-only to all who get copies of it.
//
// So in practice when you do a conversion of `string(byteSlice)` it actually performs an allocation because it has to copy the contents of the byte slice into a safe read-only state.
//
// Being that the shared fields are in the same struct indices (no different offsets), means that if you have a byte slice you can "forcibly" cast it to a string. Which in a lot of situations can be risky, because then it means you have a string that is NOT immutable, as if someone changes the data in the originating byte slice then the string will reflect that change! Now while this does seem hacky, and it _kind_ of is, it is something that you see performed in the standard library. If you look at the definition for `strings.Builder{}.String()` you'll see this :)
func B2S(b []byte) string {
return *(*string)(unsafe.Pointer(&b))
}
// S2B returns a []byte representation of string without allocation (minus slice header).
// See B2S() code comment, and this function's implementation for a better understanding.
func S2B(s string) []byte {
var b []byte

9
vendor/codeberg.org/gruf/go-cache/v2/LICENSE generated vendored Normal file
View file

@ -0,0 +1,9 @@
MIT License
Copyright (c) 2021 gruf
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

3
vendor/codeberg.org/gruf/go-cache/v2/README.md generated vendored Normal file
View file

@ -0,0 +1,3 @@
# go-cache
A TTL cache designed to be used as a base for your own customizations, or used straight out of the box

67
vendor/codeberg.org/gruf/go-cache/v2/cache.go generated vendored Normal file
View file

@ -0,0 +1,67 @@
package cache
import "time"
// Cache represents a TTL cache with customizable callbacks, it
// exists here to abstract away the "unsafe" methods in the case that
// you do not want your own implementation atop TTLCache{}.
type Cache[Key comparable, Value any] interface {
// Start will start the cache background eviction routine with given sweep frequency.
// If already running or a freq <= 0 provided, this is a no-op. This will block until
// the eviction routine has started
Start(freq time.Duration) bool
// Stop will stop cache background eviction routine. If not running this is a no-op. This
// will block until the eviction routine has stopped
Stop() bool
// SetEvictionCallback sets the eviction callback to the provided hook
SetEvictionCallback(hook Hook[Key, Value])
// SetInvalidateCallback sets the invalidate callback to the provided hook
SetInvalidateCallback(hook Hook[Key, Value])
// SetTTL sets the cache item TTL. Update can be specified to force updates of existing items in
// the cache, this will simply add the change in TTL to their current expiry time
SetTTL(ttl time.Duration, update bool)
// Get fetches the value with key from the cache, extending its TTL
Get(key Key) (value Value, ok bool)
// Put attempts to place the value at key in the cache, doing nothing if
// a value with this key already exists. Returned bool is success state
Put(key Key, value Value) bool
// Set places the value at key in the cache. This will overwrite any
// existing value, and call the update callback so. Existing values
// will have their TTL extended upon update
Set(key Key, value Value)
// CAS will attempt to perform a CAS operation on 'key', using provided
// comparison and swap values. Returned bool is success.
CAS(key Key, cmp, swp Value) bool
// Swap will attempt to perform a swap on 'key', replacing the value there
// and returning the existing value. If no value exists for key, this will
// set the value and return the zero value for V.
Swap(key Key, swp Value) Value
// Has checks the cache for a value with key, this will not update TTL
Has(key Key) bool
// Invalidate deletes a value from the cache, calling the invalidate callback
Invalidate(key Key) bool
// Clear empties the cache, calling the invalidate callback
Clear()
// Size returns the current size of the cache
Size() int
}
// New returns a new initialized Cache.
func New[K comparable, V any]() Cache[K, V] {
c := TTLCache[K, V]{}
c.Init()
return &c
}

23
vendor/codeberg.org/gruf/go-cache/v2/compare.go generated vendored Normal file
View file

@ -0,0 +1,23 @@
package cache
import (
"reflect"
)
type Comparable interface {
Equal(any) bool
}
// Compare returns whether 2 values are equal using the Comparable
// interface, or failing that falls back to use reflect.DeepEqual().
func Compare(i1, i2 any) bool {
c1, ok1 := i1.(Comparable)
if ok1 {
return c1.Equal(i2)
}
c2, ok2 := i2.(Comparable)
if ok2 {
return c2.Equal(i1)
}
return reflect.DeepEqual(i1, i2)
}

6
vendor/codeberg.org/gruf/go-cache/v2/hook.go generated vendored Normal file
View file

@ -0,0 +1,6 @@
package cache
// Hook defines a function hook that can be supplied as a callback.
type Hook[Key comparable, Value any] func(key Key, value Value)
func emptyHook[K comparable, V any](K, V) {}

214
vendor/codeberg.org/gruf/go-cache/v2/lookup.go generated vendored Normal file
View file

@ -0,0 +1,214 @@
package cache
// LookupCfg is the LookupCache configuration.
type LookupCfg[OGKey, AltKey comparable, Value any] struct {
// RegisterLookups is called on init to register lookups
// within LookupCache's internal LookupMap
RegisterLookups func(*LookupMap[OGKey, AltKey])
// AddLookups is called on each addition to the cache, to
// set any required additional key lookups for supplied item
AddLookups func(*LookupMap[OGKey, AltKey], Value)
// DeleteLookups is called on each eviction/invalidation of
// an item in the cache, to remove any unused key lookups
DeleteLookups func(*LookupMap[OGKey, AltKey], Value)
}
// LookupCache is a cache built on-top of TTLCache, providing multi-key
// lookups for items in the cache by means of additional lookup maps. These
// maps simply store additional keys => original key, with hook-ins to automatically
// call user supplied functions on adding an item, or on updating/deleting an
// item to keep the LookupMap up-to-date.
type LookupCache[OGKey, AltKey comparable, Value any] interface {
Cache[OGKey, Value]
// GetBy fetches a cached value by supplied lookup identifier and key
GetBy(lookup string, key AltKey) (value Value, ok bool)
// CASBy will attempt to perform a CAS operation on supplied lookup identifier and key
CASBy(lookup string, key AltKey, cmp, swp Value) bool
// SwapBy will attempt to perform a swap operation on supplied lookup identifier and key
SwapBy(lookup string, key AltKey, swp Value) Value
// HasBy checks if a value is cached under supplied lookup identifier and key
HasBy(lookup string, key AltKey) bool
// InvalidateBy invalidates a value by supplied lookup identifier and key
InvalidateBy(lookup string, key AltKey) bool
}
type lookupTTLCache[OK, AK comparable, V any] struct {
config LookupCfg[OK, AK, V]
lookup LookupMap[OK, AK]
TTLCache[OK, V]
}
// NewLookup returns a new initialized LookupCache.
func NewLookup[OK, AK comparable, V any](cfg LookupCfg[OK, AK, V]) LookupCache[OK, AK, V] {
switch {
case cfg.RegisterLookups == nil:
panic("cache: nil lookups register function")
case cfg.AddLookups == nil:
panic("cache: nil lookups add function")
case cfg.DeleteLookups == nil:
panic("cache: nil delete lookups function")
}
c := lookupTTLCache[OK, AK, V]{config: cfg}
c.TTLCache.Init()
c.lookup.lookup = make(map[string]map[AK]OK)
c.config.RegisterLookups(&c.lookup)
c.SetEvictionCallback(nil)
c.SetInvalidateCallback(nil)
c.lookup.initd = true
return &c
}
func (c *lookupTTLCache[OK, AK, V]) SetEvictionCallback(hook Hook[OK, V]) {
if hook == nil {
hook = emptyHook[OK, V]
}
c.TTLCache.SetEvictionCallback(func(key OK, value V) {
hook(key, value)
c.config.DeleteLookups(&c.lookup, value)
})
}
func (c *lookupTTLCache[OK, AK, V]) SetInvalidateCallback(hook Hook[OK, V]) {
if hook == nil {
hook = emptyHook[OK, V]
}
c.TTLCache.SetInvalidateCallback(func(key OK, value V) {
hook(key, value)
c.config.DeleteLookups(&c.lookup, value)
})
}
func (c *lookupTTLCache[OK, AK, V]) GetBy(lookup string, key AK) (V, bool) {
c.Lock()
origKey, ok := c.lookup.Get(lookup, key)
if !ok {
c.Unlock()
var value V
return value, false
}
v, ok := c.GetUnsafe(origKey)
c.Unlock()
return v, ok
}
func (c *lookupTTLCache[OK, AK, V]) Put(key OK, value V) bool {
c.Lock()
put := c.PutUnsafe(key, value)
if put {
c.config.AddLookups(&c.lookup, value)
}
c.Unlock()
return put
}
func (c *lookupTTLCache[OK, AK, V]) Set(key OK, value V) {
c.Lock()
defer c.Unlock()
c.SetUnsafe(key, value)
c.config.AddLookups(&c.lookup, value)
}
func (c *lookupTTLCache[OK, AK, V]) CASBy(lookup string, key AK, cmp, swp V) bool {
c.Lock()
defer c.Unlock()
origKey, ok := c.lookup.Get(lookup, key)
if !ok {
return false
}
return c.CASUnsafe(origKey, cmp, swp)
}
func (c *lookupTTLCache[OK, AK, V]) SwapBy(lookup string, key AK, swp V) V {
c.Lock()
defer c.Unlock()
origKey, ok := c.lookup.Get(lookup, key)
if !ok {
var value V
return value
}
return c.SwapUnsafe(origKey, swp)
}
func (c *lookupTTLCache[OK, AK, V]) HasBy(lookup string, key AK) bool {
c.Lock()
has := c.lookup.Has(lookup, key)
c.Unlock()
return has
}
func (c *lookupTTLCache[OK, AK, V]) InvalidateBy(lookup string, key AK) bool {
c.Lock()
defer c.Unlock()
origKey, ok := c.lookup.Get(lookup, key)
if !ok {
return false
}
c.InvalidateUnsafe(origKey)
return true
}
// LookupMap is a structure that provides lookups for
// keys to primary keys under supplied lookup identifiers.
// This is essentially a wrapper around map[string](map[K1]K2).
type LookupMap[OK comparable, AK comparable] struct {
initd bool
lookup map[string](map[AK]OK)
}
// RegisterLookup registers a lookup identifier in the LookupMap,
// note this can only be doing during the cfg.RegisterLookups() hook.
func (l *LookupMap[OK, AK]) RegisterLookup(id string) {
if l.initd {
panic("cache: cannot register lookup after initialization")
} else if _, ok := l.lookup[id]; ok {
panic("cache: lookup mapping already exists for identifier")
}
l.lookup[id] = make(map[AK]OK, 100)
}
// Get fetches an entry's primary key for lookup identifier and key.
func (l *LookupMap[OK, AK]) Get(id string, key AK) (OK, bool) {
keys, ok := l.lookup[id]
if !ok {
var key OK
return key, false
}
origKey, ok := keys[key]
return origKey, ok
}
// Set adds a lookup to the LookupMap under supplied lookup identifier,
// linking supplied key to the supplied primary (original) key.
func (l *LookupMap[OK, AK]) Set(id string, key AK, origKey OK) {
keys, ok := l.lookup[id]
if !ok {
panic("cache: invalid lookup identifier")
}
keys[key] = origKey
}
// Has checks if there exists a lookup for supplied identifier and key.
func (l *LookupMap[OK, AK]) Has(id string, key AK) bool {
keys, ok := l.lookup[id]
if !ok {
return false
}
_, ok = keys[key]
return ok
}
// Delete removes a lookup from LookupMap with supplied identifier and key.
func (l *LookupMap[OK, AK]) Delete(id string, key AK) {
keys, ok := l.lookup[id]
if !ok {
return
}
delete(keys, key)
}

333
vendor/codeberg.org/gruf/go-cache/v2/ttl.go generated vendored Normal file
View file

@ -0,0 +1,333 @@
package cache
import (
"context"
"sync"
"time"
"codeberg.org/gruf/go-runners"
)
// TTLCache is the underlying Cache implementation, providing both the base
// Cache interface and access to "unsafe" methods so that you may build your
// customized caches ontop of this structure.
type TTLCache[Key comparable, Value any] struct {
cache map[Key](*entry[Value])
evict Hook[Key, Value] // the evict hook is called when an item is evicted from the cache, includes manual delete
invalid Hook[Key, Value] // the invalidate hook is called when an item's data in the cache is invalidated
ttl time.Duration // ttl is the item TTL
svc runners.Service // svc manages running of the cache eviction routine
mu sync.Mutex // mu protects TTLCache for concurrent access
}
// Init performs Cache initialization, this MUST be called.
func (c *TTLCache[K, V]) Init() {
c.cache = make(map[K](*entry[V]), 100)
c.evict = emptyHook[K, V]
c.invalid = emptyHook[K, V]
c.ttl = time.Minute * 5
}
func (c *TTLCache[K, V]) Start(freq time.Duration) bool {
// Nothing to start
if freq <= 0 {
return false
}
// Track state of starting
done := make(chan struct{})
started := false
go func() {
ran := c.svc.Run(func(ctx context.Context) {
// Successfully started
started = true
close(done)
// start routine
c.run(ctx, freq)
})
// failed to start
if !ran {
close(done)
}
}()
<-done
return started
}
func (c *TTLCache[K, V]) Stop() bool {
return c.svc.Stop()
}
func (c *TTLCache[K, V]) run(ctx context.Context, freq time.Duration) {
t := time.NewTimer(freq)
for {
select {
// we got stopped
case <-ctx.Done():
if !t.Stop() {
<-t.C
}
return
// next tick
case <-t.C:
c.sweep()
t.Reset(freq)
}
}
}
// sweep attempts to evict expired items (with callback!) from cache.
func (c *TTLCache[K, V]) sweep() {
// Lock and defer unlock (in case of hook panic)
c.mu.Lock()
defer c.mu.Unlock()
// Fetch current time for TTL check
now := time.Now()
// Sweep the cache for old items!
for key, item := range c.cache {
if now.After(item.expiry) {
c.evict(key, item.value)
delete(c.cache, key)
}
}
}
// Lock locks the cache mutex.
func (c *TTLCache[K, V]) Lock() {
c.mu.Lock()
}
// Unlock unlocks the cache mutex.
func (c *TTLCache[K, V]) Unlock() {
c.mu.Unlock()
}
func (c *TTLCache[K, V]) SetEvictionCallback(hook Hook[K, V]) {
// Ensure non-nil hook
if hook == nil {
hook = emptyHook[K, V]
}
// Safely set evict hook
c.Lock()
c.evict = hook
c.Unlock()
}
func (c *TTLCache[K, V]) SetInvalidateCallback(hook Hook[K, V]) {
// Ensure non-nil hook
if hook == nil {
hook = emptyHook[K, V]
}
// Safely set invalidate hook
c.Lock()
c.invalid = hook
c.Unlock()
}
func (c *TTLCache[K, V]) SetTTL(ttl time.Duration, update bool) {
// Safely update TTL
c.Lock()
diff := ttl - c.ttl
c.ttl = ttl
if update {
// Update existing cache entries
for _, entry := range c.cache {
entry.expiry.Add(diff)
}
}
// We're done
c.Unlock()
}
func (c *TTLCache[K, V]) Get(key K) (V, bool) {
c.Lock()
value, ok := c.GetUnsafe(key)
c.Unlock()
return value, ok
}
// GetUnsafe is the mutex-unprotected logic for Cache.Get().
func (c *TTLCache[K, V]) GetUnsafe(key K) (V, bool) {
item, ok := c.cache[key]
if !ok {
var value V
return value, false
}
item.expiry = time.Now().Add(c.ttl)
return item.value, true
}
func (c *TTLCache[K, V]) Put(key K, value V) bool {
c.Lock()
success := c.PutUnsafe(key, value)
c.Unlock()
return success
}
// PutUnsafe is the mutex-unprotected logic for Cache.Put().
func (c *TTLCache[K, V]) PutUnsafe(key K, value V) bool {
// If already cached, return
if _, ok := c.cache[key]; ok {
return false
}
// Create new cached item
c.cache[key] = &entry[V]{
value: value,
expiry: time.Now().Add(c.ttl),
}
return true
}
func (c *TTLCache[K, V]) Set(key K, value V) {
c.Lock()
defer c.Unlock() // defer in case of hook panic
c.SetUnsafe(key, value)
}
// SetUnsafe is the mutex-unprotected logic for Cache.Set(), it calls externally-set functions.
func (c *TTLCache[K, V]) SetUnsafe(key K, value V) {
item, ok := c.cache[key]
if ok {
// call invalidate hook
c.invalid(key, item.value)
} else {
// alloc new item
item = &entry[V]{}
c.cache[key] = item
}
// Update the item + expiry
item.value = value
item.expiry = time.Now().Add(c.ttl)
}
func (c *TTLCache[K, V]) CAS(key K, cmp V, swp V) bool {
c.Lock()
ok := c.CASUnsafe(key, cmp, swp)
c.Unlock()
return ok
}
// CASUnsafe is the mutex-unprotected logic for Cache.CAS().
func (c *TTLCache[K, V]) CASUnsafe(key K, cmp V, swp V) bool {
// Check for item
item, ok := c.cache[key]
if !ok || !Compare(item.value, cmp) {
return false
}
// Invalidate item
c.invalid(key, item.value)
// Update item + expiry
item.value = swp
item.expiry = time.Now().Add(c.ttl)
return ok
}
func (c *TTLCache[K, V]) Swap(key K, swp V) V {
c.Lock()
old := c.SwapUnsafe(key, swp)
c.Unlock()
return old
}
// SwapUnsafe is the mutex-unprotected logic for Cache.Swap().
func (c *TTLCache[K, V]) SwapUnsafe(key K, swp V) V {
// Check for item
item, ok := c.cache[key]
if !ok {
var value V
return value
}
// invalidate old item
c.invalid(key, item.value)
old := item.value
// update item + expiry
item.value = swp
item.expiry = time.Now().Add(c.ttl)
return old
}
func (c *TTLCache[K, V]) Has(key K) bool {
c.Lock()
ok := c.HasUnsafe(key)
c.Unlock()
return ok
}
// HasUnsafe is the mutex-unprotected logic for Cache.Has().
func (c *TTLCache[K, V]) HasUnsafe(key K) bool {
_, ok := c.cache[key]
return ok
}
func (c *TTLCache[K, V]) Invalidate(key K) bool {
c.Lock()
defer c.Unlock()
return c.InvalidateUnsafe(key)
}
// InvalidateUnsafe is mutex-unprotected logic for Cache.Invalidate().
func (c *TTLCache[K, V]) InvalidateUnsafe(key K) bool {
// Check if we have item with key
item, ok := c.cache[key]
if !ok {
return false
}
// Call hook, remove from cache
c.invalid(key, item.value)
delete(c.cache, key)
return true
}
func (c *TTLCache[K, V]) Clear() {
c.Lock()
defer c.Unlock()
c.ClearUnsafe()
}
// ClearUnsafe is mutex-unprotected logic for Cache.Clean().
func (c *TTLCache[K, V]) ClearUnsafe() {
for key, item := range c.cache {
c.invalid(key, item.value)
delete(c.cache, key)
}
}
func (c *TTLCache[K, V]) Size() int {
c.Lock()
sz := c.SizeUnsafe()
c.Unlock()
return sz
}
// SizeUnsafe is mutex unprotected logic for Cache.Size().
func (c *TTLCache[K, V]) SizeUnsafe() int {
return len(c.cache)
}
// entry represents an item in the cache, with
// it's currently calculated expiry time.
type entry[Value any] struct {
value Value
expiry time.Time
}

5
vendor/modules.txt vendored
View file

@ -4,9 +4,12 @@ codeberg.org/gruf/go-bitutil
# codeberg.org/gruf/go-bytes v1.0.2
## explicit; go 1.14
codeberg.org/gruf/go-bytes
# codeberg.org/gruf/go-byteutil v1.0.0
# codeberg.org/gruf/go-byteutil v1.0.1
## explicit; go 1.16
codeberg.org/gruf/go-byteutil
# codeberg.org/gruf/go-cache/v2 v2.0.1
## explicit; go 1.18
codeberg.org/gruf/go-cache/v2
# codeberg.org/gruf/go-debug v1.1.2
## explicit; go 1.16
codeberg.org/gruf/go-debug