mirror of
https://github.com/superseriousbusiness/gotosocial.git
synced 2024-11-27 19:01:01 +00:00
[security] transport.Controller{} and transport.Transport{} security and performance improvements (#564)
* cache transports in controller by privkey-generated pubkey, add retry logic to transport requests Signed-off-by: kim <grufwub@gmail.com> * update code comments, defer mutex unlocks Signed-off-by: kim <grufwub@gmail.com> * add count to 'performing request' log message Signed-off-by: kim <grufwub@gmail.com> * reduce repeated conversions of same url.URL object Signed-off-by: kim <grufwub@gmail.com> * move worker.Worker to concurrency subpackage, add WorkQueue type, limit transport http client use by WorkQueue Signed-off-by: kim <grufwub@gmail.com> * fix security advisories regarding max outgoing conns, max rsp body size - implemented by a new httpclient.Client{} that wraps an underlying client with a queue to limit connections, and limit reader wrapping a response body with a configured maximum size - update pub.HttpClient args passed around to be this new httpclient.Client{} Signed-off-by: kim <grufwub@gmail.com> * add httpclient tests, move ip validation to separate package + change mechanism Signed-off-by: kim <grufwub@gmail.com> * fix merge conflicts Signed-off-by: kim <grufwub@gmail.com> * use singular mutex in transport rather than separate signer mus Signed-off-by: kim <grufwub@gmail.com> * improved useragent string Signed-off-by: kim <grufwub@gmail.com> * add note regarding missing test Signed-off-by: kim <grufwub@gmail.com> * remove useragent field from transport (instead store in controller) Signed-off-by: kim <grufwub@gmail.com> * shutup linter Signed-off-by: kim <grufwub@gmail.com> * reset other signing headers on each loop iteration Signed-off-by: kim <grufwub@gmail.com> * respect request ctx during retry-backoff sleep period Signed-off-by: kim <grufwub@gmail.com> * use external pkg with docs explaining performance "hack" Signed-off-by: kim <grufwub@gmail.com> * use http package constants instead of string method literals Signed-off-by: kim <grufwub@gmail.com> * add license file headers Signed-off-by: kim <grufwub@gmail.com> * update code comment to match new func names Signed-off-by: kim <grufwub@gmail.com> * updates to user-agent string Signed-off-by: kim <grufwub@gmail.com> * update signed testrig models to fit with new transport logic (instead uses separate signer now) Signed-off-by: kim <grufwub@gmail.com> * fuck you linter Signed-off-by: kim <grufwub@gmail.com>
This commit is contained in:
parent
4ac508f037
commit
223025fc27
61 changed files with 1801 additions and 435 deletions
|
@ -21,7 +21,6 @@ package server
|
|||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
"path"
|
||||
|
@ -56,12 +55,14 @@ import (
|
|||
"github.com/superseriousbusiness/gotosocial/internal/api/s2s/user"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/api/s2s/webfinger"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/api/security"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/config"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db/bundb"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/email"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/federation"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/federation/federatingdb"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/gotosocial"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/httpclient"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/media"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/messages"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/oauth"
|
||||
|
@ -71,7 +72,6 @@ import (
|
|||
"github.com/superseriousbusiness/gotosocial/internal/transport"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/web"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/worker"
|
||||
)
|
||||
|
||||
// Start creates and starts a gotosocial server
|
||||
|
@ -93,8 +93,8 @@ var Start action.GTSAction = func(ctx context.Context) error {
|
|||
// NOTE: these MUST NOT be used until they are passed to the
|
||||
// processor and it is started. The reason being that the processor
|
||||
// sets the Worker process functions and start the underlying pools
|
||||
clientWorker := worker.New[messages.FromClientAPI](-1, -1)
|
||||
fedWorker := worker.New[messages.FromFederator](-1, -1)
|
||||
clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
|
||||
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
|
||||
|
||||
federatingDB := federatingdb.New(dbService, fedWorker)
|
||||
|
||||
|
@ -120,13 +120,16 @@ var Start action.GTSAction = func(ctx context.Context) error {
|
|||
return fmt.Errorf("error creating storage backend: %s", err)
|
||||
}
|
||||
|
||||
// Build HTTP client (TODO: add configurables here)
|
||||
client := httpclient.New(httpclient.Config{})
|
||||
|
||||
// build backend handlers
|
||||
mediaManager, err := media.NewManager(dbService, storage)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error creating media manager: %s", err)
|
||||
}
|
||||
oauthServer := oauth.New(ctx, dbService)
|
||||
transportController := transport.NewController(dbService, federatingDB, &federation.Clock{}, http.DefaultClient)
|
||||
transportController := transport.NewController(dbService, federatingDB, &federation.Clock{}, client)
|
||||
federator := federation.NewFederator(dbService, federatingDB, transportController, typeConverter, mediaManager)
|
||||
|
||||
// decide whether to create a noop email sender (won't send emails) or a real one
|
||||
|
|
|
@ -54,11 +54,11 @@ import (
|
|||
"github.com/superseriousbusiness/gotosocial/internal/api/s2s/user"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/api/s2s/webfinger"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/api/security"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/gotosocial"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/messages"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/oidc"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/web"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/worker"
|
||||
"github.com/superseriousbusiness/gotosocial/testrig"
|
||||
)
|
||||
|
||||
|
@ -74,8 +74,8 @@ var Start action.GTSAction = func(ctx context.Context) error {
|
|||
testrig.StandardStorageSetup(storageBackend, "./testrig/media")
|
||||
|
||||
// Create client API and federator worker pools
|
||||
clientWorker := worker.New[messages.FromClientAPI](-1, -1)
|
||||
fedWorker := worker.New[messages.FromFederator](-1, -1)
|
||||
clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
|
||||
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
|
||||
|
||||
// build backend handlers
|
||||
oauthServer := testrig.NewTestOauthServer(dbService)
|
||||
|
|
5
go.mod
5
go.mod
|
@ -3,7 +3,10 @@ module github.com/superseriousbusiness/gotosocial
|
|||
go 1.18
|
||||
|
||||
require (
|
||||
codeberg.org/gruf/go-byteutil v1.0.1
|
||||
codeberg.org/gruf/go-cache/v2 v2.0.1
|
||||
codeberg.org/gruf/go-debug v1.1.2
|
||||
codeberg.org/gruf/go-errors/v2 v2.0.1
|
||||
codeberg.org/gruf/go-mutexes v1.1.2
|
||||
codeberg.org/gruf/go-runners v1.2.1
|
||||
codeberg.org/gruf/go-store v1.3.7
|
||||
|
@ -52,8 +55,6 @@ require (
|
|||
require (
|
||||
codeberg.org/gruf/go-bitutil v1.0.0 // indirect
|
||||
codeberg.org/gruf/go-bytes v1.0.2 // indirect
|
||||
codeberg.org/gruf/go-byteutil v1.0.0 // indirect
|
||||
codeberg.org/gruf/go-errors/v2 v2.0.1 // indirect
|
||||
codeberg.org/gruf/go-fastcopy v1.1.1 // indirect
|
||||
codeberg.org/gruf/go-fastpath v1.0.3 // indirect
|
||||
codeberg.org/gruf/go-hashenc v1.0.2 // indirect
|
||||
|
|
5
go.sum
5
go.sum
|
@ -40,9 +40,12 @@ codeberg.org/gruf/go-bitutil v1.0.0/go.mod h1:sb8IjlDnjVTz8zPK/8lmHesKxY0Yb3iqHW
|
|||
codeberg.org/gruf/go-bytes v1.0.0/go.mod h1:1v/ibfaosfXSZtRdW2rWaVrDXMc9E3bsi/M9Ekx39cg=
|
||||
codeberg.org/gruf/go-bytes v1.0.2 h1:malqE42Ni+h1nnYWBUAJaDDtEzF4aeN4uPN8DfMNNvo=
|
||||
codeberg.org/gruf/go-bytes v1.0.2/go.mod h1:1v/ibfaosfXSZtRdW2rWaVrDXMc9E3bsi/M9Ekx39cg=
|
||||
codeberg.org/gruf/go-byteutil v1.0.0 h1:xgKFNj/gH1r3yRo7gnyR4qrAKyeWCXs6B19ISX0DUAY=
|
||||
codeberg.org/gruf/go-byteutil v1.0.0/go.mod h1:cWM3tgMCroSzqoBXUXMhvxTxYJp+TbCr6ioISRY5vSU=
|
||||
codeberg.org/gruf/go-byteutil v1.0.1 h1:cOSaqe2aytOTAC5NM62LI0w8qPfJ9n2BBddk5KyMgd0=
|
||||
codeberg.org/gruf/go-byteutil v1.0.1/go.mod h1:cWM3tgMCroSzqoBXUXMhvxTxYJp+TbCr6ioISRY5vSU=
|
||||
codeberg.org/gruf/go-cache v1.1.2/go.mod h1:/Dbc+xU72Op3hMn6x2PXF3NE9uIDFeS+sXPF00hN/7o=
|
||||
codeberg.org/gruf/go-cache/v2 v2.0.1 h1:dyyfn6W6jfUlD/HWu5oz48sowSgsfKKeg2lU6T0gRww=
|
||||
codeberg.org/gruf/go-cache/v2 v2.0.1/go.mod h1:VyfrDnPVUXUKYVkXnFOHRO1EoN+8zrTC9jRU6VmL3p0=
|
||||
codeberg.org/gruf/go-debug v1.1.2 h1:7Tqkktg60M/4WtXTTNUFH2T/6irBw4tI4viv7IRLZDE=
|
||||
codeberg.org/gruf/go-debug v1.1.2/go.mod h1:N+vSy9uJBQgpQcJUqjctvqFz7tBHJf+S/PIjLILzpLg=
|
||||
codeberg.org/gruf/go-errors/v2 v2.0.0/go.mod h1:ZRhbdhvgoUA3Yw6e56kd9Ox984RrvbEFC2pOXyHDJP4=
|
||||
|
|
|
@ -11,6 +11,7 @@ import (
|
|||
"github.com/spf13/viper"
|
||||
"github.com/stretchr/testify/suite"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/api/client/account"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/config"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/email"
|
||||
|
@ -20,7 +21,6 @@ import (
|
|||
"github.com/superseriousbusiness/gotosocial/internal/messages"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/oauth"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/processing"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/worker"
|
||||
"github.com/superseriousbusiness/gotosocial/testrig"
|
||||
)
|
||||
|
||||
|
@ -62,8 +62,8 @@ func (suite *AccountStandardTestSuite) SetupTest() {
|
|||
testrig.InitTestConfig()
|
||||
testrig.InitTestLog()
|
||||
|
||||
fedWorker := worker.New[messages.FromFederator](-1, -1)
|
||||
clientWorker := worker.New[messages.FromClientAPI](-1, -1)
|
||||
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
|
||||
clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
|
||||
|
||||
suite.db = testrig.NewTestDB()
|
||||
suite.storage = testrig.NewTestStorage()
|
||||
|
|
|
@ -29,6 +29,7 @@ import (
|
|||
"github.com/spf13/viper"
|
||||
"github.com/stretchr/testify/suite"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/api/client/admin"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/config"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/email"
|
||||
|
@ -38,7 +39,6 @@ import (
|
|||
"github.com/superseriousbusiness/gotosocial/internal/messages"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/oauth"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/processing"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/worker"
|
||||
"github.com/superseriousbusiness/gotosocial/testrig"
|
||||
)
|
||||
|
||||
|
@ -80,8 +80,8 @@ func (suite *AdminStandardTestSuite) SetupTest() {
|
|||
testrig.InitTestConfig()
|
||||
testrig.InitTestLog()
|
||||
|
||||
fedWorker := worker.New[messages.FromFederator](-1, -1)
|
||||
clientWorker := worker.New[messages.FromClientAPI](-1, -1)
|
||||
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
|
||||
clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
|
||||
|
||||
suite.db = testrig.NewTestDB()
|
||||
suite.storage = testrig.NewTestStorage()
|
||||
|
|
|
@ -31,6 +31,7 @@ import (
|
|||
"github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/suite"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/api/client/fileserver"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/email"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/federation"
|
||||
|
@ -40,7 +41,6 @@ import (
|
|||
"github.com/superseriousbusiness/gotosocial/internal/oauth"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/processing"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/worker"
|
||||
"github.com/superseriousbusiness/gotosocial/testrig"
|
||||
)
|
||||
|
||||
|
@ -77,8 +77,8 @@ func (suite *ServeFileTestSuite) SetupSuite() {
|
|||
testrig.InitTestConfig()
|
||||
testrig.InitTestLog()
|
||||
|
||||
fedWorker := worker.New[messages.FromFederator](-1, -1)
|
||||
clientWorker := worker.New[messages.FromClientAPI](-1, -1)
|
||||
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
|
||||
clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
|
||||
|
||||
suite.db = testrig.NewTestDB()
|
||||
suite.storage = testrig.NewTestStorage()
|
||||
|
|
|
@ -28,6 +28,7 @@ import (
|
|||
"github.com/spf13/viper"
|
||||
"github.com/stretchr/testify/suite"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/api/client/followrequest"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/config"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/email"
|
||||
|
@ -37,7 +38,6 @@ import (
|
|||
"github.com/superseriousbusiness/gotosocial/internal/messages"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/oauth"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/processing"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/worker"
|
||||
"github.com/superseriousbusiness/gotosocial/testrig"
|
||||
)
|
||||
|
||||
|
@ -77,8 +77,8 @@ func (suite *FollowRequestStandardTestSuite) SetupTest() {
|
|||
testrig.InitTestConfig()
|
||||
testrig.InitTestLog()
|
||||
|
||||
fedWorker := worker.New[messages.FromFederator](-1, -1)
|
||||
clientWorker := worker.New[messages.FromClientAPI](-1, -1)
|
||||
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
|
||||
clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
|
||||
|
||||
suite.db = testrig.NewTestDB()
|
||||
suite.storage = testrig.NewTestStorage()
|
||||
|
|
|
@ -37,6 +37,7 @@ import (
|
|||
"github.com/stretchr/testify/suite"
|
||||
mediamodule "github.com/superseriousbusiness/gotosocial/internal/api/client/media"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/api/model"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/config"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/email"
|
||||
|
@ -47,7 +48,6 @@ import (
|
|||
"github.com/superseriousbusiness/gotosocial/internal/oauth"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/processing"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/worker"
|
||||
"github.com/superseriousbusiness/gotosocial/testrig"
|
||||
)
|
||||
|
||||
|
@ -84,8 +84,8 @@ func (suite *MediaCreateTestSuite) SetupSuite() {
|
|||
testrig.InitTestConfig()
|
||||
testrig.InitTestLog()
|
||||
|
||||
fedWorker := worker.New[messages.FromFederator](-1, -1)
|
||||
clientWorker := worker.New[messages.FromClientAPI](-1, -1)
|
||||
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
|
||||
clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
|
||||
|
||||
suite.db = testrig.NewTestDB()
|
||||
suite.storage = testrig.NewTestStorage()
|
||||
|
|
|
@ -35,6 +35,7 @@ import (
|
|||
"github.com/stretchr/testify/suite"
|
||||
mediamodule "github.com/superseriousbusiness/gotosocial/internal/api/client/media"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/api/model"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/config"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/email"
|
||||
|
@ -45,7 +46,6 @@ import (
|
|||
"github.com/superseriousbusiness/gotosocial/internal/oauth"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/processing"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/worker"
|
||||
"github.com/superseriousbusiness/gotosocial/testrig"
|
||||
)
|
||||
|
||||
|
@ -82,8 +82,8 @@ func (suite *MediaUpdateTestSuite) SetupSuite() {
|
|||
testrig.InitTestConfig()
|
||||
testrig.InitTestLog()
|
||||
|
||||
fedWorker := worker.New[messages.FromFederator](-1, -1)
|
||||
clientWorker := worker.New[messages.FromClientAPI](-1, -1)
|
||||
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
|
||||
clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
|
||||
|
||||
suite.db = testrig.NewTestDB()
|
||||
suite.storage = testrig.NewTestStorage()
|
||||
|
|
|
@ -32,6 +32,7 @@ import (
|
|||
"github.com/superseriousbusiness/activity/pub"
|
||||
"github.com/superseriousbusiness/activity/streams"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/api/client/status"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/email"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/federation"
|
||||
|
@ -40,7 +41,6 @@ import (
|
|||
"github.com/superseriousbusiness/gotosocial/internal/messages"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/processing"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/worker"
|
||||
"github.com/superseriousbusiness/gotosocial/testrig"
|
||||
)
|
||||
|
||||
|
@ -90,8 +90,8 @@ func (suite *StatusStandardTestSuite) SetupTest() {
|
|||
testrig.StandardDBSetup(suite.db, nil)
|
||||
testrig.StandardStorageSetup(suite.storage, "../../../../testrig/media")
|
||||
|
||||
fedWorker := worker.New[messages.FromFederator](-1, -1)
|
||||
clientWorker := worker.New[messages.FromClientAPI](-1, -1)
|
||||
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
|
||||
clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
|
||||
|
||||
suite.mediaManager = testrig.NewTestMediaManager(suite.db, suite.storage)
|
||||
suite.federator = testrig.NewTestFederator(suite.db, testrig.NewTestTransportController(suite.testHttpClient(), suite.db, fedWorker), suite.storage, suite.mediaManager, fedWorker)
|
||||
|
|
|
@ -22,6 +22,7 @@ import (
|
|||
"codeberg.org/gruf/go-store/kv"
|
||||
"github.com/stretchr/testify/suite"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/api/client/user"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/email"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/federation"
|
||||
|
@ -30,7 +31,6 @@ import (
|
|||
"github.com/superseriousbusiness/gotosocial/internal/messages"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/processing"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/worker"
|
||||
"github.com/superseriousbusiness/gotosocial/testrig"
|
||||
)
|
||||
|
||||
|
@ -58,8 +58,8 @@ type UserStandardTestSuite struct {
|
|||
func (suite *UserStandardTestSuite) SetupTest() {
|
||||
testrig.InitTestLog()
|
||||
testrig.InitTestConfig()
|
||||
fedWorker := worker.New[messages.FromFederator](-1, -1)
|
||||
clientWorker := worker.New[messages.FromClientAPI](-1, -1)
|
||||
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
|
||||
clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
|
||||
suite.testTokens = testrig.NewTestTokens()
|
||||
suite.testClients = testrig.NewTestClients()
|
||||
suite.testApplications = testrig.NewTestApplications()
|
||||
|
|
|
@ -33,11 +33,11 @@ import (
|
|||
"github.com/superseriousbusiness/activity/pub"
|
||||
"github.com/superseriousbusiness/activity/streams"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/api/s2s/user"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/id"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/messages"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/worker"
|
||||
"github.com/superseriousbusiness/gotosocial/testrig"
|
||||
)
|
||||
|
||||
|
@ -85,8 +85,8 @@ func (suite *InboxPostTestSuite) TestPostBlock() {
|
|||
suite.NoError(err)
|
||||
body := bytes.NewReader(bodyJson)
|
||||
|
||||
clientWorker := worker.New[messages.FromClientAPI](-1, -1)
|
||||
fedWorker := worker.New[messages.FromFederator](-1, -1)
|
||||
clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
|
||||
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
|
||||
|
||||
tc := testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil), suite.db, fedWorker)
|
||||
federator := testrig.NewTestFederator(suite.db, tc, suite.storage, suite.mediaManager, fedWorker)
|
||||
|
@ -188,8 +188,8 @@ func (suite *InboxPostTestSuite) TestPostUnblock() {
|
|||
suite.NoError(err)
|
||||
body := bytes.NewReader(bodyJson)
|
||||
|
||||
clientWorker := worker.New[messages.FromClientAPI](-1, -1)
|
||||
fedWorker := worker.New[messages.FromFederator](-1, -1)
|
||||
clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
|
||||
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
|
||||
|
||||
tc := testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil), suite.db, fedWorker)
|
||||
federator := testrig.NewTestFederator(suite.db, tc, suite.storage, suite.mediaManager, fedWorker)
|
||||
|
@ -281,8 +281,8 @@ func (suite *InboxPostTestSuite) TestPostUpdate() {
|
|||
suite.NoError(err)
|
||||
body := bytes.NewReader(bodyJson)
|
||||
|
||||
clientWorker := worker.New[messages.FromClientAPI](-1, -1)
|
||||
fedWorker := worker.New[messages.FromFederator](-1, -1)
|
||||
clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
|
||||
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
|
||||
|
||||
tc := testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil), suite.db, fedWorker)
|
||||
federator := testrig.NewTestFederator(suite.db, tc, suite.storage, suite.mediaManager, fedWorker)
|
||||
|
@ -403,8 +403,8 @@ func (suite *InboxPostTestSuite) TestPostDelete() {
|
|||
suite.NoError(err)
|
||||
body := bytes.NewReader(bodyJson)
|
||||
|
||||
clientWorker := worker.New[messages.FromClientAPI](-1, -1)
|
||||
fedWorker := worker.New[messages.FromFederator](-1, -1)
|
||||
clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
|
||||
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
|
||||
|
||||
tc := testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil), suite.db, fedWorker)
|
||||
federator := testrig.NewTestFederator(suite.db, tc, suite.storage, suite.mediaManager, fedWorker)
|
||||
|
|
|
@ -31,8 +31,8 @@ import (
|
|||
"github.com/superseriousbusiness/activity/streams"
|
||||
"github.com/superseriousbusiness/activity/streams/vocab"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/api/s2s/user"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/messages"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/worker"
|
||||
"github.com/superseriousbusiness/gotosocial/testrig"
|
||||
)
|
||||
|
||||
|
@ -46,8 +46,8 @@ func (suite *OutboxGetTestSuite) TestGetOutbox() {
|
|||
signedRequest := derefRequests["foss_satan_dereference_zork_outbox"]
|
||||
targetAccount := suite.testAccounts["local_account_1"]
|
||||
|
||||
clientWorker := worker.New[messages.FromClientAPI](-1, -1)
|
||||
fedWorker := worker.New[messages.FromFederator](-1, -1)
|
||||
clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
|
||||
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
|
||||
|
||||
tc := testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil), suite.db, fedWorker)
|
||||
federator := testrig.NewTestFederator(suite.db, tc, suite.storage, suite.mediaManager, fedWorker)
|
||||
|
@ -104,8 +104,8 @@ func (suite *OutboxGetTestSuite) TestGetOutboxFirstPage() {
|
|||
signedRequest := derefRequests["foss_satan_dereference_zork_outbox_first"]
|
||||
targetAccount := suite.testAccounts["local_account_1"]
|
||||
|
||||
clientWorker := worker.New[messages.FromClientAPI](-1, -1)
|
||||
fedWorker := worker.New[messages.FromFederator](-1, -1)
|
||||
clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
|
||||
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
|
||||
|
||||
tc := testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil), suite.db, fedWorker)
|
||||
federator := testrig.NewTestFederator(suite.db, tc, suite.storage, suite.mediaManager, fedWorker)
|
||||
|
@ -162,8 +162,8 @@ func (suite *OutboxGetTestSuite) TestGetOutboxNextPage() {
|
|||
signedRequest := derefRequests["foss_satan_dereference_zork_outbox_next"]
|
||||
targetAccount := suite.testAccounts["local_account_1"]
|
||||
|
||||
clientWorker := worker.New[messages.FromClientAPI](-1, -1)
|
||||
fedWorker := worker.New[messages.FromFederator](-1, -1)
|
||||
clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
|
||||
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
|
||||
|
||||
tc := testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil), suite.db, fedWorker)
|
||||
federator := testrig.NewTestFederator(suite.db, tc, suite.storage, suite.mediaManager, fedWorker)
|
||||
|
|
|
@ -33,8 +33,8 @@ import (
|
|||
"github.com/superseriousbusiness/activity/streams"
|
||||
"github.com/superseriousbusiness/activity/streams/vocab"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/api/s2s/user"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/messages"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/worker"
|
||||
"github.com/superseriousbusiness/gotosocial/testrig"
|
||||
)
|
||||
|
||||
|
@ -49,8 +49,8 @@ func (suite *RepliesGetTestSuite) TestGetReplies() {
|
|||
targetAccount := suite.testAccounts["local_account_1"]
|
||||
targetStatus := suite.testStatuses["local_account_1_status_1"]
|
||||
|
||||
clientWorker := worker.New[messages.FromClientAPI](-1, -1)
|
||||
fedWorker := worker.New[messages.FromFederator](-1, -1)
|
||||
clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
|
||||
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
|
||||
|
||||
tc := testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil), suite.db, fedWorker)
|
||||
federator := testrig.NewTestFederator(suite.db, tc, suite.storage, suite.mediaManager, fedWorker)
|
||||
|
@ -113,8 +113,8 @@ func (suite *RepliesGetTestSuite) TestGetRepliesNext() {
|
|||
targetAccount := suite.testAccounts["local_account_1"]
|
||||
targetStatus := suite.testStatuses["local_account_1_status_1"]
|
||||
|
||||
clientWorker := worker.New[messages.FromClientAPI](-1, -1)
|
||||
fedWorker := worker.New[messages.FromFederator](-1, -1)
|
||||
clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
|
||||
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
|
||||
|
||||
tc := testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil), suite.db, fedWorker)
|
||||
federator := testrig.NewTestFederator(suite.db, tc, suite.storage, suite.mediaManager, fedWorker)
|
||||
|
@ -180,8 +180,8 @@ func (suite *RepliesGetTestSuite) TestGetRepliesLast() {
|
|||
targetAccount := suite.testAccounts["local_account_1"]
|
||||
targetStatus := suite.testStatuses["local_account_1_status_1"]
|
||||
|
||||
clientWorker := worker.New[messages.FromClientAPI](-1, -1)
|
||||
fedWorker := worker.New[messages.FromFederator](-1, -1)
|
||||
clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
|
||||
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
|
||||
|
||||
tc := testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil), suite.db, fedWorker)
|
||||
federator := testrig.NewTestFederator(suite.db, tc, suite.storage, suite.mediaManager, fedWorker)
|
||||
|
|
|
@ -32,8 +32,8 @@ import (
|
|||
"github.com/superseriousbusiness/activity/streams"
|
||||
"github.com/superseriousbusiness/activity/streams/vocab"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/api/s2s/user"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/messages"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/worker"
|
||||
"github.com/superseriousbusiness/gotosocial/testrig"
|
||||
)
|
||||
|
||||
|
@ -48,8 +48,8 @@ func (suite *StatusGetTestSuite) TestGetStatus() {
|
|||
targetAccount := suite.testAccounts["local_account_1"]
|
||||
targetStatus := suite.testStatuses["local_account_1_status_1"]
|
||||
|
||||
clientWorker := worker.New[messages.FromClientAPI](-1, -1)
|
||||
fedWorker := worker.New[messages.FromFederator](-1, -1)
|
||||
clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
|
||||
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
|
||||
|
||||
tc := testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil), suite.db, fedWorker)
|
||||
federator := testrig.NewTestFederator(suite.db, tc, suite.storage, suite.mediaManager, fedWorker)
|
||||
|
@ -116,8 +116,8 @@ func (suite *StatusGetTestSuite) TestGetStatusLowercase() {
|
|||
targetAccount := suite.testAccounts["local_account_1"]
|
||||
targetStatus := suite.testStatuses["local_account_1_status_1"]
|
||||
|
||||
clientWorker := worker.New[messages.FromClientAPI](-1, -1)
|
||||
fedWorker := worker.New[messages.FromFederator](-1, -1)
|
||||
clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
|
||||
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
|
||||
|
||||
tc := testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil), suite.db, fedWorker)
|
||||
federator := testrig.NewTestFederator(suite.db, tc, suite.storage, suite.mediaManager, fedWorker)
|
||||
|
|
|
@ -23,6 +23,7 @@ import (
|
|||
"github.com/stretchr/testify/suite"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/api/s2s/user"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/api/security"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/email"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/federation"
|
||||
|
@ -32,7 +33,6 @@ import (
|
|||
"github.com/superseriousbusiness/gotosocial/internal/oauth"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/processing"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/worker"
|
||||
"github.com/superseriousbusiness/gotosocial/testrig"
|
||||
)
|
||||
|
||||
|
@ -78,8 +78,8 @@ func (suite *UserStandardTestSuite) SetupTest() {
|
|||
testrig.InitTestConfig()
|
||||
testrig.InitTestLog()
|
||||
|
||||
clientWorker := worker.New[messages.FromClientAPI](-1, -1)
|
||||
fedWorker := worker.New[messages.FromFederator](-1, -1)
|
||||
clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
|
||||
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
|
||||
|
||||
suite.db = testrig.NewTestDB()
|
||||
suite.tc = testrig.NewTestTypeConverter(suite.db)
|
||||
|
|
|
@ -33,9 +33,9 @@ import (
|
|||
"github.com/superseriousbusiness/activity/streams/vocab"
|
||||
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/api/s2s/user"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/messages"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/oauth"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/worker"
|
||||
"github.com/superseriousbusiness/gotosocial/testrig"
|
||||
)
|
||||
|
||||
|
@ -49,8 +49,8 @@ func (suite *UserGetTestSuite) TestGetUser() {
|
|||
signedRequest := derefRequests["foss_satan_dereference_zork"]
|
||||
targetAccount := suite.testAccounts["local_account_1"]
|
||||
|
||||
clientWorker := worker.New[messages.FromClientAPI](-1, -1)
|
||||
fedWorker := worker.New[messages.FromFederator](-1, -1)
|
||||
clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
|
||||
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
|
||||
|
||||
tc := testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil), suite.db, fedWorker)
|
||||
federator := testrig.NewTestFederator(suite.db, tc, suite.storage, suite.mediaManager, fedWorker)
|
||||
|
@ -130,8 +130,8 @@ func (suite *UserGetTestSuite) TestGetUserPublicKeyDeleted() {
|
|||
derefRequests := testrig.NewTestDereferenceRequests(suite.testAccounts)
|
||||
signedRequest := derefRequests["foss_satan_dereference_zork_public_key"]
|
||||
|
||||
clientWorker := worker.New[messages.FromClientAPI](-1, -1)
|
||||
fedWorker := worker.New[messages.FromFederator](-1, -1)
|
||||
clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
|
||||
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
|
||||
|
||||
tc := testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil), suite.db, fedWorker)
|
||||
federator := testrig.NewTestFederator(suite.db, tc, suite.storage, suite.mediaManager, fedWorker)
|
||||
|
|
|
@ -28,6 +28,7 @@ import (
|
|||
"github.com/superseriousbusiness/gotosocial/internal/ap"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/api/s2s/webfinger"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/api/security"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/email"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/federation"
|
||||
|
@ -37,7 +38,6 @@ import (
|
|||
"github.com/superseriousbusiness/gotosocial/internal/oauth"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/processing"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/worker"
|
||||
"github.com/superseriousbusiness/gotosocial/testrig"
|
||||
)
|
||||
|
||||
|
@ -81,8 +81,8 @@ func (suite *WebfingerStandardTestSuite) SetupTest() {
|
|||
testrig.InitTestLog()
|
||||
testrig.InitTestConfig()
|
||||
|
||||
clientWorker := worker.New[messages.FromClientAPI](-1, -1)
|
||||
fedWorker := worker.New[messages.FromFederator](-1, -1)
|
||||
clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
|
||||
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
|
||||
|
||||
suite.db = testrig.NewTestDB()
|
||||
suite.tc = testrig.NewTestTypeConverter(suite.db)
|
||||
|
|
|
@ -31,10 +31,10 @@ import (
|
|||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/suite"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/api/s2s/webfinger"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/config"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/messages"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/processing"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/worker"
|
||||
"github.com/superseriousbusiness/gotosocial/testrig"
|
||||
)
|
||||
|
||||
|
@ -71,8 +71,8 @@ func (suite *WebfingerGetTestSuite) TestFingerUser() {
|
|||
func (suite *WebfingerGetTestSuite) TestFingerUserWithDifferentAccountDomainByHost() {
|
||||
viper.Set(config.Keys.Host, "gts.example.org")
|
||||
viper.Set(config.Keys.AccountDomain, "example.org")
|
||||
clientWorker := worker.New[messages.FromClientAPI](-1, -1)
|
||||
fedWorker := worker.New[messages.FromFederator](-1, -1)
|
||||
clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
|
||||
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
|
||||
suite.processor = processing.NewProcessor(suite.tc, suite.federator, testrig.NewTestOauthServer(suite.db), testrig.NewTestMediaManager(suite.db, suite.storage), suite.storage, suite.db, suite.emailSender, clientWorker, fedWorker)
|
||||
suite.webfingerModule = webfinger.New(suite.processor).(*webfinger.Module)
|
||||
|
||||
|
@ -107,8 +107,8 @@ func (suite *WebfingerGetTestSuite) TestFingerUserWithDifferentAccountDomainByHo
|
|||
func (suite *WebfingerGetTestSuite) TestFingerUserWithDifferentAccountDomainByAccountDomain() {
|
||||
viper.Set(config.Keys.Host, "gts.example.org")
|
||||
viper.Set(config.Keys.AccountDomain, "example.org")
|
||||
clientWorker := worker.New[messages.FromClientAPI](-1, -1)
|
||||
fedWorker := worker.New[messages.FromFederator](-1, -1)
|
||||
clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
|
||||
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
|
||||
suite.processor = processing.NewProcessor(suite.tc, suite.federator, testrig.NewTestOauthServer(suite.db), testrig.NewTestMediaManager(suite.db, suite.storage), suite.storage, suite.db, suite.emailSender, clientWorker, fedWorker)
|
||||
suite.webfingerModule = webfinger.New(suite.processor).(*webfinger.Module)
|
||||
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
package worker
|
||||
package concurrency
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
@ -12,17 +12,17 @@ import (
|
|||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// Worker represents a proccessor for MsgType objects, using a worker pool to allocate resources.
|
||||
type Worker[MsgType any] struct {
|
||||
// WorkerPool represents a proccessor for MsgType objects, using a worker pool to allocate resources.
|
||||
type WorkerPool[MsgType any] struct {
|
||||
workers runners.WorkerPool
|
||||
process func(context.Context, MsgType) error
|
||||
prefix string // contains type prefix for logging
|
||||
}
|
||||
|
||||
// New returns a new Worker[MsgType] with given number of workers and queue ratio,
|
||||
// New returns a new WorkerPool[MsgType] with given number of workers and queue ratio,
|
||||
// where the queue ratio is multiplied by no. workers to get queue size. If args < 1
|
||||
// then suitable defaults are determined from the runtime's GOMAXPROCS variable.
|
||||
func New[MsgType any](workers int, queueRatio int) *Worker[MsgType] {
|
||||
func NewWorkerPool[MsgType any](workers int, queueRatio int) *WorkerPool[MsgType] {
|
||||
var zero MsgType
|
||||
|
||||
if workers < 1 {
|
||||
|
@ -38,7 +38,7 @@ func New[MsgType any](workers int, queueRatio int) *Worker[MsgType] {
|
|||
msgType := reflect.TypeOf(zero).String()
|
||||
_, msgType = path.Split(msgType)
|
||||
|
||||
w := &Worker[MsgType]{
|
||||
w := &WorkerPool[MsgType]{
|
||||
workers: runners.NewWorkerPool(workers, workers*queueRatio),
|
||||
process: nil,
|
||||
prefix: fmt.Sprintf("worker.Worker[%s]", msgType),
|
||||
|
@ -55,7 +55,7 @@ func New[MsgType any](workers int, queueRatio int) *Worker[MsgType] {
|
|||
}
|
||||
|
||||
// Start will attempt to start the underlying worker pool, or return error.
|
||||
func (w *Worker[MsgType]) Start() error {
|
||||
func (w *WorkerPool[MsgType]) Start() error {
|
||||
logrus.Infof("%s starting", w.prefix)
|
||||
|
||||
// Check processor was set
|
||||
|
@ -72,7 +72,7 @@ func (w *Worker[MsgType]) Start() error {
|
|||
}
|
||||
|
||||
// Stop will attempt to stop the underlying worker pool, or return error.
|
||||
func (w *Worker[MsgType]) Stop() error {
|
||||
func (w *WorkerPool[MsgType]) Stop() error {
|
||||
logrus.Infof("%s stopping", w.prefix)
|
||||
|
||||
// Attempt to stop pool
|
||||
|
@ -84,7 +84,7 @@ func (w *Worker[MsgType]) Stop() error {
|
|||
}
|
||||
|
||||
// SetProcessor will set the Worker's processor function, which is called for each queued message.
|
||||
func (w *Worker[MsgType]) SetProcessor(fn func(context.Context, MsgType) error) {
|
||||
func (w *WorkerPool[MsgType]) SetProcessor(fn func(context.Context, MsgType) error) {
|
||||
if w.process != nil {
|
||||
logrus.Panicf("%s Worker.process is already set", w.prefix)
|
||||
}
|
||||
|
@ -92,7 +92,7 @@ func (w *Worker[MsgType]) SetProcessor(fn func(context.Context, MsgType) error)
|
|||
}
|
||||
|
||||
// Queue will queue provided message to be processed with there's a free worker.
|
||||
func (w *Worker[MsgType]) Queue(msg MsgType) {
|
||||
func (w *WorkerPool[MsgType]) Queue(msg MsgType) {
|
||||
logrus.Tracef("%s queueing message (workers=%d queue=%d): %+v",
|
||||
w.prefix, w.workers.Workers(), w.workers.Queue(), msg,
|
||||
)
|
|
@ -29,12 +29,12 @@ import (
|
|||
"github.com/stretchr/testify/suite"
|
||||
"github.com/superseriousbusiness/activity/streams"
|
||||
"github.com/superseriousbusiness/activity/streams/vocab"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/federation/dereferencing"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/messages"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/transport"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/worker"
|
||||
"github.com/superseriousbusiness/gotosocial/testrig"
|
||||
)
|
||||
|
||||
|
@ -150,7 +150,7 @@ func (suite *DereferencerStandardTestSuite) mockTransportController() transport.
|
|||
|
||||
return response, nil
|
||||
}
|
||||
fedWorker := worker.New[messages.FromFederator](-1, -1)
|
||||
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
|
||||
mockClient := testrig.NewMockHTTPClient(do)
|
||||
return testrig.NewTestTransportController(mockClient, suite.db, fedWorker)
|
||||
}
|
||||
|
|
|
@ -28,10 +28,10 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/stretchr/testify/suite"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/federation"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/messages"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/worker"
|
||||
"github.com/superseriousbusiness/gotosocial/testrig"
|
||||
)
|
||||
|
||||
|
@ -57,7 +57,7 @@ func (suite *FederatingActorTestSuite) TestSendNoRemoteFollowers() {
|
|||
)
|
||||
testActivity := testrig.WrapAPNoteInCreate(testrig.URLMustParse("http://localhost:8080/whatever_some_create"), testrig.URLMustParse(testAccount.URI), time.Now(), testNote)
|
||||
|
||||
fedWorker := worker.New[messages.FromFederator](-1, -1)
|
||||
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
|
||||
|
||||
// setup transport controller with a no-op client so we don't make external calls
|
||||
sentMessages := []*url.URL{}
|
||||
|
@ -112,7 +112,7 @@ func (suite *FederatingActorTestSuite) TestSendRemoteFollower() {
|
|||
)
|
||||
testActivity := testrig.WrapAPNoteInCreate(testrig.URLMustParse("http://localhost:8080/whatever_some_create"), testrig.URLMustParse(testAccount.URI), time.Now(), testNote)
|
||||
|
||||
fedWorker := worker.New[messages.FromFederator](-1, -1)
|
||||
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
|
||||
|
||||
// setup transport controller with a no-op client so we don't make external calls
|
||||
sentMessages := []*url.URL{}
|
||||
|
|
|
@ -24,10 +24,10 @@ import (
|
|||
"codeberg.org/gruf/go-mutexes"
|
||||
"github.com/superseriousbusiness/activity/pub"
|
||||
"github.com/superseriousbusiness/activity/streams/vocab"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/messages"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/worker"
|
||||
)
|
||||
|
||||
// DB wraps the pub.Database interface with a couple of custom functions for GoToSocial.
|
||||
|
@ -44,12 +44,12 @@ type DB interface {
|
|||
type federatingDB struct {
|
||||
locks mutexes.MutexMap
|
||||
db db.DB
|
||||
fedWorker *worker.Worker[messages.FromFederator]
|
||||
fedWorker *concurrency.WorkerPool[messages.FromFederator]
|
||||
typeConverter typeutils.TypeConverter
|
||||
}
|
||||
|
||||
// New returns a DB interface using the given database and config
|
||||
func New(db db.DB, fedWorker *worker.Worker[messages.FromFederator]) DB {
|
||||
func New(db db.DB, fedWorker *concurrency.WorkerPool[messages.FromFederator]) DB {
|
||||
fdb := federatingDB{
|
||||
locks: mutexes.NewMap(-1, -1), // use defaults
|
||||
db: db,
|
||||
|
|
|
@ -23,12 +23,12 @@ import (
|
|||
|
||||
"github.com/stretchr/testify/suite"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/ap"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/federation/federatingdb"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/messages"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/worker"
|
||||
"github.com/superseriousbusiness/gotosocial/testrig"
|
||||
)
|
||||
|
||||
|
@ -36,7 +36,7 @@ type FederatingDBTestSuite struct {
|
|||
suite.Suite
|
||||
db db.DB
|
||||
tc typeutils.TypeConverter
|
||||
fedWorker *worker.Worker[messages.FromFederator]
|
||||
fedWorker *concurrency.WorkerPool[messages.FromFederator]
|
||||
fromFederator chan messages.FromFederator
|
||||
federatingDB federatingdb.DB
|
||||
|
||||
|
@ -65,7 +65,7 @@ func (suite *FederatingDBTestSuite) SetupSuite() {
|
|||
func (suite *FederatingDBTestSuite) SetupTest() {
|
||||
testrig.InitTestLog()
|
||||
testrig.InitTestConfig()
|
||||
suite.fedWorker = worker.New[messages.FromFederator](-1, -1)
|
||||
suite.fedWorker = concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
|
||||
suite.fromFederator = make(chan messages.FromFederator, 10)
|
||||
suite.fedWorker.SetProcessor(func(ctx context.Context, msg messages.FromFederator) error {
|
||||
suite.fromFederator <- msg
|
||||
|
|
|
@ -28,10 +28,10 @@ import (
|
|||
"github.com/stretchr/testify/suite"
|
||||
"github.com/superseriousbusiness/activity/pub"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/ap"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/federation"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/messages"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/worker"
|
||||
"github.com/superseriousbusiness/gotosocial/testrig"
|
||||
)
|
||||
|
||||
|
@ -44,7 +44,7 @@ func (suite *FederatingProtocolTestSuite) TestPostInboxRequestBodyHook() {
|
|||
// the activity we're gonna use
|
||||
activity := suite.testActivities["dm_for_zork"]
|
||||
|
||||
fedWorker := worker.New[messages.FromFederator](-1, -1)
|
||||
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
|
||||
|
||||
// setup transport controller with a no-op client so we don't make external calls
|
||||
tc := testrig.NewTestTransportController(testrig.NewMockHTTPClient(func(req *http.Request) (*http.Response, error) {
|
||||
|
@ -78,7 +78,7 @@ func (suite *FederatingProtocolTestSuite) TestAuthenticatePostInbox() {
|
|||
sendingAccount := suite.testAccounts["remote_account_1"]
|
||||
inboxAccount := suite.testAccounts["local_account_1"]
|
||||
|
||||
fedWorker := worker.New[messages.FromFederator](-1, -1)
|
||||
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
|
||||
|
||||
tc := testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil), suite.db, fedWorker)
|
||||
// now setup module being tested, with the mock transport controller
|
||||
|
|
199
internal/httpclient/client.go
Normal file
199
internal/httpclient/client.go
Normal file
|
@ -0,0 +1,199 @@
|
|||
/*
|
||||
GoToSocial
|
||||
Copyright (C) 2021-2022 GoToSocial Authors admin@gotosocial.org
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU Affero General Public License as published by
|
||||
the Free Software Foundation, either version 3 of the License, or
|
||||
(at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU Affero General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU Affero General Public License
|
||||
along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package httpclient
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"runtime"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ErrReservedAddr is returned if a dialed address resolves to an IP within a blocked or reserved net.
|
||||
var ErrReservedAddr = errors.New("dial within blocked / reserved IP range")
|
||||
|
||||
// ErrBodyTooLarge is returned when a received response body is above predefined limit (default 40MB).
|
||||
var ErrBodyTooLarge = errors.New("body size too large")
|
||||
|
||||
// dialer is the base net.Dialer used by all package-created http.Transports.
|
||||
var dialer = &net.Dialer{
|
||||
Timeout: 30 * time.Second,
|
||||
KeepAlive: 30 * time.Second,
|
||||
Resolver: &net.Resolver{Dial: nil},
|
||||
}
|
||||
|
||||
// Config provides configuration details for setting up a new
|
||||
// instance of httpclient.Client{}. Within are a subset of the
|
||||
// configuration values passed to initialized http.Transport{}
|
||||
// and http.Client{}, along with httpclient.Client{} specific.
|
||||
type Config struct {
|
||||
// MaxOpenConns limits the max number of concurrent open connections.
|
||||
MaxOpenConns int
|
||||
|
||||
// MaxIdleConns: see http.Transport{}.MaxIdleConns.
|
||||
MaxIdleConns int
|
||||
|
||||
// ReadBufferSize: see http.Transport{}.ReadBufferSize.
|
||||
ReadBufferSize int
|
||||
|
||||
// WriteBufferSize: see http.Transport{}.WriteBufferSize.
|
||||
WriteBufferSize int
|
||||
|
||||
// MaxBodySize determines the maximum fetchable body size.
|
||||
MaxBodySize int64
|
||||
|
||||
// Timeout: see http.Client{}.Timeout.
|
||||
Timeout time.Duration
|
||||
|
||||
// DisableCompression: see http.Transport{}.DisableCompression.
|
||||
DisableCompression bool
|
||||
|
||||
// AllowRanges allows outgoing communications to given IP nets.
|
||||
AllowRanges []netip.Prefix
|
||||
|
||||
// BlockRanges blocks outgoing communiciations to given IP nets.
|
||||
BlockRanges []netip.Prefix
|
||||
}
|
||||
|
||||
// Client wraps an underlying http.Client{} to provide the following:
|
||||
// - setting a maximum received request body size, returning error on
|
||||
// large content lengths, and using a limited reader in all other
|
||||
// cases to protect against forged / unknown content-lengths
|
||||
// - protection from server side request forgery (SSRF) by only dialing
|
||||
// out to known public IP prefixes, configurable with allows/blocks
|
||||
// - limit number of concurrent requests, else blocking until a slot
|
||||
// is available (context channels still respected)
|
||||
type Client struct {
|
||||
client http.Client
|
||||
queue chan struct{}
|
||||
bmax int64
|
||||
}
|
||||
|
||||
// New returns a new instance of Client initialized using configuration.
|
||||
func New(cfg Config) *Client {
|
||||
var c Client
|
||||
|
||||
// Copy global
|
||||
d := dialer
|
||||
|
||||
if cfg.MaxOpenConns <= 0 {
|
||||
// By default base this value on GOMAXPROCS.
|
||||
maxprocs := runtime.GOMAXPROCS(0)
|
||||
cfg.MaxOpenConns = maxprocs * 10
|
||||
}
|
||||
|
||||
if cfg.MaxIdleConns <= 0 {
|
||||
// By default base this value on MaxOpenConns
|
||||
cfg.MaxIdleConns = cfg.MaxOpenConns * 10
|
||||
}
|
||||
|
||||
if cfg.MaxBodySize <= 0 {
|
||||
// By default set this to a reasonable 40MB
|
||||
cfg.MaxBodySize = 40 * 1024 * 1024
|
||||
}
|
||||
|
||||
// Protect dialer with IP range sanitizer
|
||||
d.Control = (&sanitizer{
|
||||
allow: cfg.AllowRanges,
|
||||
block: cfg.BlockRanges,
|
||||
}).Sanitize
|
||||
|
||||
// Prepare client fields
|
||||
c.bmax = cfg.MaxBodySize
|
||||
c.queue = make(chan struct{}, cfg.MaxOpenConns)
|
||||
c.client.Timeout = cfg.Timeout
|
||||
|
||||
// Set underlying HTTP client roundtripper
|
||||
c.client.Transport = &http.Transport{
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
ForceAttemptHTTP2: true,
|
||||
DialContext: d.DialContext,
|
||||
MaxIdleConns: cfg.MaxIdleConns,
|
||||
IdleConnTimeout: 90 * time.Second,
|
||||
TLSHandshakeTimeout: 10 * time.Second,
|
||||
ExpectContinueTimeout: 1 * time.Second,
|
||||
ReadBufferSize: cfg.ReadBufferSize,
|
||||
WriteBufferSize: cfg.WriteBufferSize,
|
||||
DisableCompression: cfg.DisableCompression,
|
||||
}
|
||||
|
||||
return &c
|
||||
}
|
||||
|
||||
// Do will perform given request when an available slot in the queue is available,
|
||||
// and block until this time. For returned values, this follows the same semantics
|
||||
// as the standard http.Client{}.Do() implementation except that response body will
|
||||
// be wrapped by an io.LimitReader() to limit response body sizes.
|
||||
func (c *Client) Do(req *http.Request) (*http.Response, error) {
|
||||
select {
|
||||
// Request context cancelled
|
||||
case <-req.Context().Done():
|
||||
return nil, req.Context().Err()
|
||||
|
||||
// Slot in queue acquired
|
||||
case c.queue <- struct{}{}:
|
||||
// NOTE:
|
||||
// Ideally here we would set the slot release to happen either
|
||||
// on error return, or via callback from the response body closer.
|
||||
// However when implementing this, there appear deadlocks between
|
||||
// the channel queue here and the media manager worker pool. So
|
||||
// currently we only place a limit on connections dialing out, but
|
||||
// there may still be more connections open than len(c.queue) given
|
||||
// that connections may not be closed until response body is closed.
|
||||
// The current implementation will reduce the viability of denial of
|
||||
// service attacks, but if there are future issues heed this advice :]
|
||||
defer func() { <-c.queue }()
|
||||
}
|
||||
|
||||
// Perform the HTTP request
|
||||
rsp, err := c.client.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Check response body not too large
|
||||
if rsp.ContentLength > c.bmax {
|
||||
return nil, ErrBodyTooLarge
|
||||
}
|
||||
|
||||
// Seperate the body implementers
|
||||
rbody := (io.Reader)(rsp.Body)
|
||||
cbody := (io.Closer)(rsp.Body)
|
||||
|
||||
var limit int64
|
||||
|
||||
if limit = rsp.ContentLength; limit < 0 {
|
||||
// If unknown, use max as reader limit
|
||||
limit = c.bmax
|
||||
}
|
||||
|
||||
// Don't trust them, limit body reads
|
||||
rbody = io.LimitReader(rbody, limit)
|
||||
|
||||
// Wrap body with limit
|
||||
rsp.Body = &struct {
|
||||
io.Reader
|
||||
io.Closer
|
||||
}{rbody, cbody}
|
||||
|
||||
return rsp, nil
|
||||
}
|
154
internal/httpclient/client_test.go
Normal file
154
internal/httpclient/client_test.go
Normal file
|
@ -0,0 +1,154 @@
|
|||
/*
|
||||
GoToSocial
|
||||
Copyright (C) 2021-2022 GoToSocial Authors admin@gotosocial.org
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU Affero General Public License as published by
|
||||
the Free Software Foundation, either version 3 of the License, or
|
||||
(at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU Affero General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU Affero General Public License
|
||||
along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package httpclient_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/superseriousbusiness/gotosocial/internal/httpclient"
|
||||
)
|
||||
|
||||
var privateIPs = []string{
|
||||
"http://127.0.0.1:80",
|
||||
"http://0.0.0.0:80",
|
||||
"http://192.168.0.1:80",
|
||||
"http://192.168.1.0:80",
|
||||
"http://10.0.0.0:80",
|
||||
"http://172.16.0.0:80",
|
||||
"http://10.255.255.255:80",
|
||||
"http://172.31.255.255:80",
|
||||
"http://255.255.255.255:80",
|
||||
}
|
||||
|
||||
var bodies = []string{
|
||||
"hello world!",
|
||||
"{}",
|
||||
`{"key": "value", "some": "kinda bullshit"}`,
|
||||
"body with\r\nnewlines",
|
||||
}
|
||||
|
||||
// Note:
|
||||
// There is no test for the .MaxOpenConns implementation
|
||||
// in the httpclient.Client{}, due to the difficult to test
|
||||
// this. The block is only held for the actual dial out to
|
||||
// the connection, so the usual test of blocking and holding
|
||||
// open this queue slot to check we can't open another isn't
|
||||
// an easy test here.
|
||||
|
||||
func TestHTTPClientSmallBody(t *testing.T) {
|
||||
for _, body := range bodies {
|
||||
_TestHTTPClientWithBody(t, []byte(body), int(^uint16(0)))
|
||||
}
|
||||
}
|
||||
|
||||
func TestHTTPClientExactBody(t *testing.T) {
|
||||
for _, body := range bodies {
|
||||
_TestHTTPClientWithBody(t, []byte(body), len(body))
|
||||
}
|
||||
}
|
||||
|
||||
func TestHTTPClientLargeBody(t *testing.T) {
|
||||
for _, body := range bodies {
|
||||
_TestHTTPClientWithBody(t, []byte(body), len(body)-1)
|
||||
}
|
||||
}
|
||||
|
||||
func _TestHTTPClientWithBody(t *testing.T, body []byte, max int) {
|
||||
var (
|
||||
handler http.HandlerFunc
|
||||
|
||||
expect []byte
|
||||
|
||||
expectErr error
|
||||
)
|
||||
|
||||
// If this is a larger body, reslice and
|
||||
// set error so we know what to expect
|
||||
expect = body
|
||||
if max < len(body) {
|
||||
expect = expect[:max]
|
||||
expectErr = httpclient.ErrBodyTooLarge
|
||||
}
|
||||
|
||||
// Create new HTTP client with maximum body size
|
||||
client := httpclient.New(httpclient.Config{
|
||||
MaxBodySize: int64(max),
|
||||
DisableCompression: true,
|
||||
AllowRanges: []netip.Prefix{
|
||||
// Loopback (used by server)
|
||||
netip.MustParsePrefix("127.0.0.1/8"),
|
||||
},
|
||||
})
|
||||
|
||||
// Set simple body-writing test handler
|
||||
handler = func(rw http.ResponseWriter, r *http.Request) {
|
||||
_, _ = rw.Write(body)
|
||||
}
|
||||
|
||||
// Start the test server
|
||||
srv := httptest.NewServer(handler)
|
||||
defer srv.Close()
|
||||
|
||||
// Wrap body to provide reader iface
|
||||
rbody := bytes.NewReader(body)
|
||||
|
||||
// Create the test HTTP request
|
||||
req, _ := http.NewRequest("POST", srv.URL, rbody)
|
||||
|
||||
// Perform the test request
|
||||
rsp, err := client.Do(req)
|
||||
if !errors.Is(err, expectErr) {
|
||||
t.Fatalf("error performing client request: %v", err)
|
||||
} else if err != nil {
|
||||
return // expected error
|
||||
}
|
||||
defer rsp.Body.Close()
|
||||
|
||||
// Read response body into memory
|
||||
check, err := io.ReadAll(rsp.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("error reading response body: %v", err)
|
||||
}
|
||||
|
||||
// Check actual response body matches expected
|
||||
if !bytes.Equal(expect, check) {
|
||||
t.Errorf("response body did not match expected: expect=%q actual=%q", string(expect), string(check))
|
||||
}
|
||||
}
|
||||
|
||||
func TestHTTPClientPrivateIP(t *testing.T) {
|
||||
client := httpclient.New(httpclient.Config{})
|
||||
|
||||
for _, addr := range privateIPs {
|
||||
// Prepare request to private IP
|
||||
req, _ := http.NewRequest("GET", addr, nil)
|
||||
|
||||
// Perform the HTTP request
|
||||
_, err := client.Do(req)
|
||||
if !errors.Is(err, httpclient.ErrReservedAddr) {
|
||||
t.Errorf("dialing private address did not return expected error: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
64
internal/httpclient/sanitizer.go
Normal file
64
internal/httpclient/sanitizer.go
Normal file
|
@ -0,0 +1,64 @@
|
|||
/*
|
||||
GoToSocial
|
||||
Copyright (C) 2021-2022 GoToSocial Authors admin@gotosocial.org
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU Affero General Public License as published by
|
||||
the Free Software Foundation, either version 3 of the License, or
|
||||
(at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU Affero General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU Affero General Public License
|
||||
along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package httpclient
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"syscall"
|
||||
|
||||
"github.com/superseriousbusiness/gotosocial/internal/netutil"
|
||||
)
|
||||
|
||||
type sanitizer struct {
|
||||
allow []netip.Prefix
|
||||
block []netip.Prefix
|
||||
}
|
||||
|
||||
// Sanitize implements the required net.Dialer.Control function signature.
|
||||
func (s *sanitizer) Sanitize(ntwrk, addr string, _ syscall.RawConn) error {
|
||||
// Parse IP+port from addr
|
||||
ipport, err := netip.ParseAddrPort(addr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Seperate the IP
|
||||
ip := ipport.Addr()
|
||||
|
||||
// Check if this is explicitly allowed
|
||||
for i := 0; i < len(s.allow); i++ {
|
||||
if s.allow[i].Contains(ip) {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Now check if explicity blocked
|
||||
for i := 0; i < len(s.block); i++ {
|
||||
if s.block[i].Contains(ip) {
|
||||
return ErrReservedAddr
|
||||
}
|
||||
}
|
||||
|
||||
// Validate this is a safe IP
|
||||
if !netutil.ValidateIP(ip) {
|
||||
return ErrReservedAddr
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
|
@ -27,9 +27,9 @@ import (
|
|||
"github.com/robfig/cron/v3"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/spf13/viper"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/config"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/worker"
|
||||
)
|
||||
|
||||
// Manager provides an interface for managing media: parsing, storing, and retrieving media objects like photos, videos, and gifs.
|
||||
|
@ -79,8 +79,8 @@ type Manager interface {
|
|||
type manager struct {
|
||||
db db.DB
|
||||
storage *kv.KVStore
|
||||
emojiWorker *worker.Worker[*ProcessingEmoji]
|
||||
mediaWorker *worker.Worker[*ProcessingMedia]
|
||||
emojiWorker *concurrency.WorkerPool[*ProcessingEmoji]
|
||||
mediaWorker *concurrency.WorkerPool[*ProcessingMedia]
|
||||
stopCronJobs func() error
|
||||
}
|
||||
|
||||
|
@ -89,7 +89,7 @@ type manager struct {
|
|||
// A worker pool will also be initialized for the manager, to ensure that only
|
||||
// a limited number of media will be processed in parallel. The numbers of workers
|
||||
// is determined from the $GOMAXPROCS environment variable (usually no. CPU cores).
|
||||
// See internal/worker.New() documentation for further information.
|
||||
// See internal/concurrency.NewWorkerPool() documentation for further information.
|
||||
func NewManager(database db.DB, storage *kv.KVStore) (Manager, error) {
|
||||
m := &manager{
|
||||
db: database,
|
||||
|
@ -97,7 +97,7 @@ func NewManager(database db.DB, storage *kv.KVStore) (Manager, error) {
|
|||
}
|
||||
|
||||
// Prepare the media worker pool
|
||||
m.mediaWorker = worker.New[*ProcessingMedia](-1, 10)
|
||||
m.mediaWorker = concurrency.NewWorkerPool[*ProcessingMedia](-1, 10)
|
||||
m.mediaWorker.SetProcessor(func(ctx context.Context, media *ProcessingMedia) error {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return err
|
||||
|
@ -109,7 +109,7 @@ func NewManager(database db.DB, storage *kv.KVStore) (Manager, error) {
|
|||
})
|
||||
|
||||
// Prepare the emoji worker pool
|
||||
m.emojiWorker = worker.New[*ProcessingEmoji](-1, 10)
|
||||
m.emojiWorker = concurrency.NewWorkerPool[*ProcessingEmoji](-1, 10)
|
||||
m.emojiWorker.SetProcessor(func(ctx context.Context, emoji *ProcessingEmoji) error {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return err
|
||||
|
|
78
internal/netutil/validate.go
Normal file
78
internal/netutil/validate.go
Normal file
|
@ -0,0 +1,78 @@
|
|||
/*
|
||||
GoToSocial
|
||||
Copyright (C) 2021-2022 GoToSocial Authors admin@gotosocial.org
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU Affero General Public License as published by
|
||||
the Free Software Foundation, either version 3 of the License, or
|
||||
(at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU Affero General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU Affero General Public License
|
||||
along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package netutil
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
)
|
||||
|
||||
var (
|
||||
// IPv6GlobalUnicast is the global IPv6 unicast IP prefix.
|
||||
IPv6GlobalUnicast = netip.MustParsePrefix("ff00::/8")
|
||||
|
||||
// IPvReserved contains IPv4 reserved IP prefixes.
|
||||
IPv4Reserved = [...]netip.Prefix{
|
||||
netip.MustParsePrefix("0.0.0.0/8"), // Current network
|
||||
netip.MustParsePrefix("10.0.0.0/8"), // Private
|
||||
netip.MustParsePrefix("100.64.0.0/10"), // RFC6598
|
||||
netip.MustParsePrefix("127.0.0.0/8"), // Loopback
|
||||
netip.MustParsePrefix("169.254.0.0/16"), // Link-local
|
||||
netip.MustParsePrefix("172.16.0.0/12"), // Private
|
||||
netip.MustParsePrefix("192.0.0.0/24"), // RFC6890
|
||||
netip.MustParsePrefix("192.0.2.0/24"), // Test, doc, examples
|
||||
netip.MustParsePrefix("192.88.99.0/24"), // IPv6 to IPv4 relay
|
||||
netip.MustParsePrefix("192.168.0.0/16"), // Private
|
||||
netip.MustParsePrefix("198.18.0.0/15"), // Benchmarking tests
|
||||
netip.MustParsePrefix("198.51.100.0/24"), // Test, doc, examples
|
||||
netip.MustParsePrefix("203.0.113.0/24"), // Test, doc, examples
|
||||
netip.MustParsePrefix("224.0.0.0/4"), // Multicast
|
||||
netip.MustParsePrefix("240.0.0.0/4"), // Reserved (includes broadcast / 255.255.255.255)
|
||||
}
|
||||
)
|
||||
|
||||
// ValidateAddr will parse a netip.AddrPort from string, and return the result of ValidateIP() on addr.
|
||||
func ValidateAddr(s string) bool {
|
||||
ipport, err := netip.ParseAddrPort(s)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return ValidateIP(ipport.Addr())
|
||||
}
|
||||
|
||||
// ValidateIP returns whether IP is an IPv4/6 address in non-reserved, public ranges.
|
||||
func ValidateIP(ip netip.Addr) bool {
|
||||
switch {
|
||||
// IPv4: check if IPv4 in reserved nets
|
||||
case ip.Is4():
|
||||
for _, reserved := range IPv4Reserved {
|
||||
if reserved.Contains(ip) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
|
||||
// IPv6: check if in global unicast (public internet)
|
||||
case ip.Is6():
|
||||
return IPv6GlobalUnicast.Contains(ip)
|
||||
|
||||
// Assume malicious by default
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
|
@ -23,6 +23,7 @@ import (
|
|||
"mime/multipart"
|
||||
|
||||
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/federation"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
|
||||
|
@ -33,7 +34,6 @@ import (
|
|||
"github.com/superseriousbusiness/gotosocial/internal/text"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/visibility"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/worker"
|
||||
"github.com/superseriousbusiness/oauth2/v4"
|
||||
)
|
||||
|
||||
|
@ -84,7 +84,7 @@ type Processor interface {
|
|||
type processor struct {
|
||||
tc typeutils.TypeConverter
|
||||
mediaManager media.Manager
|
||||
clientWorker *worker.Worker[messages.FromClientAPI]
|
||||
clientWorker *concurrency.WorkerPool[messages.FromClientAPI]
|
||||
oauthServer oauth.Server
|
||||
filter visibility.Filter
|
||||
formatter text.Formatter
|
||||
|
@ -94,7 +94,7 @@ type processor struct {
|
|||
}
|
||||
|
||||
// New returns a new account processor.
|
||||
func New(db db.DB, tc typeutils.TypeConverter, mediaManager media.Manager, oauthServer oauth.Server, clientWorker *worker.Worker[messages.FromClientAPI], federator federation.Federator, parseMention gtsmodel.ParseMentionFunc) Processor {
|
||||
func New(db db.DB, tc typeutils.TypeConverter, mediaManager media.Manager, oauthServer oauth.Server, clientWorker *concurrency.WorkerPool[messages.FromClientAPI], federator federation.Federator, parseMention gtsmodel.ParseMentionFunc) Processor {
|
||||
return &processor{
|
||||
tc: tc,
|
||||
mediaManager: mediaManager,
|
||||
|
|
|
@ -24,6 +24,7 @@ import (
|
|||
"codeberg.org/gruf/go-store/kv"
|
||||
"github.com/stretchr/testify/suite"
|
||||
"github.com/superseriousbusiness/activity/pub"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/email"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/federation"
|
||||
|
@ -35,7 +36,6 @@ import (
|
|||
"github.com/superseriousbusiness/gotosocial/internal/processing/account"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/transport"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/worker"
|
||||
"github.com/superseriousbusiness/gotosocial/testrig"
|
||||
)
|
||||
|
||||
|
@ -81,8 +81,8 @@ func (suite *AccountStandardTestSuite) SetupTest() {
|
|||
testrig.InitTestLog()
|
||||
testrig.InitTestConfig()
|
||||
|
||||
fedWorker := worker.New[messages.FromFederator](-1, -1)
|
||||
clientWorker := worker.New[messages.FromClientAPI](-1, -1)
|
||||
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
|
||||
clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
|
||||
clientWorker.SetProcessor(func(_ context.Context, msg messages.FromClientAPI) error {
|
||||
suite.fromClientAPIChan <- msg
|
||||
return nil
|
||||
|
|
|
@ -23,13 +23,13 @@ import (
|
|||
"mime/multipart"
|
||||
|
||||
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/media"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/messages"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/worker"
|
||||
)
|
||||
|
||||
// Processor wraps a bunch of functions for processing admin actions.
|
||||
|
@ -47,12 +47,12 @@ type Processor interface {
|
|||
type processor struct {
|
||||
tc typeutils.TypeConverter
|
||||
mediaManager media.Manager
|
||||
clientWorker *worker.Worker[messages.FromClientAPI]
|
||||
clientWorker *concurrency.WorkerPool[messages.FromClientAPI]
|
||||
db db.DB
|
||||
}
|
||||
|
||||
// New returns a new admin processor.
|
||||
func New(db db.DB, tc typeutils.TypeConverter, mediaManager media.Manager, clientWorker *worker.Worker[messages.FromClientAPI]) Processor {
|
||||
func New(db db.DB, tc typeutils.TypeConverter, mediaManager media.Manager, clientWorker *concurrency.WorkerPool[messages.FromClientAPI]) Processor {
|
||||
return &processor{
|
||||
tc: tc,
|
||||
mediaManager: mediaManager,
|
||||
|
|
|
@ -26,6 +26,7 @@ import (
|
|||
"codeberg.org/gruf/go-store/kv"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/suite"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/media"
|
||||
|
@ -33,7 +34,6 @@ import (
|
|||
mediaprocessing "github.com/superseriousbusiness/gotosocial/internal/processing/media"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/transport"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/worker"
|
||||
"github.com/superseriousbusiness/gotosocial/testrig"
|
||||
)
|
||||
|
||||
|
@ -122,7 +122,7 @@ func (suite *MediaStandardTestSuite) mockTransportController() transport.Control
|
|||
|
||||
return response, nil
|
||||
}
|
||||
fedWorker := worker.New[messages.FromFederator](-1, -1)
|
||||
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
|
||||
mockClient := testrig.NewMockHTTPClient(do)
|
||||
return testrig.NewTestTransportController(mockClient, suite.db, fedWorker)
|
||||
}
|
||||
|
|
|
@ -25,6 +25,7 @@ import (
|
|||
|
||||
"codeberg.org/gruf/go-store/kv"
|
||||
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/email"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/federation"
|
||||
|
@ -44,7 +45,6 @@ import (
|
|||
"github.com/superseriousbusiness/gotosocial/internal/timeline"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/visibility"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/worker"
|
||||
)
|
||||
|
||||
// Processor should be passed to api modules (see internal/apimodule/...). It is used for
|
||||
|
@ -237,8 +237,8 @@ type Processor interface {
|
|||
|
||||
// processor just implements the Processor interface
|
||||
type processor struct {
|
||||
clientWorker *worker.Worker[messages.FromClientAPI]
|
||||
fedWorker *worker.Worker[messages.FromFederator]
|
||||
clientWorker *concurrency.WorkerPool[messages.FromClientAPI]
|
||||
fedWorker *concurrency.WorkerPool[messages.FromFederator]
|
||||
|
||||
federator federation.Federator
|
||||
tc typeutils.TypeConverter
|
||||
|
@ -271,8 +271,8 @@ func NewProcessor(
|
|||
storage *kv.KVStore,
|
||||
db db.DB,
|
||||
emailSender email.Sender,
|
||||
clientWorker *worker.Worker[messages.FromClientAPI],
|
||||
fedWorker *worker.Worker[messages.FromFederator],
|
||||
clientWorker *concurrency.WorkerPool[messages.FromClientAPI],
|
||||
fedWorker *concurrency.WorkerPool[messages.FromFederator],
|
||||
) Processor {
|
||||
parseMentionFunc := GetParseMentionFunc(db, federator)
|
||||
|
||||
|
|
|
@ -29,6 +29,7 @@ import (
|
|||
"codeberg.org/gruf/go-store/kv"
|
||||
"github.com/stretchr/testify/suite"
|
||||
"github.com/superseriousbusiness/activity/streams"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/email"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/federation"
|
||||
|
@ -40,7 +41,6 @@ import (
|
|||
"github.com/superseriousbusiness/gotosocial/internal/timeline"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/transport"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/worker"
|
||||
"github.com/superseriousbusiness/gotosocial/testrig"
|
||||
)
|
||||
|
||||
|
@ -217,8 +217,8 @@ func (suite *ProcessingStandardTestSuite) SetupTest() {
|
|||
}, nil
|
||||
})
|
||||
|
||||
clientWorker := worker.New[messages.FromClientAPI](-1, -1)
|
||||
fedWorker := worker.New[messages.FromFederator](-1, -1)
|
||||
clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
|
||||
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
|
||||
|
||||
suite.transportController = testrig.NewTestTransportController(httpClient, suite.db, fedWorker)
|
||||
suite.mediaManager = testrig.NewTestMediaManager(suite.db, suite.storage)
|
||||
|
|
|
@ -22,6 +22,7 @@ import (
|
|||
"context"
|
||||
|
||||
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
|
||||
|
@ -29,7 +30,6 @@ import (
|
|||
"github.com/superseriousbusiness/gotosocial/internal/text"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/visibility"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/worker"
|
||||
)
|
||||
|
||||
// Processor wraps a bunch of functions for processing statuses.
|
||||
|
@ -74,12 +74,12 @@ type processor struct {
|
|||
db db.DB
|
||||
filter visibility.Filter
|
||||
formatter text.Formatter
|
||||
clientWorker *worker.Worker[messages.FromClientAPI]
|
||||
clientWorker *concurrency.WorkerPool[messages.FromClientAPI]
|
||||
parseMention gtsmodel.ParseMentionFunc
|
||||
}
|
||||
|
||||
// New returns a new status processor.
|
||||
func New(db db.DB, tc typeutils.TypeConverter, clientWorker *worker.Worker[messages.FromClientAPI], parseMention gtsmodel.ParseMentionFunc) Processor {
|
||||
func New(db db.DB, tc typeutils.TypeConverter, clientWorker *concurrency.WorkerPool[messages.FromClientAPI], parseMention gtsmodel.ParseMentionFunc) Processor {
|
||||
return &processor{
|
||||
tc: tc,
|
||||
db: db,
|
||||
|
|
|
@ -21,6 +21,7 @@ package status_test
|
|||
import (
|
||||
"codeberg.org/gruf/go-store/kv"
|
||||
"github.com/stretchr/testify/suite"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/federation"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
|
||||
|
@ -30,7 +31,6 @@ import (
|
|||
"github.com/superseriousbusiness/gotosocial/internal/processing/status"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/transport"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/worker"
|
||||
"github.com/superseriousbusiness/gotosocial/testrig"
|
||||
)
|
||||
|
||||
|
@ -42,7 +42,7 @@ type StatusStandardTestSuite struct {
|
|||
storage *kv.KVStore
|
||||
mediaManager media.Manager
|
||||
federator federation.Federator
|
||||
clientWorker *worker.Worker[messages.FromClientAPI]
|
||||
clientWorker *concurrency.WorkerPool[messages.FromClientAPI]
|
||||
|
||||
// standard suite models
|
||||
testTokens map[string]*gtsmodel.Token
|
||||
|
@ -75,11 +75,11 @@ func (suite *StatusStandardTestSuite) SetupTest() {
|
|||
testrig.InitTestConfig()
|
||||
testrig.InitTestLog()
|
||||
|
||||
fedWorker := worker.New[messages.FromFederator](-1, -1)
|
||||
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
|
||||
|
||||
suite.db = testrig.NewTestDB()
|
||||
suite.typeConverter = testrig.NewTestTypeConverter(suite.db)
|
||||
suite.clientWorker = worker.New[messages.FromClientAPI](-1, -1)
|
||||
suite.clientWorker = concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
|
||||
suite.tc = testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil), suite.db, fedWorker)
|
||||
suite.storage = testrig.NewTestStorage()
|
||||
suite.mediaManager = testrig.NewTestMediaManager(suite.db, suite.storage)
|
||||
|
|
|
@ -20,13 +20,17 @@ package transport
|
|||
|
||||
import (
|
||||
"context"
|
||||
"crypto"
|
||||
"crypto/rsa"
|
||||
"crypto/x509"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"sync"
|
||||
"runtime/debug"
|
||||
"time"
|
||||
|
||||
"github.com/go-fed/httpsig"
|
||||
"codeberg.org/gruf/go-byteutil"
|
||||
"codeberg.org/gruf/go-cache/v2"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/spf13/viper"
|
||||
"github.com/superseriousbusiness/activity/pub"
|
||||
"github.com/superseriousbusiness/activity/streams"
|
||||
|
@ -37,109 +41,85 @@ import (
|
|||
|
||||
// Controller generates transports for use in making federation requests to other servers.
|
||||
type Controller interface {
|
||||
NewTransport(pubKeyID string, privkey crypto.PrivateKey) (Transport, error)
|
||||
// NewTransport returns an http signature transport with the given public key ID (URL location of pubkey), and the given private key.
|
||||
NewTransport(pubKeyID string, privkey *rsa.PrivateKey) (Transport, error)
|
||||
|
||||
// NewTransportForUsername searches for account with username, and returns result of .NewTransport().
|
||||
NewTransportForUsername(ctx context.Context, username string) (Transport, error)
|
||||
}
|
||||
|
||||
type controller struct {
|
||||
db db.DB
|
||||
clock pub.Clock
|
||||
client pub.HttpClient
|
||||
appAgent string
|
||||
|
||||
// dereferenceFollowersShortcut is a shortcut to dereference followers of an
|
||||
// account on this instance, without making any external api/http calls.
|
||||
//
|
||||
// It is passed to new transports, and should only be invoked when the iri.Host == this host.
|
||||
dereferenceFollowersShortcut func(ctx context.Context, iri *url.URL) ([]byte, error)
|
||||
|
||||
// dereferenceUserShortcut is a shortcut to dereference followers an account on
|
||||
// this instance, without making any external api/http calls.
|
||||
//
|
||||
// It is passed to new transports, and should only be invoked when the iri.Host == this host.
|
||||
dereferenceUserShortcut func(ctx context.Context, iri *url.URL) ([]byte, error)
|
||||
}
|
||||
|
||||
func dereferenceFollowersShortcut(federatingDB federatingdb.DB) func(context.Context, *url.URL) ([]byte, error) {
|
||||
return func(ctx context.Context, iri *url.URL) ([]byte, error) {
|
||||
followers, err := federatingDB.Followers(ctx, iri)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
i, err := streams.Serialize(followers)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return json.Marshal(i)
|
||||
}
|
||||
}
|
||||
|
||||
func dereferenceUserShortcut(federatingDB federatingdb.DB) func(context.Context, *url.URL) ([]byte, error) {
|
||||
return func(ctx context.Context, iri *url.URL) ([]byte, error) {
|
||||
user, err := federatingDB.Get(ctx, iri)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
i, err := streams.Serialize(user)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return json.Marshal(i)
|
||||
}
|
||||
db db.DB
|
||||
fedDB federatingdb.DB
|
||||
clock pub.Clock
|
||||
client pub.HttpClient
|
||||
cache cache.Cache[string, *transport]
|
||||
userAgent string
|
||||
}
|
||||
|
||||
// NewController returns an implementation of the Controller interface for creating new transports
|
||||
func NewController(db db.DB, federatingDB federatingdb.DB, clock pub.Clock, client pub.HttpClient) Controller {
|
||||
applicationName := viper.GetString(config.Keys.ApplicationName)
|
||||
host := viper.GetString(config.Keys.Host)
|
||||
appAgent := fmt.Sprintf("%s %s", applicationName, host)
|
||||
|
||||
return &controller{
|
||||
db: db,
|
||||
clock: clock,
|
||||
client: client,
|
||||
appAgent: appAgent,
|
||||
dereferenceFollowersShortcut: dereferenceFollowersShortcut(federatingDB),
|
||||
dereferenceUserShortcut: dereferenceUserShortcut(federatingDB),
|
||||
// Determine build information
|
||||
build, _ := debug.ReadBuildInfo()
|
||||
|
||||
c := &controller{
|
||||
db: db,
|
||||
fedDB: federatingDB,
|
||||
clock: clock,
|
||||
client: client,
|
||||
cache: cache.New[string, *transport](),
|
||||
userAgent: fmt.Sprintf("%s; %s (gofed/activity gotosocial-%s)", applicationName, host, build.Main.Version),
|
||||
}
|
||||
|
||||
// Transport cache has TTL=1hr freq=1m
|
||||
c.cache.SetTTL(time.Hour, false)
|
||||
if !c.cache.Start(time.Minute) {
|
||||
logrus.Panic("failed to start transport controller cache")
|
||||
}
|
||||
|
||||
return c
|
||||
}
|
||||
|
||||
// NewTransport returns a new http signature transport with the given public key id (a URL), and the given private key.
|
||||
func (c *controller) NewTransport(pubKeyID string, privkey crypto.PrivateKey) (Transport, error) {
|
||||
prefs := []httpsig.Algorithm{httpsig.RSA_SHA256}
|
||||
digestAlgo := httpsig.DigestSha256
|
||||
getHeaders := []string{httpsig.RequestTarget, "host", "date"}
|
||||
postHeaders := []string{httpsig.RequestTarget, "host", "date", "digest"}
|
||||
func (c *controller) NewTransport(pubKeyID string, privkey *rsa.PrivateKey) (Transport, error) {
|
||||
// Generate public key string for cache key
|
||||
//
|
||||
// NOTE: it is safe to use the public key as the cache
|
||||
// key here as we are generating it ourselves from the
|
||||
// private key. If we were simply using a public key
|
||||
// provided as argument that would absolutely NOT be safe.
|
||||
pubStr := privkeyToPublicStr(privkey)
|
||||
|
||||
getSigner, _, err := httpsig.NewSigner(prefs, digestAlgo, getHeaders, httpsig.Signature, 120)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating get signer: %s", err)
|
||||
// First check for cached transport
|
||||
transp, ok := c.cache.Get(pubStr)
|
||||
if ok {
|
||||
return transp, nil
|
||||
}
|
||||
|
||||
postSigner, _, err := httpsig.NewSigner(prefs, digestAlgo, postHeaders, httpsig.Signature, 120)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating post signer: %s", err)
|
||||
// Create the transport
|
||||
transp = &transport{
|
||||
controller: c,
|
||||
pubKeyID: pubKeyID,
|
||||
privkey: privkey,
|
||||
}
|
||||
|
||||
sigTransport := pub.NewHttpSigTransport(c.client, c.appAgent, c.clock, getSigner, postSigner, pubKeyID, privkey)
|
||||
// Cache this transport under pubkey
|
||||
if !c.cache.Put(pubStr, transp) {
|
||||
var cached *transport
|
||||
|
||||
return &transport{
|
||||
client: c.client,
|
||||
appAgent: c.appAgent,
|
||||
gofedAgent: "(go-fed/activity v1.0.0)",
|
||||
clock: c.clock,
|
||||
pubKeyID: pubKeyID,
|
||||
privkey: privkey,
|
||||
sigTransport: sigTransport,
|
||||
getSigner: getSigner,
|
||||
getSignerMu: &sync.Mutex{},
|
||||
dereferenceFollowersShortcut: c.dereferenceFollowersShortcut,
|
||||
dereferenceUserShortcut: c.dereferenceUserShortcut,
|
||||
}, nil
|
||||
cached, ok = c.cache.Get(pubStr)
|
||||
if !ok {
|
||||
// Some ridiculous race cond.
|
||||
c.cache.Set(pubStr, transp)
|
||||
} else {
|
||||
// Use already cached
|
||||
transp = cached
|
||||
}
|
||||
}
|
||||
|
||||
return transp, nil
|
||||
}
|
||||
|
||||
func (c *controller) NewTransportForUsername(ctx context.Context, username string) (Transport, error) {
|
||||
|
@ -164,3 +144,45 @@ func (c *controller) NewTransportForUsername(ctx context.Context, username strin
|
|||
}
|
||||
return transport, nil
|
||||
}
|
||||
|
||||
// dereferenceLocalFollowers is a shortcut to dereference followers of an
|
||||
// account on this instance, without making any external api/http calls.
|
||||
//
|
||||
// It is passed to new transports, and should only be invoked when the iri.Host == this host.
|
||||
func (c *controller) dereferenceLocalFollowers(ctx context.Context, iri *url.URL) ([]byte, error) {
|
||||
followers, err := c.fedDB.Followers(ctx, iri)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
i, err := streams.Serialize(followers)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return json.Marshal(i)
|
||||
}
|
||||
|
||||
// dereferenceLocalUser is a shortcut to dereference followers an account on
|
||||
// this instance, without making any external api/http calls.
|
||||
//
|
||||
// It is passed to new transports, and should only be invoked when the iri.Host == this host.
|
||||
func (c *controller) dereferenceLocalUser(ctx context.Context, iri *url.URL) ([]byte, error) {
|
||||
user, err := c.fedDB.Get(ctx, iri)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
i, err := streams.Serialize(user)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return json.Marshal(i)
|
||||
}
|
||||
|
||||
// privkeyToPublicStr will create a string representation of RSA public key from private.
|
||||
func privkeyToPublicStr(privkey *rsa.PrivateKey) string {
|
||||
b := x509.MarshalPKCS1PublicKey(&privkey.PublicKey)
|
||||
return byteutil.B2S(b)
|
||||
}
|
||||
|
|
|
@ -19,13 +19,14 @@
|
|||
package transport
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/spf13/viper"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/config"
|
||||
)
|
||||
|
@ -72,6 +73,28 @@ func (t *transport) Deliver(ctx context.Context, b []byte, to *url.URL) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
logrus.Debugf("Deliver: posting as %s to %s", t.pubKeyID, to.String())
|
||||
return t.sigTransport.Deliver(ctx, b, to)
|
||||
urlStr := to.String()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", urlStr, bytes.NewReader(b))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
req.Header.Add("Content-Type", "application/ld+json; profile=\"https://www.w3.org/ns/activitystreams\"")
|
||||
req.Header.Add("Accept-Charset", "utf-8")
|
||||
req.Header.Add("User-Agent", t.controller.userAgent)
|
||||
req.Header.Set("Host", to.Host)
|
||||
|
||||
resp, err := t.POST(req, b)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if code := resp.StatusCode; code != http.StatusOK &&
|
||||
code != http.StatusCreated && code != http.StatusAccepted {
|
||||
return fmt.Errorf("POST request to %s failed (%d): %s", urlStr, resp.StatusCode, resp.Status)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -20,32 +20,55 @@ package transport
|
|||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/url"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/spf13/viper"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/config"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/uris"
|
||||
)
|
||||
|
||||
func (t *transport) Dereference(ctx context.Context, iri *url.URL) ([]byte, error) {
|
||||
l := logrus.WithField("func", "Dereference")
|
||||
|
||||
// if the request is to us, we can shortcut for certain URIs rather than going through
|
||||
// the normal request flow, thereby saving time and energy
|
||||
if iri.Host == viper.GetString(config.Keys.Host) {
|
||||
if uris.IsFollowersPath(iri) {
|
||||
// the request is for followers of one of our accounts, which we can shortcut
|
||||
return t.dereferenceFollowersShortcut(ctx, iri)
|
||||
return t.controller.dereferenceLocalFollowers(ctx, iri)
|
||||
}
|
||||
|
||||
if uris.IsUserPath(iri) {
|
||||
// the request is for one of our accounts, which we can shortcut
|
||||
return t.dereferenceUserShortcut(ctx, iri)
|
||||
return t.controller.dereferenceLocalUser(ctx, iri)
|
||||
}
|
||||
}
|
||||
|
||||
// the request is either for a remote host or for us but we don't have a shortcut, so continue as normal
|
||||
l.Debugf("performing GET to %s", iri.String())
|
||||
return t.sigTransport.Dereference(ctx, iri)
|
||||
// Build IRI just once
|
||||
iriStr := iri.String()
|
||||
|
||||
// Prepare new HTTP request to endpoint
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", iriStr, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Add("Accept", "application/ld+json; profile=\"https://www.w3.org/ns/activitystreams\"")
|
||||
req.Header.Add("Accept-Charset", "utf-8")
|
||||
req.Header.Add("User-Agent", t.controller.userAgent)
|
||||
req.Header.Set("Host", iri.Host)
|
||||
|
||||
// Perform the HTTP request
|
||||
rsp, err := t.GET(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rsp.Body.Close()
|
||||
|
||||
// Check for an expected status code
|
||||
if rsp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("GET request to %s failed (%d): %s", iriStr, rsp.StatusCode, rsp.Status)
|
||||
}
|
||||
|
||||
return ioutil.ReadAll(rsp.Body)
|
||||
}
|
||||
|
|
|
@ -80,43 +80,38 @@ func (t *transport) DereferenceInstance(ctx context.Context, iri *url.URL) (*gts
|
|||
}
|
||||
|
||||
func dereferenceByAPIV1Instance(ctx context.Context, t *transport, iri *url.URL) (*gtsmodel.Instance, error) {
|
||||
l := logrus.WithField("func", "dereferenceByAPIV1Instance")
|
||||
|
||||
cleanIRI := &url.URL{
|
||||
Scheme: iri.Scheme,
|
||||
Host: iri.Host,
|
||||
Path: "api/v1/instance",
|
||||
}
|
||||
|
||||
l.Debugf("performing GET to %s", cleanIRI.String())
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", cleanIRI.String(), nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Add("Accept", "application/json")
|
||||
req.Header.Add("Date", t.clock.Now().UTC().Format("Mon, 02 Jan 2006 15:04:05")+" GMT")
|
||||
req.Header.Add("User-Agent", fmt.Sprintf("%s %s", t.appAgent, t.gofedAgent))
|
||||
req.Header.Set("Host", cleanIRI.Host)
|
||||
t.getSignerMu.Lock()
|
||||
err = t.getSigner.SignRequest(t.privkey, t.pubKeyID, req, nil)
|
||||
t.getSignerMu.Unlock()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp, err := t.client.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("GET request to %s failed (%d): %s", cleanIRI.String(), resp.StatusCode, resp.Status)
|
||||
}
|
||||
b, err := ioutil.ReadAll(resp.Body)
|
||||
// Build IRI just once
|
||||
iriStr := cleanIRI.String()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", iriStr, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(b) == 0 {
|
||||
req.Header.Add("Accept", "application/json")
|
||||
req.Header.Add("User-Agent", t.controller.userAgent)
|
||||
req.Header.Set("Host", cleanIRI.Host)
|
||||
|
||||
resp, err := t.GET(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("GET request to %s failed (%d): %s", iriStr, resp.StatusCode, resp.Status)
|
||||
}
|
||||
|
||||
b, err := ioutil.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
} else if len(b) == 0 {
|
||||
return nil, errors.New("response bytes was len 0")
|
||||
}
|
||||
|
||||
|
@ -237,44 +232,37 @@ func dereferenceByNodeInfo(c context.Context, t *transport, iri *url.URL) (*gtsm
|
|||
}
|
||||
|
||||
func callNodeInfoWellKnown(ctx context.Context, t *transport, iri *url.URL) (*url.URL, error) {
|
||||
l := logrus.WithField("func", "callNodeInfoWellKnown")
|
||||
|
||||
cleanIRI := &url.URL{
|
||||
Scheme: iri.Scheme,
|
||||
Host: iri.Host,
|
||||
Path: ".well-known/nodeinfo",
|
||||
}
|
||||
|
||||
l.Debugf("performing GET to %s", cleanIRI.String())
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", cleanIRI.String(), nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Build IRI just once
|
||||
iriStr := cleanIRI.String()
|
||||
|
||||
req.Header.Add("Accept", "application/json")
|
||||
req.Header.Add("Date", t.clock.Now().UTC().Format("Mon, 02 Jan 2006 15:04:05")+" GMT")
|
||||
req.Header.Add("User-Agent", fmt.Sprintf("%s %s", t.appAgent, t.gofedAgent))
|
||||
req.Header.Set("Host", cleanIRI.Host)
|
||||
t.getSignerMu.Lock()
|
||||
err = t.getSigner.SignRequest(t.privkey, t.pubKeyID, req, nil)
|
||||
t.getSignerMu.Unlock()
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", iriStr, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp, err := t.client.Do(req)
|
||||
req.Header.Add("Accept", "application/json")
|
||||
req.Header.Add("User-Agent", t.controller.userAgent)
|
||||
req.Header.Set("Host", cleanIRI.Host)
|
||||
|
||||
resp, err := t.GET(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("callNodeInfoWellKnown: GET request to %s failed (%d): %s", cleanIRI.String(), resp.StatusCode, resp.Status)
|
||||
return nil, fmt.Errorf("callNodeInfoWellKnown: GET request to %s failed (%d): %s", iriStr, resp.StatusCode, resp.Status)
|
||||
}
|
||||
|
||||
b, err := ioutil.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(b) == 0 {
|
||||
} else if len(b) == 0 {
|
||||
return nil, errors.New("callNodeInfoWellKnown: response bytes was len 0")
|
||||
}
|
||||
|
||||
|
@ -302,38 +290,31 @@ func callNodeInfoWellKnown(ctx context.Context, t *transport, iri *url.URL) (*ur
|
|||
}
|
||||
|
||||
func callNodeInfo(ctx context.Context, t *transport, iri *url.URL) (*apimodel.Nodeinfo, error) {
|
||||
l := logrus.WithField("func", "callNodeInfo")
|
||||
// Build IRI just once
|
||||
iriStr := iri.String()
|
||||
|
||||
l.Debugf("performing GET to %s", iri.String())
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", iri.String(), nil)
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", iriStr, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req.Header.Add("Accept", "application/json")
|
||||
req.Header.Add("Date", t.clock.Now().UTC().Format("Mon, 02 Jan 2006 15:04:05")+" GMT")
|
||||
req.Header.Add("User-Agent", fmt.Sprintf("%s %s", t.appAgent, t.gofedAgent))
|
||||
req.Header.Add("User-Agent", t.controller.userAgent)
|
||||
req.Header.Set("Host", iri.Host)
|
||||
t.getSignerMu.Lock()
|
||||
err = t.getSigner.SignRequest(t.privkey, t.pubKeyID, req, nil)
|
||||
t.getSignerMu.Unlock()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp, err := t.client.Do(req)
|
||||
|
||||
resp, err := t.GET(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("callNodeInfo: GET request to %s failed (%d): %s", iri.String(), resp.StatusCode, resp.Status)
|
||||
return nil, fmt.Errorf("callNodeInfo: GET request to %s failed (%d): %s", iriStr, resp.StatusCode, resp.Status)
|
||||
}
|
||||
|
||||
b, err := ioutil.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(b) == 0 {
|
||||
} else if len(b) == 0 {
|
||||
return nil, errors.New("callNodeInfo: response bytes was len 0")
|
||||
}
|
||||
|
||||
|
|
|
@ -24,34 +24,31 @@ import (
|
|||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
func (t *transport) DereferenceMedia(ctx context.Context, iri *url.URL) (io.ReadCloser, int, error) {
|
||||
l := logrus.WithField("func", "DereferenceMedia")
|
||||
l.Debugf("performing GET to %s", iri.String())
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", iri.String(), nil)
|
||||
// Build IRI just once
|
||||
iriStr := iri.String()
|
||||
|
||||
// Prepare HTTP request to this media's IRI
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", iriStr, nil)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
req.Header.Add("Accept", "*/*") // we don't know what kind of media we're going to get here
|
||||
req.Header.Add("User-Agent", t.controller.userAgent)
|
||||
req.Header.Set("Host", iri.Host)
|
||||
|
||||
// Perform the HTTP request
|
||||
rsp, err := t.GET(req)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
req.Header.Add("Accept", "*/*") // we don't know what kind of media we're going to get here
|
||||
req.Header.Add("Date", t.clock.Now().UTC().Format("Mon, 02 Jan 2006 15:04:05")+" GMT")
|
||||
req.Header.Add("User-Agent", fmt.Sprintf("%s %s", t.appAgent, t.gofedAgent))
|
||||
req.Header.Set("Host", iri.Host)
|
||||
t.getSignerMu.Lock()
|
||||
err = t.getSigner.SignRequest(t.privkey, t.pubKeyID, req, nil)
|
||||
t.getSignerMu.Unlock()
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
// Check for an expected status code
|
||||
if rsp.StatusCode != http.StatusOK {
|
||||
return nil, 0, fmt.Errorf("GET request to %s failed (%d): %s", iriStr, rsp.StatusCode, rsp.Status)
|
||||
}
|
||||
resp, err := t.client.Do(req)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, 0, fmt.Errorf("GET request to %s failed (%d): %s", iri.String(), resp.StatusCode, resp.Status)
|
||||
}
|
||||
return resp.Body, int(resp.ContentLength), nil
|
||||
|
||||
return rsp.Body, int(rsp.ContentLength), nil
|
||||
}
|
||||
|
|
|
@ -23,46 +23,36 @@ import (
|
|||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/url"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
func (t *transport) Finger(ctx context.Context, targetUsername string, targetDomain string) ([]byte, error) {
|
||||
l := logrus.WithField("func", "Finger")
|
||||
urlString := fmt.Sprintf("https://%s/.well-known/webfinger?resource=acct:%s@%s", targetDomain, targetUsername, targetDomain)
|
||||
l.Debugf("performing GET to %s", urlString)
|
||||
// Prepare URL string
|
||||
urlStr := "https://" +
|
||||
targetDomain +
|
||||
"/.well-known/webfinger?resource=acct:" +
|
||||
targetUsername + "@" + targetDomain
|
||||
|
||||
iri, err := url.Parse(urlString)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Finger: error parsing url %s: %s", urlString, err)
|
||||
}
|
||||
|
||||
l.Debugf("performing GET to %s", iri.String())
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", iri.String(), nil)
|
||||
// Generate new GET request from URL string
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", urlStr, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req.Header.Add("Accept", "application/json")
|
||||
req.Header.Add("Accept", "application/jrd+json")
|
||||
req.Header.Add("Date", t.clock.Now().UTC().Format("Mon, 02 Jan 2006 15:04:05")+" GMT")
|
||||
req.Header.Add("User-Agent", fmt.Sprintf("%s %s", t.appAgent, t.gofedAgent))
|
||||
req.Header.Set("Host", iri.Host)
|
||||
t.getSignerMu.Lock()
|
||||
err = t.getSigner.SignRequest(t.privkey, t.pubKeyID, req, nil)
|
||||
t.getSignerMu.Unlock()
|
||||
req.Header.Add("User-Agent", t.controller.userAgent)
|
||||
req.Header.Set("Host", req.URL.Host)
|
||||
|
||||
// Perform the HTTP request
|
||||
rsp, err := t.GET(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp, err := t.client.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
defer rsp.Body.Close()
|
||||
|
||||
// Check for an expected status code
|
||||
if rsp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("GET request to %s failed (%d): %s", urlStr, rsp.StatusCode, rsp.Status)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("GET request to %s failed (%d): %s", iri.String(), resp.StatusCode, resp.Status)
|
||||
}
|
||||
return ioutil.ReadAll(resp.Body)
|
||||
|
||||
return ioutil.ReadAll(rsp.Body)
|
||||
}
|
||||
|
|
43
internal/transport/signing.go
Normal file
43
internal/transport/signing.go
Normal file
|
@ -0,0 +1,43 @@
|
|||
/*
|
||||
GoToSocial
|
||||
Copyright (C) 2021-2022 GoToSocial Authors admin@gotosocial.org
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU Affero General Public License as published by
|
||||
the Free Software Foundation, either version 3 of the License, or
|
||||
(at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU Affero General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU Affero General Public License
|
||||
along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package transport
|
||||
|
||||
import (
|
||||
"github.com/go-fed/httpsig"
|
||||
)
|
||||
|
||||
var (
|
||||
// http signer preferences
|
||||
prefs = []httpsig.Algorithm{httpsig.RSA_SHA256}
|
||||
digestAlgo = httpsig.DigestSha256
|
||||
getHeaders = []string{httpsig.RequestTarget, "host", "date"}
|
||||
postHeaders = []string{httpsig.RequestTarget, "host", "date", "digest"}
|
||||
)
|
||||
|
||||
// NewGETSigner returns a new httpsig.Signer instance initialized with GTS GET preferences.
|
||||
func NewGETSigner(expiresIn int64) (httpsig.Signer, error) {
|
||||
sig, _, err := httpsig.NewSigner(prefs, digestAlgo, getHeaders, httpsig.Signature, expiresIn)
|
||||
return sig, err
|
||||
}
|
||||
|
||||
// NewPOSTSigner returns a new httpsig.Signer instance initialized with GTS POST preferences.
|
||||
func NewPOSTSigner(expiresIn int64) (httpsig.Signer, error) {
|
||||
sig, _, err := httpsig.NewSigner(prefs, digestAlgo, postHeaders, httpsig.Signature, expiresIn)
|
||||
return sig, err
|
||||
}
|
|
@ -21,11 +21,18 @@ package transport
|
|||
import (
|
||||
"context"
|
||||
"crypto"
|
||||
"crypto/x509"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
errorsv2 "codeberg.org/gruf/go-errors/v2"
|
||||
"github.com/go-fed/httpsig"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/superseriousbusiness/activity/pub"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
|
||||
)
|
||||
|
@ -43,28 +50,148 @@ type Transport interface {
|
|||
DereferenceInstance(ctx context.Context, iri *url.URL) (*gtsmodel.Instance, error)
|
||||
// Finger performs a webfinger request with the given username and domain, and returns the bytes from the response body.
|
||||
Finger(ctx context.Context, targetUsername string, targetDomains string) ([]byte, error)
|
||||
// SigTransport returns the underlying http signature transport wrapped by the GoToSocial transport.
|
||||
SigTransport() pub.Transport
|
||||
}
|
||||
|
||||
// transport implements the Transport interface
|
||||
type transport struct {
|
||||
client pub.HttpClient
|
||||
appAgent string
|
||||
gofedAgent string
|
||||
clock pub.Clock
|
||||
pubKeyID string
|
||||
privkey crypto.PrivateKey
|
||||
sigTransport *pub.HttpSigTransport
|
||||
getSigner httpsig.Signer
|
||||
getSignerMu *sync.Mutex
|
||||
controller *controller
|
||||
pubKeyID string
|
||||
privkey crypto.PrivateKey
|
||||
|
||||
// shortcuts for dereferencing things that exist on our instance without making an http call to ourself
|
||||
|
||||
dereferenceFollowersShortcut func(ctx context.Context, iri *url.URL) ([]byte, error)
|
||||
dereferenceUserShortcut func(ctx context.Context, iri *url.URL) ([]byte, error)
|
||||
signerExp time.Time
|
||||
getSigner httpsig.Signer
|
||||
postSigner httpsig.Signer
|
||||
signerMu sync.Mutex
|
||||
}
|
||||
|
||||
func (t *transport) SigTransport() pub.Transport {
|
||||
return t.sigTransport
|
||||
// GET will perform given http request using transport client, retrying on certain preset errors, or if status code is among retryOn.
|
||||
func (t *transport) GET(r *http.Request, retryOn ...int) (*http.Response, error) {
|
||||
if r.Method != http.MethodGet {
|
||||
return nil, errors.New("must be GET request")
|
||||
}
|
||||
return t.do(r, func(r *http.Request) error {
|
||||
return t.signGET(r)
|
||||
}, retryOn...)
|
||||
}
|
||||
|
||||
// POST will perform given http request using transport client, retrying on certain preset errors, or if status code is among retryOn.
|
||||
func (t *transport) POST(r *http.Request, body []byte, retryOn ...int) (*http.Response, error) {
|
||||
if r.Method != http.MethodPost {
|
||||
return nil, errors.New("must be POST request")
|
||||
}
|
||||
return t.do(r, func(r *http.Request) error {
|
||||
return t.signPOST(r, body)
|
||||
}, retryOn...)
|
||||
}
|
||||
|
||||
func (t *transport) do(r *http.Request, signer func(*http.Request) error, retryOn ...int) (*http.Response, error) {
|
||||
const maxRetries = 5
|
||||
backoff := time.Second * 2
|
||||
|
||||
// Start a log entry for this request
|
||||
l := logrus.WithFields(logrus.Fields{
|
||||
"pubKeyID": t.pubKeyID,
|
||||
"method": r.Method,
|
||||
"url": r.URL.String(),
|
||||
})
|
||||
|
||||
for i := 0; i < maxRetries; i++ {
|
||||
// Reset signing header fields
|
||||
now := t.controller.clock.Now().UTC()
|
||||
r.Header.Set("Date", now.Format("Mon, 02 Jan 2006 15:04:05")+" GMT")
|
||||
r.Header.Del("Signature")
|
||||
r.Header.Del("Digest")
|
||||
|
||||
// Perform request signing
|
||||
if err := signer(r); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
l.Infof("performing request")
|
||||
|
||||
// Attempt to perform request
|
||||
rsp, err := t.controller.client.Do(r)
|
||||
if err == nil { //nolint shutup linter
|
||||
// TooManyRequest means we need to slow
|
||||
// down and retry our request. Codes over
|
||||
// 500 generally indicate temp. outages.
|
||||
if code := rsp.StatusCode; code < 500 &&
|
||||
code != http.StatusTooManyRequests &&
|
||||
!containsInt(retryOn, rsp.StatusCode) {
|
||||
return rsp, nil
|
||||
}
|
||||
|
||||
// Generate error from status code for logging
|
||||
err = errors.New(`http response "` + rsp.Status + `"`)
|
||||
} else if errorsv2.Is(err, context.DeadlineExceeded, context.Canceled) {
|
||||
// Return early if context has cancelled
|
||||
return nil, err
|
||||
} else if strings.Contains(err.Error(), "stopped after 10 redirects") {
|
||||
// Don't bother if net/http returned after too many redirects
|
||||
return nil, err
|
||||
} else if errors.As(err, &x509.UnknownAuthorityError{}) {
|
||||
// Unknown authority errors we do NOT recover from
|
||||
return nil, err
|
||||
}
|
||||
|
||||
l.Errorf("backing off for %s after http request error: %v", backoff.String(), err)
|
||||
|
||||
select {
|
||||
// Request ctx cancelled
|
||||
case <-r.Context().Done():
|
||||
return nil, r.Context().Err()
|
||||
|
||||
// Backoff for some time
|
||||
case <-time.After(backoff):
|
||||
backoff *= 2
|
||||
}
|
||||
}
|
||||
|
||||
return nil, errors.New("transport reached max retries")
|
||||
}
|
||||
|
||||
// signGET will safely sign an HTTP GET request.
|
||||
func (t *transport) signGET(r *http.Request) (err error) {
|
||||
t.safesign(func() {
|
||||
err = t.getSigner.SignRequest(t.privkey, t.pubKeyID, r, nil)
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// signPOST will safely sign an HTTP POST request for given body.
|
||||
func (t *transport) signPOST(r *http.Request, body []byte) (err error) {
|
||||
t.safesign(func() {
|
||||
err = t.postSigner.SignRequest(t.privkey, t.pubKeyID, r, body)
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// safesign will perform sign function within mutex protection,
|
||||
// and ensured that httpsig.Signers are up-to-date.
|
||||
func (t *transport) safesign(sign func()) {
|
||||
// Perform within mu safety
|
||||
t.signerMu.Lock()
|
||||
defer t.signerMu.Unlock()
|
||||
|
||||
if now := time.Now(); now.After(t.signerExp) {
|
||||
const expiry = 120
|
||||
|
||||
// Signers have expired and require renewal
|
||||
t.getSigner, _ = NewGETSigner(expiry)
|
||||
t.postSigner, _ = NewPOSTSigner(expiry)
|
||||
t.signerExp = now.Add(time.Second * expiry)
|
||||
}
|
||||
|
||||
// Perform signing
|
||||
sign()
|
||||
}
|
||||
|
||||
// containsInt checks if slice contains check.
|
||||
func containsInt(slice []int, check int) bool {
|
||||
for _, i := range slice {
|
||||
if i == check {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
|
|
@ -1,13 +1,13 @@
|
|||
package testrig
|
||||
|
||||
import (
|
||||
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/federation/federatingdb"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/messages"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/worker"
|
||||
)
|
||||
|
||||
// NewTestFederatingDB returns a federating DB with the underlying db
|
||||
func NewTestFederatingDB(db db.DB, fedWorker *worker.Worker[messages.FromFederator]) federatingdb.DB {
|
||||
func NewTestFederatingDB(db db.DB, fedWorker *concurrency.WorkerPool[messages.FromFederator]) federatingdb.DB {
|
||||
return federatingdb.New(db, fedWorker)
|
||||
}
|
||||
|
|
|
@ -20,15 +20,15 @@ package testrig
|
|||
|
||||
import (
|
||||
"codeberg.org/gruf/go-store/kv"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/federation"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/media"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/messages"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/transport"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/worker"
|
||||
)
|
||||
|
||||
// NewTestFederator returns a federator with the given database and (mock!!) transport controller.
|
||||
func NewTestFederator(db db.DB, tc transport.Controller, storage *kv.KVStore, mediaManager media.Manager, fedWorker *worker.Worker[messages.FromFederator]) federation.Federator {
|
||||
func NewTestFederator(db db.DB, tc transport.Controller, storage *kv.KVStore, mediaManager media.Manager, fedWorker *concurrency.WorkerPool[messages.FromFederator]) federation.Federator {
|
||||
return federation.NewFederator(db, NewTestFederatingDB(db, fedWorker), tc, NewTestTypeConverter(db), mediaManager)
|
||||
}
|
||||
|
|
|
@ -20,16 +20,16 @@ package testrig
|
|||
|
||||
import (
|
||||
"codeberg.org/gruf/go-store/kv"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/email"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/federation"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/media"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/messages"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/processing"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/worker"
|
||||
)
|
||||
|
||||
// NewTestProcessor returns a Processor suitable for testing purposes
|
||||
func NewTestProcessor(db db.DB, storage *kv.KVStore, federator federation.Federator, emailSender email.Sender, mediaManager media.Manager, clientWorker *worker.Worker[messages.FromClientAPI], fedWorker *worker.Worker[messages.FromFederator]) processing.Processor {
|
||||
func NewTestProcessor(db db.DB, storage *kv.KVStore, federator federation.Federator, emailSender email.Sender, mediaManager media.Manager, clientWorker *concurrency.WorkerPool[messages.FromClientAPI], fedWorker *concurrency.WorkerPool[messages.FromFederator]) processing.Processor {
|
||||
return processing.NewProcessor(NewTestTypeConverter(db), federator, NewTestOauthServer(db), mediaManager, storage, db, emailSender, clientWorker, fedWorker)
|
||||
}
|
||||
|
|
|
@ -20,8 +20,6 @@ package testrig
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/x509"
|
||||
|
@ -29,7 +27,6 @@ import (
|
|||
"encoding/json"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
|
@ -42,8 +39,7 @@ import (
|
|||
"github.com/superseriousbusiness/activity/streams/vocab"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/ap"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/messages"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/worker"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/transport"
|
||||
)
|
||||
|
||||
// NewTestTokens returns a map of tokens keyed according to which account the token belongs to.
|
||||
|
@ -1855,86 +1851,71 @@ func NewTestDereferenceRequests(accounts map[string]*gtsmodel.Account) map[strin
|
|||
}
|
||||
}
|
||||
|
||||
// GetSignatureForActivity does some sneaky sneaky work with a mock http client and a test transport controller, in order to derive
|
||||
// the HTTP Signature for the given activity, public key ID, private key, and destination.
|
||||
func GetSignatureForActivity(activity pub.Activity, pubKeyID string, privkey crypto.PrivateKey, destination *url.URL) (signatureHeader string, digestHeader string, dateHeader string) {
|
||||
// create a client that basically just pulls the signature out of the request and sets it
|
||||
client := &mockHTTPClient{
|
||||
do: func(req *http.Request) (*http.Response, error) {
|
||||
signatureHeader = req.Header.Get("Signature")
|
||||
digestHeader = req.Header.Get("Digest")
|
||||
dateHeader = req.Header.Get("Date")
|
||||
r := ioutil.NopCloser(bytes.NewReader([]byte{})) // we only need this so the 'close' func doesn't nil out
|
||||
return &http.Response{
|
||||
StatusCode: 200,
|
||||
Body: r,
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
|
||||
// Create temporary federator worker for transport controller
|
||||
fedWorker := worker.New[messages.FromFederator](-1, -1)
|
||||
_ = fedWorker.Start()
|
||||
defer func() { _ = fedWorker.Stop() }()
|
||||
|
||||
// use the client to create a new transport
|
||||
c := NewTestTransportController(client, NewTestDB(), fedWorker)
|
||||
tp, err := c.NewTransport(pubKeyID, privkey)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// GetSignatureForActivity prepares a mock HTTP request as if it were going to deliver activity to destination signed for privkey and pubKeyID, signs the request and returns the header values.
|
||||
func GetSignatureForActivity(activity pub.Activity, pubKeyID string, privkey *rsa.PrivateKey, destination *url.URL) (signatureHeader string, digestHeader string, dateHeader string) {
|
||||
// convert the activity into json bytes
|
||||
m, err := activity.Serialize()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
bytes, err := json.Marshal(m)
|
||||
b, err := json.Marshal(m)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// trigger the delivery function for the underlying signature transport, which will trigger the 'do' function of the recorder above
|
||||
if err := tp.SigTransport().Deliver(context.Background(), bytes, destination); err != nil {
|
||||
// Prepare HTTP request signer
|
||||
sig, err := transport.NewPOSTSigner(120)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// Prepare a mock request ready for signing
|
||||
r, err := http.NewRequest("POST", destination.String(), bytes.NewReader(b))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
r.Header.Set("Host", destination.Host)
|
||||
r.Header.Set("Date", time.Now().Format("Mon, 02 Jan 2006 15:04:05")+" GMT")
|
||||
|
||||
// Sign this new HTTP request
|
||||
if err := sig.SignRequest(privkey, pubKeyID, r, b); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// Load signed data from request
|
||||
signatureHeader = r.Header.Get("Signature")
|
||||
digestHeader = r.Header.Get("Digest")
|
||||
dateHeader = r.Header.Get("Date")
|
||||
|
||||
// headers should now be populated
|
||||
return
|
||||
}
|
||||
|
||||
// GetSignatureForDereference does some sneaky sneaky work with a mock http client and a test transport controller, in order to derive
|
||||
// the HTTP Signature for the given derefence GET request using public key ID, private key, and destination.
|
||||
func GetSignatureForDereference(pubKeyID string, privkey crypto.PrivateKey, destination *url.URL) (signatureHeader string, digestHeader string, dateHeader string) {
|
||||
// create a client that basically just pulls the signature out of the request and sets it
|
||||
client := &mockHTTPClient{
|
||||
do: func(req *http.Request) (*http.Response, error) {
|
||||
signatureHeader = req.Header.Get("Signature")
|
||||
dateHeader = req.Header.Get("Date")
|
||||
r := ioutil.NopCloser(bytes.NewReader([]byte{})) // we only need this so the 'close' func doesn't nil out
|
||||
return &http.Response{
|
||||
StatusCode: 200,
|
||||
Body: r,
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
|
||||
// Create temporary federator worker for transport controller
|
||||
fedWorker := worker.New[messages.FromFederator](-1, -1)
|
||||
_ = fedWorker.Start()
|
||||
defer func() { _ = fedWorker.Stop() }()
|
||||
|
||||
// use the client to create a new transport
|
||||
c := NewTestTransportController(client, NewTestDB(), fedWorker)
|
||||
tp, err := c.NewTransport(pubKeyID, privkey)
|
||||
// GetSignatureForDereference prepares a mock HTTP request as if it were going to dereference destination signed for privkey and pubKeyID, signs the request and returns the header values.
|
||||
func GetSignatureForDereference(pubKeyID string, privkey *rsa.PrivateKey, destination *url.URL) (signatureHeader string, digestHeader string, dateHeader string) {
|
||||
// Prepare HTTP request signer
|
||||
sig, err := transport.NewGETSigner(120)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// trigger the dereference function for the underlying signature transport, which will trigger the 'do' function of the recorder above
|
||||
if _, err := tp.SigTransport().Dereference(context.Background(), destination); err != nil {
|
||||
// Prepare a mock request ready for signing
|
||||
r, err := http.NewRequest("GET", destination.String(), nil)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
r.Header.Set("Host", destination.Host)
|
||||
r.Header.Set("Date", time.Now().Format("Mon, 02 Jan 2006 15:04:05")+" GMT")
|
||||
|
||||
// Sign this new HTTP request
|
||||
if err := sig.SignRequest(privkey, pubKeyID, r, nil); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// Load signed data from request
|
||||
signatureHeader = r.Header.Get("Signature")
|
||||
digestHeader = r.Header.Get("Digest")
|
||||
dateHeader = r.Header.Get("Date")
|
||||
|
||||
// headers should now be populated
|
||||
return
|
||||
|
|
|
@ -24,11 +24,11 @@ import (
|
|||
"net/http"
|
||||
|
||||
"github.com/superseriousbusiness/activity/pub"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/federation"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/messages"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/transport"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/worker"
|
||||
)
|
||||
|
||||
// NewTestTransportController returns a test transport controller with the given http client.
|
||||
|
@ -40,7 +40,7 @@ import (
|
|||
// Unlike the other test interfaces provided in this package, you'll probably want to call this function
|
||||
// PER TEST rather than per suite, so that the do function can be set on a test by test (or even more granular)
|
||||
// basis.
|
||||
func NewTestTransportController(client pub.HttpClient, db db.DB, fedWorker *worker.Worker[messages.FromFederator]) transport.Controller {
|
||||
func NewTestTransportController(client pub.HttpClient, db db.DB, fedWorker *concurrency.WorkerPool[messages.FromFederator]) transport.Controller {
|
||||
return transport.NewController(db, NewTestFederatingDB(db, fedWorker), &federation.Clock{}, client)
|
||||
}
|
||||
|
||||
|
|
19
vendor/codeberg.org/gruf/go-byteutil/bytes.go
generated
vendored
19
vendor/codeberg.org/gruf/go-byteutil/bytes.go
generated
vendored
|
@ -16,11 +16,30 @@ func Copy(b []byte) []byte {
|
|||
}
|
||||
|
||||
// B2S returns a string representation of []byte without allocation.
|
||||
//
|
||||
// According to the Go spec strings are immutable and byte slices are not. The way this gets implemented is strings under the hood are:
|
||||
// type StringHeader struct {
|
||||
// Data uintptr
|
||||
// Len int
|
||||
// }
|
||||
//
|
||||
// while slices are:
|
||||
// type SliceHeader struct {
|
||||
// Data uintptr
|
||||
// Len int
|
||||
// Cap int
|
||||
// }
|
||||
// because being mutable, you can change the data, length etc, but the string has to promise to be read-only to all who get copies of it.
|
||||
//
|
||||
// So in practice when you do a conversion of `string(byteSlice)` it actually performs an allocation because it has to copy the contents of the byte slice into a safe read-only state.
|
||||
//
|
||||
// Being that the shared fields are in the same struct indices (no different offsets), means that if you have a byte slice you can "forcibly" cast it to a string. Which in a lot of situations can be risky, because then it means you have a string that is NOT immutable, as if someone changes the data in the originating byte slice then the string will reflect that change! Now while this does seem hacky, and it _kind_ of is, it is something that you see performed in the standard library. If you look at the definition for `strings.Builder{}.String()` you'll see this :)
|
||||
func B2S(b []byte) string {
|
||||
return *(*string)(unsafe.Pointer(&b))
|
||||
}
|
||||
|
||||
// S2B returns a []byte representation of string without allocation (minus slice header).
|
||||
// See B2S() code comment, and this function's implementation for a better understanding.
|
||||
func S2B(s string) []byte {
|
||||
var b []byte
|
||||
|
||||
|
|
9
vendor/codeberg.org/gruf/go-cache/v2/LICENSE
generated
vendored
Normal file
9
vendor/codeberg.org/gruf/go-cache/v2/LICENSE
generated
vendored
Normal file
|
@ -0,0 +1,9 @@
|
|||
MIT License
|
||||
|
||||
Copyright (c) 2021 gruf
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
3
vendor/codeberg.org/gruf/go-cache/v2/README.md
generated
vendored
Normal file
3
vendor/codeberg.org/gruf/go-cache/v2/README.md
generated
vendored
Normal file
|
@ -0,0 +1,3 @@
|
|||
# go-cache
|
||||
|
||||
A TTL cache designed to be used as a base for your own customizations, or used straight out of the box
|
67
vendor/codeberg.org/gruf/go-cache/v2/cache.go
generated
vendored
Normal file
67
vendor/codeberg.org/gruf/go-cache/v2/cache.go
generated
vendored
Normal file
|
@ -0,0 +1,67 @@
|
|||
package cache
|
||||
|
||||
import "time"
|
||||
|
||||
// Cache represents a TTL cache with customizable callbacks, it
|
||||
// exists here to abstract away the "unsafe" methods in the case that
|
||||
// you do not want your own implementation atop TTLCache{}.
|
||||
type Cache[Key comparable, Value any] interface {
|
||||
// Start will start the cache background eviction routine with given sweep frequency.
|
||||
// If already running or a freq <= 0 provided, this is a no-op. This will block until
|
||||
// the eviction routine has started
|
||||
Start(freq time.Duration) bool
|
||||
|
||||
// Stop will stop cache background eviction routine. If not running this is a no-op. This
|
||||
// will block until the eviction routine has stopped
|
||||
Stop() bool
|
||||
|
||||
// SetEvictionCallback sets the eviction callback to the provided hook
|
||||
SetEvictionCallback(hook Hook[Key, Value])
|
||||
|
||||
// SetInvalidateCallback sets the invalidate callback to the provided hook
|
||||
SetInvalidateCallback(hook Hook[Key, Value])
|
||||
|
||||
// SetTTL sets the cache item TTL. Update can be specified to force updates of existing items in
|
||||
// the cache, this will simply add the change in TTL to their current expiry time
|
||||
SetTTL(ttl time.Duration, update bool)
|
||||
|
||||
// Get fetches the value with key from the cache, extending its TTL
|
||||
Get(key Key) (value Value, ok bool)
|
||||
|
||||
// Put attempts to place the value at key in the cache, doing nothing if
|
||||
// a value with this key already exists. Returned bool is success state
|
||||
Put(key Key, value Value) bool
|
||||
|
||||
// Set places the value at key in the cache. This will overwrite any
|
||||
// existing value, and call the update callback so. Existing values
|
||||
// will have their TTL extended upon update
|
||||
Set(key Key, value Value)
|
||||
|
||||
// CAS will attempt to perform a CAS operation on 'key', using provided
|
||||
// comparison and swap values. Returned bool is success.
|
||||
CAS(key Key, cmp, swp Value) bool
|
||||
|
||||
// Swap will attempt to perform a swap on 'key', replacing the value there
|
||||
// and returning the existing value. If no value exists for key, this will
|
||||
// set the value and return the zero value for V.
|
||||
Swap(key Key, swp Value) Value
|
||||
|
||||
// Has checks the cache for a value with key, this will not update TTL
|
||||
Has(key Key) bool
|
||||
|
||||
// Invalidate deletes a value from the cache, calling the invalidate callback
|
||||
Invalidate(key Key) bool
|
||||
|
||||
// Clear empties the cache, calling the invalidate callback
|
||||
Clear()
|
||||
|
||||
// Size returns the current size of the cache
|
||||
Size() int
|
||||
}
|
||||
|
||||
// New returns a new initialized Cache.
|
||||
func New[K comparable, V any]() Cache[K, V] {
|
||||
c := TTLCache[K, V]{}
|
||||
c.Init()
|
||||
return &c
|
||||
}
|
23
vendor/codeberg.org/gruf/go-cache/v2/compare.go
generated
vendored
Normal file
23
vendor/codeberg.org/gruf/go-cache/v2/compare.go
generated
vendored
Normal file
|
@ -0,0 +1,23 @@
|
|||
package cache
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
)
|
||||
|
||||
type Comparable interface {
|
||||
Equal(any) bool
|
||||
}
|
||||
|
||||
// Compare returns whether 2 values are equal using the Comparable
|
||||
// interface, or failing that falls back to use reflect.DeepEqual().
|
||||
func Compare(i1, i2 any) bool {
|
||||
c1, ok1 := i1.(Comparable)
|
||||
if ok1 {
|
||||
return c1.Equal(i2)
|
||||
}
|
||||
c2, ok2 := i2.(Comparable)
|
||||
if ok2 {
|
||||
return c2.Equal(i1)
|
||||
}
|
||||
return reflect.DeepEqual(i1, i2)
|
||||
}
|
6
vendor/codeberg.org/gruf/go-cache/v2/hook.go
generated
vendored
Normal file
6
vendor/codeberg.org/gruf/go-cache/v2/hook.go
generated
vendored
Normal file
|
@ -0,0 +1,6 @@
|
|||
package cache
|
||||
|
||||
// Hook defines a function hook that can be supplied as a callback.
|
||||
type Hook[Key comparable, Value any] func(key Key, value Value)
|
||||
|
||||
func emptyHook[K comparable, V any](K, V) {}
|
214
vendor/codeberg.org/gruf/go-cache/v2/lookup.go
generated
vendored
Normal file
214
vendor/codeberg.org/gruf/go-cache/v2/lookup.go
generated
vendored
Normal file
|
@ -0,0 +1,214 @@
|
|||
package cache
|
||||
|
||||
// LookupCfg is the LookupCache configuration.
|
||||
type LookupCfg[OGKey, AltKey comparable, Value any] struct {
|
||||
// RegisterLookups is called on init to register lookups
|
||||
// within LookupCache's internal LookupMap
|
||||
RegisterLookups func(*LookupMap[OGKey, AltKey])
|
||||
|
||||
// AddLookups is called on each addition to the cache, to
|
||||
// set any required additional key lookups for supplied item
|
||||
AddLookups func(*LookupMap[OGKey, AltKey], Value)
|
||||
|
||||
// DeleteLookups is called on each eviction/invalidation of
|
||||
// an item in the cache, to remove any unused key lookups
|
||||
DeleteLookups func(*LookupMap[OGKey, AltKey], Value)
|
||||
}
|
||||
|
||||
// LookupCache is a cache built on-top of TTLCache, providing multi-key
|
||||
// lookups for items in the cache by means of additional lookup maps. These
|
||||
// maps simply store additional keys => original key, with hook-ins to automatically
|
||||
// call user supplied functions on adding an item, or on updating/deleting an
|
||||
// item to keep the LookupMap up-to-date.
|
||||
type LookupCache[OGKey, AltKey comparable, Value any] interface {
|
||||
Cache[OGKey, Value]
|
||||
|
||||
// GetBy fetches a cached value by supplied lookup identifier and key
|
||||
GetBy(lookup string, key AltKey) (value Value, ok bool)
|
||||
|
||||
// CASBy will attempt to perform a CAS operation on supplied lookup identifier and key
|
||||
CASBy(lookup string, key AltKey, cmp, swp Value) bool
|
||||
|
||||
// SwapBy will attempt to perform a swap operation on supplied lookup identifier and key
|
||||
SwapBy(lookup string, key AltKey, swp Value) Value
|
||||
|
||||
// HasBy checks if a value is cached under supplied lookup identifier and key
|
||||
HasBy(lookup string, key AltKey) bool
|
||||
|
||||
// InvalidateBy invalidates a value by supplied lookup identifier and key
|
||||
InvalidateBy(lookup string, key AltKey) bool
|
||||
}
|
||||
|
||||
type lookupTTLCache[OK, AK comparable, V any] struct {
|
||||
config LookupCfg[OK, AK, V]
|
||||
lookup LookupMap[OK, AK]
|
||||
TTLCache[OK, V]
|
||||
}
|
||||
|
||||
// NewLookup returns a new initialized LookupCache.
|
||||
func NewLookup[OK, AK comparable, V any](cfg LookupCfg[OK, AK, V]) LookupCache[OK, AK, V] {
|
||||
switch {
|
||||
case cfg.RegisterLookups == nil:
|
||||
panic("cache: nil lookups register function")
|
||||
case cfg.AddLookups == nil:
|
||||
panic("cache: nil lookups add function")
|
||||
case cfg.DeleteLookups == nil:
|
||||
panic("cache: nil delete lookups function")
|
||||
}
|
||||
c := lookupTTLCache[OK, AK, V]{config: cfg}
|
||||
c.TTLCache.Init()
|
||||
c.lookup.lookup = make(map[string]map[AK]OK)
|
||||
c.config.RegisterLookups(&c.lookup)
|
||||
c.SetEvictionCallback(nil)
|
||||
c.SetInvalidateCallback(nil)
|
||||
c.lookup.initd = true
|
||||
return &c
|
||||
}
|
||||
|
||||
func (c *lookupTTLCache[OK, AK, V]) SetEvictionCallback(hook Hook[OK, V]) {
|
||||
if hook == nil {
|
||||
hook = emptyHook[OK, V]
|
||||
}
|
||||
c.TTLCache.SetEvictionCallback(func(key OK, value V) {
|
||||
hook(key, value)
|
||||
c.config.DeleteLookups(&c.lookup, value)
|
||||
})
|
||||
}
|
||||
|
||||
func (c *lookupTTLCache[OK, AK, V]) SetInvalidateCallback(hook Hook[OK, V]) {
|
||||
if hook == nil {
|
||||
hook = emptyHook[OK, V]
|
||||
}
|
||||
c.TTLCache.SetInvalidateCallback(func(key OK, value V) {
|
||||
hook(key, value)
|
||||
c.config.DeleteLookups(&c.lookup, value)
|
||||
})
|
||||
}
|
||||
|
||||
func (c *lookupTTLCache[OK, AK, V]) GetBy(lookup string, key AK) (V, bool) {
|
||||
c.Lock()
|
||||
origKey, ok := c.lookup.Get(lookup, key)
|
||||
if !ok {
|
||||
c.Unlock()
|
||||
var value V
|
||||
return value, false
|
||||
}
|
||||
v, ok := c.GetUnsafe(origKey)
|
||||
c.Unlock()
|
||||
return v, ok
|
||||
}
|
||||
|
||||
func (c *lookupTTLCache[OK, AK, V]) Put(key OK, value V) bool {
|
||||
c.Lock()
|
||||
put := c.PutUnsafe(key, value)
|
||||
if put {
|
||||
c.config.AddLookups(&c.lookup, value)
|
||||
}
|
||||
c.Unlock()
|
||||
return put
|
||||
}
|
||||
|
||||
func (c *lookupTTLCache[OK, AK, V]) Set(key OK, value V) {
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
c.SetUnsafe(key, value)
|
||||
c.config.AddLookups(&c.lookup, value)
|
||||
}
|
||||
|
||||
func (c *lookupTTLCache[OK, AK, V]) CASBy(lookup string, key AK, cmp, swp V) bool {
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
origKey, ok := c.lookup.Get(lookup, key)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
return c.CASUnsafe(origKey, cmp, swp)
|
||||
}
|
||||
|
||||
func (c *lookupTTLCache[OK, AK, V]) SwapBy(lookup string, key AK, swp V) V {
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
origKey, ok := c.lookup.Get(lookup, key)
|
||||
if !ok {
|
||||
var value V
|
||||
return value
|
||||
}
|
||||
return c.SwapUnsafe(origKey, swp)
|
||||
}
|
||||
|
||||
func (c *lookupTTLCache[OK, AK, V]) HasBy(lookup string, key AK) bool {
|
||||
c.Lock()
|
||||
has := c.lookup.Has(lookup, key)
|
||||
c.Unlock()
|
||||
return has
|
||||
}
|
||||
|
||||
func (c *lookupTTLCache[OK, AK, V]) InvalidateBy(lookup string, key AK) bool {
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
origKey, ok := c.lookup.Get(lookup, key)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
c.InvalidateUnsafe(origKey)
|
||||
return true
|
||||
}
|
||||
|
||||
// LookupMap is a structure that provides lookups for
|
||||
// keys to primary keys under supplied lookup identifiers.
|
||||
// This is essentially a wrapper around map[string](map[K1]K2).
|
||||
type LookupMap[OK comparable, AK comparable] struct {
|
||||
initd bool
|
||||
lookup map[string](map[AK]OK)
|
||||
}
|
||||
|
||||
// RegisterLookup registers a lookup identifier in the LookupMap,
|
||||
// note this can only be doing during the cfg.RegisterLookups() hook.
|
||||
func (l *LookupMap[OK, AK]) RegisterLookup(id string) {
|
||||
if l.initd {
|
||||
panic("cache: cannot register lookup after initialization")
|
||||
} else if _, ok := l.lookup[id]; ok {
|
||||
panic("cache: lookup mapping already exists for identifier")
|
||||
}
|
||||
l.lookup[id] = make(map[AK]OK, 100)
|
||||
}
|
||||
|
||||
// Get fetches an entry's primary key for lookup identifier and key.
|
||||
func (l *LookupMap[OK, AK]) Get(id string, key AK) (OK, bool) {
|
||||
keys, ok := l.lookup[id]
|
||||
if !ok {
|
||||
var key OK
|
||||
return key, false
|
||||
}
|
||||
origKey, ok := keys[key]
|
||||
return origKey, ok
|
||||
}
|
||||
|
||||
// Set adds a lookup to the LookupMap under supplied lookup identifier,
|
||||
// linking supplied key to the supplied primary (original) key.
|
||||
func (l *LookupMap[OK, AK]) Set(id string, key AK, origKey OK) {
|
||||
keys, ok := l.lookup[id]
|
||||
if !ok {
|
||||
panic("cache: invalid lookup identifier")
|
||||
}
|
||||
keys[key] = origKey
|
||||
}
|
||||
|
||||
// Has checks if there exists a lookup for supplied identifier and key.
|
||||
func (l *LookupMap[OK, AK]) Has(id string, key AK) bool {
|
||||
keys, ok := l.lookup[id]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
_, ok = keys[key]
|
||||
return ok
|
||||
}
|
||||
|
||||
// Delete removes a lookup from LookupMap with supplied identifier and key.
|
||||
func (l *LookupMap[OK, AK]) Delete(id string, key AK) {
|
||||
keys, ok := l.lookup[id]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
delete(keys, key)
|
||||
}
|
333
vendor/codeberg.org/gruf/go-cache/v2/ttl.go
generated
vendored
Normal file
333
vendor/codeberg.org/gruf/go-cache/v2/ttl.go
generated
vendored
Normal file
|
@ -0,0 +1,333 @@
|
|||
package cache
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"codeberg.org/gruf/go-runners"
|
||||
)
|
||||
|
||||
// TTLCache is the underlying Cache implementation, providing both the base
|
||||
// Cache interface and access to "unsafe" methods so that you may build your
|
||||
// customized caches ontop of this structure.
|
||||
type TTLCache[Key comparable, Value any] struct {
|
||||
cache map[Key](*entry[Value])
|
||||
evict Hook[Key, Value] // the evict hook is called when an item is evicted from the cache, includes manual delete
|
||||
invalid Hook[Key, Value] // the invalidate hook is called when an item's data in the cache is invalidated
|
||||
ttl time.Duration // ttl is the item TTL
|
||||
svc runners.Service // svc manages running of the cache eviction routine
|
||||
mu sync.Mutex // mu protects TTLCache for concurrent access
|
||||
}
|
||||
|
||||
// Init performs Cache initialization, this MUST be called.
|
||||
func (c *TTLCache[K, V]) Init() {
|
||||
c.cache = make(map[K](*entry[V]), 100)
|
||||
c.evict = emptyHook[K, V]
|
||||
c.invalid = emptyHook[K, V]
|
||||
c.ttl = time.Minute * 5
|
||||
}
|
||||
|
||||
func (c *TTLCache[K, V]) Start(freq time.Duration) bool {
|
||||
// Nothing to start
|
||||
if freq <= 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
// Track state of starting
|
||||
done := make(chan struct{})
|
||||
started := false
|
||||
|
||||
go func() {
|
||||
ran := c.svc.Run(func(ctx context.Context) {
|
||||
// Successfully started
|
||||
started = true
|
||||
close(done)
|
||||
|
||||
// start routine
|
||||
c.run(ctx, freq)
|
||||
})
|
||||
|
||||
// failed to start
|
||||
if !ran {
|
||||
close(done)
|
||||
}
|
||||
}()
|
||||
|
||||
<-done
|
||||
return started
|
||||
}
|
||||
|
||||
func (c *TTLCache[K, V]) Stop() bool {
|
||||
return c.svc.Stop()
|
||||
}
|
||||
|
||||
func (c *TTLCache[K, V]) run(ctx context.Context, freq time.Duration) {
|
||||
t := time.NewTimer(freq)
|
||||
for {
|
||||
select {
|
||||
// we got stopped
|
||||
case <-ctx.Done():
|
||||
if !t.Stop() {
|
||||
<-t.C
|
||||
}
|
||||
return
|
||||
|
||||
// next tick
|
||||
case <-t.C:
|
||||
c.sweep()
|
||||
t.Reset(freq)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// sweep attempts to evict expired items (with callback!) from cache.
|
||||
func (c *TTLCache[K, V]) sweep() {
|
||||
// Lock and defer unlock (in case of hook panic)
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
// Fetch current time for TTL check
|
||||
now := time.Now()
|
||||
|
||||
// Sweep the cache for old items!
|
||||
for key, item := range c.cache {
|
||||
if now.After(item.expiry) {
|
||||
c.evict(key, item.value)
|
||||
delete(c.cache, key)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Lock locks the cache mutex.
|
||||
func (c *TTLCache[K, V]) Lock() {
|
||||
c.mu.Lock()
|
||||
}
|
||||
|
||||
// Unlock unlocks the cache mutex.
|
||||
func (c *TTLCache[K, V]) Unlock() {
|
||||
c.mu.Unlock()
|
||||
}
|
||||
|
||||
func (c *TTLCache[K, V]) SetEvictionCallback(hook Hook[K, V]) {
|
||||
// Ensure non-nil hook
|
||||
if hook == nil {
|
||||
hook = emptyHook[K, V]
|
||||
}
|
||||
|
||||
// Safely set evict hook
|
||||
c.Lock()
|
||||
c.evict = hook
|
||||
c.Unlock()
|
||||
}
|
||||
|
||||
func (c *TTLCache[K, V]) SetInvalidateCallback(hook Hook[K, V]) {
|
||||
// Ensure non-nil hook
|
||||
if hook == nil {
|
||||
hook = emptyHook[K, V]
|
||||
}
|
||||
|
||||
// Safely set invalidate hook
|
||||
c.Lock()
|
||||
c.invalid = hook
|
||||
c.Unlock()
|
||||
}
|
||||
|
||||
func (c *TTLCache[K, V]) SetTTL(ttl time.Duration, update bool) {
|
||||
// Safely update TTL
|
||||
c.Lock()
|
||||
diff := ttl - c.ttl
|
||||
c.ttl = ttl
|
||||
|
||||
if update {
|
||||
// Update existing cache entries
|
||||
for _, entry := range c.cache {
|
||||
entry.expiry.Add(diff)
|
||||
}
|
||||
}
|
||||
|
||||
// We're done
|
||||
c.Unlock()
|
||||
}
|
||||
|
||||
func (c *TTLCache[K, V]) Get(key K) (V, bool) {
|
||||
c.Lock()
|
||||
value, ok := c.GetUnsafe(key)
|
||||
c.Unlock()
|
||||
return value, ok
|
||||
}
|
||||
|
||||
// GetUnsafe is the mutex-unprotected logic for Cache.Get().
|
||||
func (c *TTLCache[K, V]) GetUnsafe(key K) (V, bool) {
|
||||
item, ok := c.cache[key]
|
||||
if !ok {
|
||||
var value V
|
||||
return value, false
|
||||
}
|
||||
item.expiry = time.Now().Add(c.ttl)
|
||||
return item.value, true
|
||||
}
|
||||
|
||||
func (c *TTLCache[K, V]) Put(key K, value V) bool {
|
||||
c.Lock()
|
||||
success := c.PutUnsafe(key, value)
|
||||
c.Unlock()
|
||||
return success
|
||||
}
|
||||
|
||||
// PutUnsafe is the mutex-unprotected logic for Cache.Put().
|
||||
func (c *TTLCache[K, V]) PutUnsafe(key K, value V) bool {
|
||||
// If already cached, return
|
||||
if _, ok := c.cache[key]; ok {
|
||||
return false
|
||||
}
|
||||
|
||||
// Create new cached item
|
||||
c.cache[key] = &entry[V]{
|
||||
value: value,
|
||||
expiry: time.Now().Add(c.ttl),
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func (c *TTLCache[K, V]) Set(key K, value V) {
|
||||
c.Lock()
|
||||
defer c.Unlock() // defer in case of hook panic
|
||||
c.SetUnsafe(key, value)
|
||||
}
|
||||
|
||||
// SetUnsafe is the mutex-unprotected logic for Cache.Set(), it calls externally-set functions.
|
||||
func (c *TTLCache[K, V]) SetUnsafe(key K, value V) {
|
||||
item, ok := c.cache[key]
|
||||
if ok {
|
||||
// call invalidate hook
|
||||
c.invalid(key, item.value)
|
||||
} else {
|
||||
// alloc new item
|
||||
item = &entry[V]{}
|
||||
c.cache[key] = item
|
||||
}
|
||||
|
||||
// Update the item + expiry
|
||||
item.value = value
|
||||
item.expiry = time.Now().Add(c.ttl)
|
||||
}
|
||||
|
||||
func (c *TTLCache[K, V]) CAS(key K, cmp V, swp V) bool {
|
||||
c.Lock()
|
||||
ok := c.CASUnsafe(key, cmp, swp)
|
||||
c.Unlock()
|
||||
return ok
|
||||
}
|
||||
|
||||
// CASUnsafe is the mutex-unprotected logic for Cache.CAS().
|
||||
func (c *TTLCache[K, V]) CASUnsafe(key K, cmp V, swp V) bool {
|
||||
// Check for item
|
||||
item, ok := c.cache[key]
|
||||
if !ok || !Compare(item.value, cmp) {
|
||||
return false
|
||||
}
|
||||
|
||||
// Invalidate item
|
||||
c.invalid(key, item.value)
|
||||
|
||||
// Update item + expiry
|
||||
item.value = swp
|
||||
item.expiry = time.Now().Add(c.ttl)
|
||||
|
||||
return ok
|
||||
}
|
||||
|
||||
func (c *TTLCache[K, V]) Swap(key K, swp V) V {
|
||||
c.Lock()
|
||||
old := c.SwapUnsafe(key, swp)
|
||||
c.Unlock()
|
||||
return old
|
||||
}
|
||||
|
||||
// SwapUnsafe is the mutex-unprotected logic for Cache.Swap().
|
||||
func (c *TTLCache[K, V]) SwapUnsafe(key K, swp V) V {
|
||||
// Check for item
|
||||
item, ok := c.cache[key]
|
||||
if !ok {
|
||||
var value V
|
||||
return value
|
||||
}
|
||||
|
||||
// invalidate old item
|
||||
c.invalid(key, item.value)
|
||||
old := item.value
|
||||
|
||||
// update item + expiry
|
||||
item.value = swp
|
||||
item.expiry = time.Now().Add(c.ttl)
|
||||
|
||||
return old
|
||||
}
|
||||
|
||||
func (c *TTLCache[K, V]) Has(key K) bool {
|
||||
c.Lock()
|
||||
ok := c.HasUnsafe(key)
|
||||
c.Unlock()
|
||||
return ok
|
||||
}
|
||||
|
||||
// HasUnsafe is the mutex-unprotected logic for Cache.Has().
|
||||
func (c *TTLCache[K, V]) HasUnsafe(key K) bool {
|
||||
_, ok := c.cache[key]
|
||||
return ok
|
||||
}
|
||||
|
||||
func (c *TTLCache[K, V]) Invalidate(key K) bool {
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
return c.InvalidateUnsafe(key)
|
||||
}
|
||||
|
||||
// InvalidateUnsafe is mutex-unprotected logic for Cache.Invalidate().
|
||||
func (c *TTLCache[K, V]) InvalidateUnsafe(key K) bool {
|
||||
// Check if we have item with key
|
||||
item, ok := c.cache[key]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
// Call hook, remove from cache
|
||||
c.invalid(key, item.value)
|
||||
delete(c.cache, key)
|
||||
return true
|
||||
}
|
||||
|
||||
func (c *TTLCache[K, V]) Clear() {
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
c.ClearUnsafe()
|
||||
}
|
||||
|
||||
// ClearUnsafe is mutex-unprotected logic for Cache.Clean().
|
||||
func (c *TTLCache[K, V]) ClearUnsafe() {
|
||||
for key, item := range c.cache {
|
||||
c.invalid(key, item.value)
|
||||
delete(c.cache, key)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *TTLCache[K, V]) Size() int {
|
||||
c.Lock()
|
||||
sz := c.SizeUnsafe()
|
||||
c.Unlock()
|
||||
return sz
|
||||
}
|
||||
|
||||
// SizeUnsafe is mutex unprotected logic for Cache.Size().
|
||||
func (c *TTLCache[K, V]) SizeUnsafe() int {
|
||||
return len(c.cache)
|
||||
}
|
||||
|
||||
// entry represents an item in the cache, with
|
||||
// it's currently calculated expiry time.
|
||||
type entry[Value any] struct {
|
||||
value Value
|
||||
expiry time.Time
|
||||
}
|
5
vendor/modules.txt
vendored
5
vendor/modules.txt
vendored
|
@ -4,9 +4,12 @@ codeberg.org/gruf/go-bitutil
|
|||
# codeberg.org/gruf/go-bytes v1.0.2
|
||||
## explicit; go 1.14
|
||||
codeberg.org/gruf/go-bytes
|
||||
# codeberg.org/gruf/go-byteutil v1.0.0
|
||||
# codeberg.org/gruf/go-byteutil v1.0.1
|
||||
## explicit; go 1.16
|
||||
codeberg.org/gruf/go-byteutil
|
||||
# codeberg.org/gruf/go-cache/v2 v2.0.1
|
||||
## explicit; go 1.18
|
||||
codeberg.org/gruf/go-cache/v2
|
||||
# codeberg.org/gruf/go-debug v1.1.2
|
||||
## explicit; go 1.16
|
||||
codeberg.org/gruf/go-debug
|
||||
|
|
Loading…
Reference in a new issue