diff --git a/cmd/gotosocial/action/server/server.go b/cmd/gotosocial/action/server/server.go index 42cbf318b..68b039d0c 100644 --- a/cmd/gotosocial/action/server/server.go +++ b/cmd/gotosocial/action/server/server.go @@ -87,9 +87,9 @@ var Start action.GTSAction = func(ctx context.Context) error { // defer function for safe shutdown // depending on what services were // managed to be started. - - state = new(state.State) - route *router.Router + state = new(state.State) + route *router.Router + process *processing.Processor ) defer func() { @@ -125,6 +125,23 @@ var Start action.GTSAction = func(ctx context.Context) error { } } + if process != nil { + const timeout = time.Minute + + // Use a new timeout context to ensure + // persisting queued tasks does not fail! + // The main ctx is very likely canceled. + ctx := context.WithoutCancel(ctx) + ctx, cncl := context.WithTimeout(ctx, timeout) + defer cncl() + + // Now that all the "moving" components have been stopped, + // persist any remaining queued worker tasks to the database. + if err := process.Admin().PersistWorkerQueues(ctx); err != nil { + log.Errorf(ctx, "error persisting worker queues: %v", err) + } + } + if state.DB != nil { // Lastly, if database service was started, // ensure it gets closed now all else stopped. @@ -270,7 +287,7 @@ var Start action.GTSAction = func(ctx context.Context) error { // Create the processor using all the // other services we've created so far. - processor := processing.NewProcessor( + process = processing.NewProcessor( cleaner, typeConverter, federator, @@ -286,14 +303,14 @@ var Start action.GTSAction = func(ctx context.Context) error { state.Workers.Client.Init(messages.ClientMsgIndices()) state.Workers.Federator.Init(messages.FederatorMsgIndices()) state.Workers.Delivery.Init(client) - state.Workers.Client.Process = processor.Workers().ProcessFromClientAPI - state.Workers.Federator.Process = processor.Workers().ProcessFromFediAPI + state.Workers.Client.Process = process.Workers().ProcessFromClientAPI + state.Workers.Federator.Process = process.Workers().ProcessFromFediAPI // Now start workers! state.Workers.Start() // Schedule notif tasks for all existing poll expiries. - if err := processor.Polls().ScheduleAll(ctx); err != nil { + if err := process.Polls().ScheduleAll(ctx); err != nil { return fmt.Errorf("error scheduling poll expiries: %w", err) } @@ -303,7 +320,7 @@ var Start action.GTSAction = func(ctx context.Context) error { } // Run advanced migrations. - if err := processor.AdvancedMigrations().Migrate(ctx); err != nil { + if err := process.AdvancedMigrations().Migrate(ctx); err != nil { return err } @@ -370,7 +387,7 @@ var Start action.GTSAction = func(ctx context.Context) error { // attach global no route / 404 handler to the router route.AttachNoRouteHandler(func(c *gin.Context) { - apiutil.ErrorHandler(c, gtserror.NewErrorNotFound(errors.New(http.StatusText(http.StatusNotFound))), processor.InstanceGetV1) + apiutil.ErrorHandler(c, gtserror.NewErrorNotFound(errors.New(http.StatusText(http.StatusNotFound))), process.InstanceGetV1) }) // build router modules @@ -393,15 +410,15 @@ var Start action.GTSAction = func(ctx context.Context) error { } var ( - authModule = api.NewAuth(dbService, processor, idp, routerSession, sessionName) // auth/oauth paths - clientModule = api.NewClient(state, processor) // api client endpoints - metricsModule = api.NewMetrics() // Metrics endpoints - healthModule = api.NewHealth(dbService.Ready) // Health check endpoints - fileserverModule = api.NewFileserver(processor) // fileserver endpoints - wellKnownModule = api.NewWellKnown(processor) // .well-known endpoints - nodeInfoModule = api.NewNodeInfo(processor) // nodeinfo endpoint - activityPubModule = api.NewActivityPub(dbService, processor) // ActivityPub endpoints - webModule = web.New(dbService, processor) // web pages + user profiles + settings panels etc + authModule = api.NewAuth(dbService, process, idp, routerSession, sessionName) // auth/oauth paths + clientModule = api.NewClient(state, process) // api client endpoints + metricsModule = api.NewMetrics() // Metrics endpoints + healthModule = api.NewHealth(dbService.Ready) // Health check endpoints + fileserverModule = api.NewFileserver(process) // fileserver endpoints + wellKnownModule = api.NewWellKnown(process) // .well-known endpoints + nodeInfoModule = api.NewNodeInfo(process) // nodeinfo endpoint + activityPubModule = api.NewActivityPub(dbService, process) // ActivityPub endpoints + webModule = web.New(dbService, process) // web pages + user profiles + settings panels etc ) // create required middleware @@ -416,10 +433,11 @@ var Start action.GTSAction = func(ctx context.Context) error { // throttling cpuMultiplier := config.GetAdvancedThrottlingMultiplier() retryAfter := config.GetAdvancedThrottlingRetryAfter() - clThrottle := middleware.Throttle(cpuMultiplier, retryAfter) // client api - s2sThrottle := middleware.Throttle(cpuMultiplier, retryAfter) // server-to-server (AP) - fsThrottle := middleware.Throttle(cpuMultiplier, retryAfter) // fileserver / web templates / emojis - pkThrottle := middleware.Throttle(cpuMultiplier, retryAfter) // throttle public key endpoint separately + clThrottle := middleware.Throttle(cpuMultiplier, retryAfter) // client api + s2sThrottle := middleware.Throttle(cpuMultiplier, retryAfter) + // server-to-server (AP) + fsThrottle := middleware.Throttle(cpuMultiplier, retryAfter) // fileserver / web templates / emojis + pkThrottle := middleware.Throttle(cpuMultiplier, retryAfter) // throttle public key endpoint separately gzip := middleware.Gzip() // applied to all except fileserver @@ -442,6 +460,11 @@ var Start action.GTSAction = func(ctx context.Context) error { return fmt.Errorf("error starting router: %w", err) } + // Fill worker queues from persisted task data in database. + if err := process.Admin().FillWorkerQueues(ctx); err != nil { + return fmt.Errorf("error filling worker queues: %w", err) + } + // catch shutdown signals from the operating system sigs := make(chan os.Signal, 1) signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM) diff --git a/internal/db/bundb/bundb.go b/internal/db/bundb/bundb.go index 070d4eb91..d5071d141 100644 --- a/internal/db/bundb/bundb.go +++ b/internal/db/bundb/bundb.go @@ -84,6 +84,7 @@ type DBService struct { db.Timeline db.User db.Tombstone + db.WorkerTask db *bun.DB } @@ -302,6 +303,9 @@ func NewBunDBService(ctx context.Context, state *state.State) (db.DB, error) { db: db, state: state, }, + WorkerTask: &workerTaskDB{ + db: db, + }, db: db, } diff --git a/internal/db/bundb/migrations/20240617134210_add_worker_tasks_table.go b/internal/db/bundb/migrations/20240617134210_add_worker_tasks_table.go new file mode 100644 index 000000000..3b0ebcfd8 --- /dev/null +++ b/internal/db/bundb/migrations/20240617134210_add_worker_tasks_table.go @@ -0,0 +1,51 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package migrations + +import ( + "context" + + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/uptrace/bun" +) + +func init() { + up := func(ctx context.Context, db *bun.DB) error { + return db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { + // WorkerTask table. + if _, err := tx. + NewCreateTable(). + Model(>smodel.WorkerTask{}). + IfNotExists(). + Exec(ctx); err != nil { + return err + } + return nil + }) + } + + down := func(ctx context.Context, db *bun.DB) error { + return db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { + return nil + }) + } + + if err := Migrations.Register(up, down); err != nil { + panic(err) + } +} diff --git a/internal/db/bundb/workertask.go b/internal/db/bundb/workertask.go new file mode 100644 index 000000000..eec51530d --- /dev/null +++ b/internal/db/bundb/workertask.go @@ -0,0 +1,58 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package bundb + +import ( + "context" + "errors" + + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/uptrace/bun" +) + +type workerTaskDB struct{ db *bun.DB } + +func (w *workerTaskDB) GetWorkerTasks(ctx context.Context) ([]*gtsmodel.WorkerTask, error) { + var tasks []*gtsmodel.WorkerTask + if err := w.db.NewSelect(). + Model(&tasks). + OrderExpr("? ASC", bun.Ident("created_at")). + Scan(ctx); err != nil { + return nil, err + } + return tasks, nil +} + +func (w *workerTaskDB) PutWorkerTasks(ctx context.Context, tasks []*gtsmodel.WorkerTask) error { + var errs []error + for _, task := range tasks { + _, err := w.db.NewInsert().Model(task).Exec(ctx) + if err != nil { + errs = append(errs, err) + } + } + return errors.Join(errs...) +} + +func (w *workerTaskDB) DeleteWorkerTaskByID(ctx context.Context, id uint) error { + _, err := w.db.NewDelete(). + Table("worker_tasks"). + Where("? = ?", bun.Ident("id"), id). + Exec(ctx) + return err +} diff --git a/internal/db/db.go b/internal/db/db.go index 4b2152732..cd621871a 100644 --- a/internal/db/db.go +++ b/internal/db/db.go @@ -56,4 +56,5 @@ type DB interface { Timeline User Tombstone + WorkerTask } diff --git a/internal/db/workertask.go b/internal/db/workertask.go new file mode 100644 index 000000000..0276f231a --- /dev/null +++ b/internal/db/workertask.go @@ -0,0 +1,35 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package db + +import ( + "context" + + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +) + +type WorkerTask interface { + // GetWorkerTasks fetches all persisted worker tasks from the database. + GetWorkerTasks(ctx context.Context) ([]*gtsmodel.WorkerTask, error) + + // PutWorkerTasks persists the given worker tasks to the database. + PutWorkerTasks(ctx context.Context, tasks []*gtsmodel.WorkerTask) error + + // DeleteWorkerTask deletes worker task with given ID from database. + DeleteWorkerTaskByID(ctx context.Context, id uint) error +} diff --git a/internal/gtsmodel/workertask.go b/internal/gtsmodel/workertask.go index cc8433199..758fc4cd7 100644 --- a/internal/gtsmodel/workertask.go +++ b/internal/gtsmodel/workertask.go @@ -34,8 +34,8 @@ const ( // queued tasks from being lost. It is simply a // means to store a blob of serialized task data. type WorkerTask struct { - ID uint `bun:""` - WorkerType uint8 `bun:""` - TaskData []byte `bun:""` - CreatedAt time.Time `bun:""` + ID uint `bun:",pk,autoincrement"` + WorkerType WorkerType `bun:",notnull"` + TaskData []byte `bun:",nullzero,notnull"` + CreatedAt time.Time `bun:"type:timestamptz,nullzero,notnull,default:current_timestamp"` } diff --git a/internal/httpclient/client.go b/internal/httpclient/client.go index b78dbc2d9..30ef0b04d 100644 --- a/internal/httpclient/client.go +++ b/internal/httpclient/client.go @@ -197,7 +197,7 @@ func (c *Client) Do(r *http.Request) (rsp *http.Response, err error) { // If the fast-fail flag was set, just // attempt a single iteration instead of // following the below retry-backoff loop. - rsp, _, err = c.DoOnce(&req) + rsp, _, err = c.DoOnce(req) if err != nil { return nil, fmt.Errorf("%w (fast fail)", err) } @@ -208,7 +208,7 @@ func (c *Client) Do(r *http.Request) (rsp *http.Response, err error) { var retry bool // Perform the http request. - rsp, retry, err = c.DoOnce(&req) + rsp, retry, err = c.DoOnce(req) if err == nil { return rsp, nil } diff --git a/internal/httpclient/request.go b/internal/httpclient/request.go index e5a7f44d3..dfe51b160 100644 --- a/internal/httpclient/request.go +++ b/internal/httpclient/request.go @@ -47,8 +47,8 @@ type Request struct { // WrapRequest wraps an existing http.Request within // our own httpclient.Request with retry / backoff tracking. -func WrapRequest(r *http.Request) Request { - var rr Request +func WrapRequest(r *http.Request) *Request { + rr := new(Request) rr.Request = r entry := log.WithContext(r.Context()) entry = entry.WithField("method", r.Method) diff --git a/internal/messages/messages.go b/internal/messages/messages.go index 7779633ba..d652c0c5c 100644 --- a/internal/messages/messages.go +++ b/internal/messages/messages.go @@ -352,7 +352,7 @@ func resolveAPObject(data map[string]interface{}) (interface{}, error) { // we then need to wrangle back into the original type. So we also store the type name // and use this to determine the appropriate Go structure type to unmarshal into to. func resolveGTSModel(typ string, data []byte) (interface{}, error) { - if typ == "" && data == nil { + if typ == "" { // No data given. return nil, nil } diff --git a/internal/processing/admin/workertask.go b/internal/processing/admin/workertask.go new file mode 100644 index 000000000..6d7cc7b7a --- /dev/null +++ b/internal/processing/admin/workertask.go @@ -0,0 +1,426 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package admin + +import ( + "context" + "fmt" + "slices" + "time" + + "github.com/superseriousbusiness/gotosocial/internal/gtserror" + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/superseriousbusiness/gotosocial/internal/log" + "github.com/superseriousbusiness/gotosocial/internal/messages" + "github.com/superseriousbusiness/gotosocial/internal/transport" + "github.com/superseriousbusiness/gotosocial/internal/transport/delivery" +) + +// NOTE: +// Having these functions in the processor, which is +// usually the intermediary that performs *processing* +// between the HTTP route handlers and the underlying +// database / storage layers is a little odd, so this +// may be subject to change! +// +// For now at least, this is a useful place that has +// access to the underlying database, workers and +// causes no dependency cycles with this use case! + +// FillWorkerQueues recovers all serialized worker tasks from the database +// (if any!), and pushes them to each of their relevant worker queues. +func (p *Processor) FillWorkerQueues(ctx context.Context) error { + log.Info(ctx, "rehydrate!") + + // Get all persisted worker tasks from db. + // + // (database returns these as ASCENDING, i.e. + // returned in the order they were inserted). + tasks, err := p.state.DB.GetWorkerTasks(ctx) + if err != nil { + return gtserror.Newf("error fetching worker tasks from db: %w", err) + } + + var ( + // Counts of each task type + // successfully recovered. + delivery int + federator int + client int + + // Failed recoveries. + errors int + ) + +loop: + + // Handle each persisted task, removing + // all those we can't handle. Leaving us + // with a slice of tasks we can safely + // delete from being persisted in the DB. + for i := 0; i < len(tasks); { + var err error + + // Task at index. + task := tasks[i] + + // Appropriate task count + // pointer to increment. + var counter *int + + // Attempt to recovery persisted + // task depending on worker type. + switch task.WorkerType { + case gtsmodel.DeliveryWorker: + err = p.pushDelivery(ctx, task) + counter = &delivery + case gtsmodel.FederatorWorker: + err = p.pushFederator(ctx, task) + counter = &federator + case gtsmodel.ClientWorker: + err = p.pushClient(ctx, task) + counter = &client + default: + err = fmt.Errorf("invalid worker type %d", task.WorkerType) + } + + if err != nil { + log.Errorf(ctx, "error pushing task %d: %v", task.ID, err) + + // Drop error'd task from slice. + tasks = slices.Delete(tasks, i, i+1) + + // Incr errors. + errors++ + continue loop + } + + // Increment slice + // index & counter. + (*counter)++ + i++ + } + + // Tasks that worker successfully pushed + // to their appropriate workers, we can + // safely now remove from the database. + for _, task := range tasks { + if err := p.state.DB.DeleteWorkerTaskByID(ctx, task.ID); err != nil { + log.Errorf(ctx, "error deleting task from db: %v", err) + } + } + + // Log recovered tasks. + log.WithContext(ctx). + WithField("delivery", delivery). + WithField("federator", federator). + WithField("client", client). + WithField("errors", errors). + Info("recovered queued tasks") + + return nil +} + +// PersistWorkerQueues pops all queued worker tasks (that are themselves persistable, i.e. not +// dereference tasks which are just function ptrs), serializes and persists them to the database. +func (p *Processor) PersistWorkerQueues(ctx context.Context) error { + log.Info(ctx, "dehydrate!") + + var ( + // Counts of each task type + // successfully persisted. + delivery int + federator int + client int + + // Failed persists. + errors int + + // Serialized tasks to persist. + tasks []*gtsmodel.WorkerTask + ) + + for { + // Pop all queued deliveries. + task, err := p.popDelivery() + if err != nil { + log.Errorf(ctx, "error popping delivery: %v", err) + errors++ // incr error count. + continue + } + + if task == nil { + // No more queue + // tasks to pop! + break + } + + // Append serialized task. + tasks = append(tasks, task) + delivery++ // incr count + } + + for { + // Pop queued federator msgs. + task, err := p.popFederator() + if err != nil { + log.Errorf(ctx, "error popping federator message: %v", err) + errors++ // incr count + continue + } + + if task == nil { + // No more queue + // tasks to pop! + break + } + + // Append serialized task. + tasks = append(tasks, task) + federator++ // incr count + } + + for { + // Pop queued client msgs. + task, err := p.popClient() + if err != nil { + log.Errorf(ctx, "error popping client message: %v", err) + continue + } + + if task == nil { + // No more queue + // tasks to pop! + break + } + + // Append serialized task. + tasks = append(tasks, task) + client++ // incr count + } + + // Persist all serialized queued worker tasks to database. + if err := p.state.DB.PutWorkerTasks(ctx, tasks); err != nil { + return gtserror.Newf("error putting tasks in db: %w", err) + } + + // Log recovered tasks. + log.WithContext(ctx). + WithField("delivery", delivery). + WithField("federator", federator). + WithField("client", client). + WithField("errors", errors). + Info("persisted queued tasks") + + return nil +} + +// pushDelivery parses a valid delivery.Delivery{} from serialized task data and pushes to queue. +func (p *Processor) pushDelivery(ctx context.Context, task *gtsmodel.WorkerTask) error { + dlv := new(delivery.Delivery) + + // Deserialize the raw worker task data into delivery. + if err := dlv.Deserialize(task.TaskData); err != nil { + return gtserror.Newf("error deserializing delivery: %w", err) + } + + var tsport transport.Transport + + if uri := dlv.ActorID; uri != "" { + // Fetch the actor account by provided URI from db. + account, err := p.state.DB.GetAccountByURI(ctx, uri) + if err != nil { + return gtserror.Newf("error getting actor account %s from db: %w", uri, err) + } + + // Fetch a transport for request signing for actor's account username. + tsport, err = p.transport.NewTransportForUsername(ctx, account.Username) + if err != nil { + return gtserror.Newf("error getting transport for actor %s: %w", uri, err) + } + } else { + var err error + + // No actor was given, will be signed by instance account. + tsport, err = p.transport.NewTransportForUsername(ctx, "") + if err != nil { + return gtserror.Newf("error getting instance account transport: %w", err) + } + } + + // Using transport, add actor signature to delivery. + if err := tsport.SignDelivery(dlv); err != nil { + return gtserror.Newf("error signing delivery: %w", err) + } + + // Push deserialized task to delivery queue. + p.state.Workers.Delivery.Queue.Push(dlv) + + return nil +} + +// popDelivery pops delivery.Delivery{} from queue and serializes as valid task data. +func (p *Processor) popDelivery() (*gtsmodel.WorkerTask, error) { + + // Pop waiting delivery from the delivery worker. + delivery, ok := p.state.Workers.Delivery.Queue.Pop() + if !ok { + return nil, nil + } + + // Serialize the delivery task data. + data, err := delivery.Serialize() + if err != nil { + return nil, gtserror.Newf("error serializing delivery: %w", err) + } + + return >smodel.WorkerTask{ + // ID is autoincrement + WorkerType: gtsmodel.DeliveryWorker, + TaskData: data, + CreatedAt: time.Now(), + }, nil +} + +// pushClient parses a valid messages.FromFediAPI{} from serialized task data and pushes to queue. +func (p *Processor) pushFederator(ctx context.Context, task *gtsmodel.WorkerTask) error { + var msg messages.FromFediAPI + + // Deserialize the raw worker task data into message. + if err := msg.Deserialize(task.TaskData); err != nil { + return gtserror.Newf("error deserializing federator message: %w", err) + } + + if rcv := msg.Receiving; rcv != nil { + // Only a placeholder receiving account will be populated, + // fetch the actual model from database by persisted ID. + account, err := p.state.DB.GetAccountByID(ctx, rcv.ID) + if err != nil { + return gtserror.Newf("error fetching receiving account %s from db: %w", rcv.ID, err) + } + + // Set the now populated + // receiving account model. + msg.Receiving = account + } + + if req := msg.Requesting; req != nil { + // Only a placeholder requesting account will be populated, + // fetch the actual model from database by persisted ID. + account, err := p.state.DB.GetAccountByID(ctx, req.ID) + if err != nil { + return gtserror.Newf("error fetching requesting account %s from db: %w", req.ID, err) + } + + // Set the now populated + // requesting account model. + msg.Requesting = account + } + + // Push populated task to the federator queue. + p.state.Workers.Federator.Queue.Push(&msg) + + return nil +} + +// popFederator pops messages.FromFediAPI{} from queue and serializes as valid task data. +func (p *Processor) popFederator() (*gtsmodel.WorkerTask, error) { + + // Pop waiting message from the federator worker. + msg, ok := p.state.Workers.Federator.Queue.Pop() + if !ok { + return nil, nil + } + + // Serialize message task data. + data, err := msg.Serialize() + if err != nil { + return nil, gtserror.Newf("error serializing federator message: %w", err) + } + + return >smodel.WorkerTask{ + // ID is autoincrement + WorkerType: gtsmodel.FederatorWorker, + TaskData: data, + CreatedAt: time.Now(), + }, nil +} + +// pushClient parses a valid messages.FromClientAPI{} from serialized task data and pushes to queue. +func (p *Processor) pushClient(ctx context.Context, task *gtsmodel.WorkerTask) error { + var msg messages.FromClientAPI + + // Deserialize the raw worker task data into message. + if err := msg.Deserialize(task.TaskData); err != nil { + return gtserror.Newf("error deserializing client message: %w", err) + } + + if org := msg.Origin; org != nil { + // Only a placeholder origin account will be populated, + // fetch the actual model from database by persisted ID. + account, err := p.state.DB.GetAccountByID(ctx, org.ID) + if err != nil { + return gtserror.Newf("error fetching origin account %s from db: %w", org.ID, err) + } + + // Set the now populated + // origin account model. + msg.Origin = account + } + + if trg := msg.Target; trg != nil { + // Only a placeholder target account will be populated, + // fetch the actual model from database by persisted ID. + account, err := p.state.DB.GetAccountByID(ctx, trg.ID) + if err != nil { + return gtserror.Newf("error fetching target account %s from db: %w", trg.ID, err) + } + + // Set the now populated + // target account model. + msg.Target = account + } + + // Push populated task to the federator queue. + p.state.Workers.Client.Queue.Push(&msg) + + return nil +} + +// popClient pops messages.FromClientAPI{} from queue and serializes as valid task data. +func (p *Processor) popClient() (*gtsmodel.WorkerTask, error) { + + // Pop waiting message from the client worker. + msg, ok := p.state.Workers.Client.Queue.Pop() + if !ok { + return nil, nil + } + + // Serialize message task data. + data, err := msg.Serialize() + if err != nil { + return nil, gtserror.Newf("error serializing client message: %w", err) + } + + return >smodel.WorkerTask{ + // ID is autoincrement + WorkerType: gtsmodel.ClientWorker, + TaskData: data, + CreatedAt: time.Now(), + }, nil +} diff --git a/internal/processing/admin/workertask_test.go b/internal/processing/admin/workertask_test.go new file mode 100644 index 000000000..bf326bafd --- /dev/null +++ b/internal/processing/admin/workertask_test.go @@ -0,0 +1,421 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package admin_test + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "testing" + + "github.com/stretchr/testify/suite" + "github.com/superseriousbusiness/gotosocial/internal/ap" + "github.com/superseriousbusiness/gotosocial/internal/gtscontext" + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/superseriousbusiness/gotosocial/internal/httpclient" + "github.com/superseriousbusiness/gotosocial/internal/messages" + "github.com/superseriousbusiness/gotosocial/internal/transport/delivery" + "github.com/superseriousbusiness/gotosocial/testrig" +) + +var ( + // TODO: move these test values into + // the testrig test models area. They'll + // need to be as both WorkerTask and as + // the raw types themselves. + + testDeliveries = []*delivery.Delivery{ + { + ObjectID: "https://google.com/users/bigboy/follow/1", + TargetID: "https://askjeeves.com/users/smallboy", + Request: toRequest("POST", "https://askjeeves.com/users/smallboy/inbox", []byte("data!"), http.Header{"Host": {"https://askjeeves.com"}}), + }, + { + Request: toRequest("GET", "https://google.com", []byte("uwu im just a wittle seawch engwin"), http.Header{"Host": {"https://google.com"}}), + }, + } + + testFederatorMsgs = []*messages.FromFediAPI{ + { + APObjectType: ap.ObjectNote, + APActivityType: ap.ActivityCreate, + TargetURI: "https://gotosocial.org", + Requesting: >smodel.Account{ID: "654321"}, + Receiving: >smodel.Account{ID: "123456"}, + }, + { + APObjectType: ap.ObjectProfile, + APActivityType: ap.ActivityUpdate, + TargetURI: "https://uk-queen-is-dead.org", + Requesting: >smodel.Account{ID: "123456"}, + Receiving: >smodel.Account{ID: "654321"}, + }, + } + + testClientMsgs = []*messages.FromClientAPI{ + { + APObjectType: ap.ObjectNote, + APActivityType: ap.ActivityCreate, + TargetURI: "https://gotosocial.org", + Origin: >smodel.Account{ID: "654321"}, + Target: >smodel.Account{ID: "123456"}, + }, + { + APObjectType: ap.ObjectProfile, + APActivityType: ap.ActivityUpdate, + TargetURI: "https://uk-queen-is-dead.org", + Origin: >smodel.Account{ID: "123456"}, + Target: >smodel.Account{ID: "654321"}, + }, + } +) + +type WorkerTaskTestSuite struct { + AdminStandardTestSuite +} + +func (suite *WorkerTaskTestSuite) TestFillWorkerQueues() { + ctx, cncl := context.WithCancel(context.Background()) + defer cncl() + + var tasks []*gtsmodel.WorkerTask + + for _, dlv := range testDeliveries { + // Serialize all test deliveries. + data, err := dlv.Serialize() + if err != nil { + panic(err) + } + + // Append each serialized delivery to tasks. + tasks = append(tasks, >smodel.WorkerTask{ + WorkerType: gtsmodel.DeliveryWorker, + TaskData: data, + }) + } + + for _, msg := range testFederatorMsgs { + // Serialize all test messages. + data, err := msg.Serialize() + if err != nil { + panic(err) + } + + if msg.Receiving != nil { + // Quick hack to bypass database errors for non-existing + // accounts, instead we just insert this into cache ;). + suite.state.Caches.DB.Account.Put(msg.Receiving) + suite.state.Caches.DB.AccountSettings.Put(>smodel.AccountSettings{ + AccountID: msg.Receiving.ID, + }) + } + + if msg.Requesting != nil { + // Quick hack to bypass database errors for non-existing + // accounts, instead we just insert this into cache ;). + suite.state.Caches.DB.Account.Put(msg.Requesting) + suite.state.Caches.DB.AccountSettings.Put(>smodel.AccountSettings{ + AccountID: msg.Requesting.ID, + }) + } + + // Append each serialized message to tasks. + tasks = append(tasks, >smodel.WorkerTask{ + WorkerType: gtsmodel.FederatorWorker, + TaskData: data, + }) + } + + for _, msg := range testClientMsgs { + // Serialize all test messages. + data, err := msg.Serialize() + if err != nil { + panic(err) + } + + if msg.Origin != nil { + // Quick hack to bypass database errors for non-existing + // accounts, instead we just insert this into cache ;). + suite.state.Caches.DB.Account.Put(msg.Origin) + suite.state.Caches.DB.AccountSettings.Put(>smodel.AccountSettings{ + AccountID: msg.Origin.ID, + }) + } + + if msg.Target != nil { + // Quick hack to bypass database errors for non-existing + // accounts, instead we just insert this into cache ;). + suite.state.Caches.DB.Account.Put(msg.Target) + suite.state.Caches.DB.AccountSettings.Put(>smodel.AccountSettings{ + AccountID: msg.Target.ID, + }) + } + + // Append each serialized message to tasks. + tasks = append(tasks, >smodel.WorkerTask{ + WorkerType: gtsmodel.ClientWorker, + TaskData: data, + }) + } + + // Persist all test worker tasks to the database. + err := suite.state.DB.PutWorkerTasks(ctx, tasks) + suite.NoError(err) + + // Fill the worker queues from persisted task data. + err = suite.adminProcessor.FillWorkerQueues(ctx) + suite.NoError(err) + + var ( + // Recovered + // task counts. + ndelivery int + nfederator int + nclient int + ) + + // Fetch current gotosocial instance account, for later checks. + instanceAcc, err := suite.state.DB.GetInstanceAccount(ctx, "") + suite.NoError(err) + + for { + // Pop all queued delivery tasks from worker queue. + dlv, ok := suite.state.Workers.Delivery.Queue.Pop() + if !ok { + break + } + + // Incr count. + ndelivery++ + + // Check that we have this message in slice. + err = containsSerializable(testDeliveries, dlv) + suite.NoError(err) + + // Check that delivery request context has instance account pubkey. + pubKeyID := gtscontext.OutgoingPublicKeyID(dlv.Request.Context()) + suite.Equal(instanceAcc.PublicKeyURI, pubKeyID) + signfn := gtscontext.HTTPClientSignFunc(dlv.Request.Context()) + suite.NotNil(signfn) + } + + for { + // Pop all queued federator messages from worker queue. + msg, ok := suite.state.Workers.Federator.Queue.Pop() + if !ok { + break + } + + // Incr count. + nfederator++ + + // Check that we have this message in slice. + err = containsSerializable(testFederatorMsgs, msg) + suite.NoError(err) + } + + for { + // Pop all queued client messages from worker queue. + msg, ok := suite.state.Workers.Client.Queue.Pop() + if !ok { + break + } + + // Incr count. + nclient++ + + // Check that we have this message in slice. + err = containsSerializable(testClientMsgs, msg) + suite.NoError(err) + } + + // Ensure recovered task counts as expected. + suite.Equal(len(testDeliveries), ndelivery) + suite.Equal(len(testFederatorMsgs), nfederator) + suite.Equal(len(testClientMsgs), nclient) +} + +func (suite *WorkerTaskTestSuite) TestPersistWorkerQueues() { + ctx, cncl := context.WithCancel(context.Background()) + defer cncl() + + // Push all test worker tasks to their respective queues. + suite.state.Workers.Delivery.Queue.Push(testDeliveries...) + suite.state.Workers.Federator.Queue.Push(testFederatorMsgs...) + suite.state.Workers.Client.Queue.Push(testClientMsgs...) + + // Persist the worker queued tasks to database. + err := suite.adminProcessor.PersistWorkerQueues(ctx) + suite.NoError(err) + + // Fetch all the persisted tasks from database. + tasks, err := suite.state.DB.GetWorkerTasks(ctx) + suite.NoError(err) + + var ( + // Persisted + // task counts. + ndelivery int + nfederator int + nclient int + ) + + // Check persisted task data. + for _, task := range tasks { + switch task.WorkerType { + case gtsmodel.DeliveryWorker: + var dlv delivery.Delivery + + // Incr count. + ndelivery++ + + // Deserialize the persisted task data. + err := dlv.Deserialize(task.TaskData) + suite.NoError(err) + + // Check that we have this delivery in slice. + err = containsSerializable(testDeliveries, &dlv) + suite.NoError(err) + + case gtsmodel.FederatorWorker: + var msg messages.FromFediAPI + + // Incr count. + nfederator++ + + // Deserialize the persisted task data. + err := msg.Deserialize(task.TaskData) + suite.NoError(err) + + // Check that we have this message in slice. + err = containsSerializable(testFederatorMsgs, &msg) + suite.NoError(err) + + case gtsmodel.ClientWorker: + var msg messages.FromClientAPI + + // Incr count. + nclient++ + + // Deserialize the persisted task data. + err := msg.Deserialize(task.TaskData) + suite.NoError(err) + + // Check that we have this message in slice. + err = containsSerializable(testClientMsgs, &msg) + suite.NoError(err) + + default: + suite.T().Errorf("unexpected worker type: %d", task.WorkerType) + } + } + + // Ensure persisted task counts as expected. + suite.Equal(len(testDeliveries), ndelivery) + suite.Equal(len(testFederatorMsgs), nfederator) + suite.Equal(len(testClientMsgs), nclient) +} + +func (suite *WorkerTaskTestSuite) SetupTest() { + suite.AdminStandardTestSuite.SetupTest() + // we don't want workers running + testrig.StopWorkers(&suite.state) +} + +func TestWorkerTaskTestSuite(t *testing.T) { + suite.Run(t, new(WorkerTaskTestSuite)) +} + +// containsSerializeable returns whether slice of serializables contains given serializable entry. +func containsSerializable[T interface{ Serialize() ([]byte, error) }](expect []T, have T) error { + // Serialize wanted value. + bh, err := have.Serialize() + if err != nil { + panic(err) + } + + var strings []string + + for _, t := range expect { + // Serialize expected value. + be, err := t.Serialize() + if err != nil { + panic(err) + } + + // Alloc as string. + se := string(be) + + if se == string(bh) { + // We have this entry! + return nil + } + + // Add to serialized strings. + strings = append(strings, se) + } + + return fmt.Errorf("could not find %s in %s", string(bh), strings) +} + +// urlStr simply returns u.String() or "" if nil. +func urlStr(u *url.URL) string { + if u == nil { + return "" + } + return u.String() +} + +// accountID simply returns account.ID or "" if nil. +func accountID(account *gtsmodel.Account) string { + if account == nil { + return "" + } + return account.ID +} + +// toRequest creates httpclient.Request from HTTP method, URL and body data. +func toRequest(method string, url string, body []byte, hdr http.Header) *httpclient.Request { + var rbody io.Reader + if body != nil { + rbody = bytes.NewReader(body) + } + req, err := http.NewRequest(method, url, rbody) + if err != nil { + panic(err) + } + for key, values := range hdr { + for _, value := range values { + req.Header.Add(key, value) + } + } + return httpclient.WrapRequest(req) +} + +// toJSON marshals input type as JSON data. +func toJSON(a any) []byte { + b, err := json.Marshal(a) + if err != nil { + panic(err) + } + return b +} diff --git a/internal/transport/deliver.go b/internal/transport/deliver.go index 30435b86f..36ad6f015 100644 --- a/internal/transport/deliver.go +++ b/internal/transport/deliver.go @@ -21,6 +21,7 @@ import ( "bytes" "context" "encoding/json" + "io" "net/http" "net/url" @@ -169,6 +170,38 @@ func (t *transport) prepare( }, nil } +func (t *transport) SignDelivery(dlv *delivery.Delivery) error { + if dlv.Request.GetBody == nil { + return gtserror.New("delivery request body not rewindable") + } + + // Get a new copy of the request body. + body, err := dlv.Request.GetBody() + if err != nil { + return gtserror.Newf("error getting request body: %w", err) + } + + // Read body data into memory. + data, err := io.ReadAll(body) + if err != nil { + return gtserror.Newf("error reading request body: %w", err) + } + + // Get signing function for POST data. + // (note that delivery is ALWAYS POST). + sign := t.signPOST(data) + + // Extract delivery context. + ctx := dlv.Request.Context() + + // Update delivery request context with signing details. + ctx = gtscontext.SetOutgoingPublicKeyID(ctx, t.pubKeyID) + ctx = gtscontext.SetHTTPClientSignFunc(ctx, sign) + dlv.Request.Request = dlv.Request.Request.WithContext(ctx) + + return nil +} + // getObjectID extracts an object ID from 'serialized' ActivityPub object map. func getObjectID(obj map[string]interface{}) string { switch t := obj["object"].(type) { diff --git a/internal/transport/delivery/delivery.go b/internal/transport/delivery/delivery.go index 1e3ebb054..e11eea83c 100644 --- a/internal/transport/delivery/delivery.go +++ b/internal/transport/delivery/delivery.go @@ -33,10 +33,6 @@ import ( // be indexed (and so, dropped from queue) // by any of these possible ID IRIs. type Delivery struct { - // PubKeyID is the signing public key - // ID of the actor performing request. - PubKeyID string - // ActorID contains the ActivityPub // actor ID IRI (if any) of the activity // being sent out by this request. @@ -55,7 +51,7 @@ type Delivery struct { // Request is the prepared (+ wrapped) // httpclient.Client{} request that // constitutes this ActivtyPub delivery. - Request httpclient.Request + Request *httpclient.Request // internal fields. next time.Time @@ -66,7 +62,6 @@ type Delivery struct { // a json serialize / deserialize // able shape that minimizes data. type delivery struct { - PubKeyID string `json:"pub_key_id,omitempty"` ActorID string `json:"actor_id,omitempty"` ObjectID string `json:"object_id,omitempty"` TargetID string `json:"target_id,omitempty"` @@ -101,7 +96,6 @@ func (dlv *Delivery) Serialize() ([]byte, error) { // Marshal as internal JSON type. return json.Marshal(delivery{ - PubKeyID: dlv.PubKeyID, ActorID: dlv.ActorID, ObjectID: dlv.ObjectID, TargetID: dlv.TargetID, @@ -125,7 +119,6 @@ func (dlv *Delivery) Deserialize(data []byte) error { } // Copy over simplest fields. - dlv.PubKeyID = idlv.PubKeyID dlv.ActorID = idlv.ActorID dlv.ObjectID = idlv.ObjectID dlv.TargetID = idlv.TargetID @@ -143,6 +136,13 @@ func (dlv *Delivery) Deserialize(data []byte) error { return err } + // Copy over any stored header values. + for key, values := range idlv.Header { + for _, value := range values { + r.Header.Add(key, value) + } + } + // Wrap request in httpclient type. dlv.Request = httpclient.WrapRequest(r) diff --git a/internal/transport/delivery/delivery_test.go b/internal/transport/delivery/delivery_test.go index e9eaf8fd1..81f32d5f8 100644 --- a/internal/transport/delivery/delivery_test.go +++ b/internal/transport/delivery/delivery_test.go @@ -35,32 +35,30 @@ var deliveryCases = []struct { }{ { msg: delivery.Delivery{ - PubKeyID: "https://google.com/users/bigboy#pubkey", ActorID: "https://google.com/users/bigboy", ObjectID: "https://google.com/users/bigboy/follow/1", TargetID: "https://askjeeves.com/users/smallboy", - Request: toRequest("POST", "https://askjeeves.com/users/smallboy/inbox", []byte("data!")), + Request: toRequest("POST", "https://askjeeves.com/users/smallboy/inbox", []byte("data!"), http.Header{"Hello": {"world1", "world2"}}), }, data: toJSON(map[string]any{ - "pub_key_id": "https://google.com/users/bigboy#pubkey", - "actor_id": "https://google.com/users/bigboy", - "object_id": "https://google.com/users/bigboy/follow/1", - "target_id": "https://askjeeves.com/users/smallboy", - "method": "POST", - "url": "https://askjeeves.com/users/smallboy/inbox", - "body": []byte("data!"), - // "header": map[string][]string{}, + "actor_id": "https://google.com/users/bigboy", + "object_id": "https://google.com/users/bigboy/follow/1", + "target_id": "https://askjeeves.com/users/smallboy", + "method": "POST", + "url": "https://askjeeves.com/users/smallboy/inbox", + "body": []byte("data!"), + "header": map[string][]string{"Hello": {"world1", "world2"}}, }), }, { msg: delivery.Delivery{ - Request: toRequest("GET", "https://google.com", []byte("uwu im just a wittle seawch engwin")), + Request: toRequest("GET", "https://google.com", []byte("uwu im just a wittle seawch engwin"), nil), }, data: toJSON(map[string]any{ "method": "GET", "url": "https://google.com", "body": []byte("uwu im just a wittle seawch engwin"), - // "header": map[string][]string{}, + // "header": map[string][]string{}, }), }, } @@ -89,18 +87,18 @@ func TestDeserializeDelivery(t *testing.T) { } // Check that delivery fields are as expected. - assert.Equal(t, test.msg.PubKeyID, msg.PubKeyID) assert.Equal(t, test.msg.ActorID, msg.ActorID) assert.Equal(t, test.msg.ObjectID, msg.ObjectID) assert.Equal(t, test.msg.TargetID, msg.TargetID) assert.Equal(t, test.msg.Request.Method, msg.Request.Method) assert.Equal(t, test.msg.Request.URL, msg.Request.URL) assert.Equal(t, readBody(test.msg.Request.Body), readBody(msg.Request.Body)) + assert.Equal(t, test.msg.Request.Header, msg.Request.Header) } } // toRequest creates httpclient.Request from HTTP method, URL and body data. -func toRequest(method string, url string, body []byte) httpclient.Request { +func toRequest(method string, url string, body []byte, hdr http.Header) *httpclient.Request { var rbody io.Reader if body != nil { rbody = bytes.NewReader(body) @@ -109,6 +107,11 @@ func toRequest(method string, url string, body []byte) httpclient.Request { if err != nil { panic(err) } + for key, values := range hdr { + for _, value := range values { + req.Header.Add(key, value) + } + } return httpclient.WrapRequest(req) } diff --git a/internal/transport/delivery/worker.go b/internal/transport/delivery/worker.go index ef31e94a6..d6d253769 100644 --- a/internal/transport/delivery/worker.go +++ b/internal/transport/delivery/worker.go @@ -19,6 +19,7 @@ package delivery import ( "context" + "errors" "slices" "time" @@ -160,6 +161,13 @@ func (w *Worker) process(ctx context.Context) bool { loop: for { + // Before trying to get + // next delivery, check + // context still valid. + if ctx.Err() != nil { + return true + } + // Get next delivery. dlv, ok := w.next(ctx) if !ok { @@ -195,16 +203,30 @@ loop: // Attempt delivery of AP request. rsp, retry, err := w.Client.DoOnce( - &dlv.Request, + dlv.Request, ) - if err == nil { + switch { + case err == nil: // Ensure body closed. _ = rsp.Body.Close() continue loop - } - if !retry { + case errors.Is(err, context.Canceled) && + ctx.Err() != nil: + // In the case of our own context + // being cancelled, push delivery + // back onto queue for persisting. + // + // Note we specifically check against + // context.Canceled here as it will + // be faster than the mutex lock of + // ctx.Err(), so gives an initial + // faster check in the if-clause. + w.Queue.Push(dlv) + continue loop + + case !retry: // Drop deliveries when no // retry requested, or they // reached max (either). @@ -222,42 +244,36 @@ loop: // next gets the next available delivery, blocking until available if necessary. func (w *Worker) next(ctx context.Context) (*Delivery, bool) { -loop: - for { - // Try pop next queued. - dlv, ok := w.Queue.Pop() + // Try a fast-pop of queued + // delivery before anything. + dlv, ok := w.Queue.Pop() - if !ok { - // Check the backlog. - if len(w.backlog) > 0 { + if !ok { + // Check the backlog. + if len(w.backlog) > 0 { - // Sort by 'next' time. - sortDeliveries(w.backlog) + // Sort by 'next' time. + sortDeliveries(w.backlog) - // Pop next delivery. - dlv := w.popBacklog() + // Pop next delivery. + dlv := w.popBacklog() - return dlv, true - } - - select { - // Backlog is empty, we MUST - // block until next enqueued. - case <-w.Queue.Wait(): - continue loop - - // Worker was stopped. - case <-ctx.Done(): - return nil, false - } + return dlv, true } - // Replace request context for worker state canceling. - ctx := gtscontext.WithValues(ctx, dlv.Request.Context()) - dlv.Request.Request = dlv.Request.Request.WithContext(ctx) - - return dlv, true + // Block on next delivery push + // OR worker context canceled. + dlv, ok = w.Queue.PopCtx(ctx) + if !ok { + return nil, false + } } + + // Replace request context for worker state canceling. + ctx = gtscontext.WithValues(ctx, dlv.Request.Context()) + dlv.Request.Request = dlv.Request.Request.WithContext(ctx) + + return dlv, true } // popBacklog pops next available from the backlog. diff --git a/internal/transport/transport.go b/internal/transport/transport.go index 2971ca603..7f7e985fc 100644 --- a/internal/transport/transport.go +++ b/internal/transport/transport.go @@ -30,6 +30,7 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/gtscontext" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/httpclient" + "github.com/superseriousbusiness/gotosocial/internal/transport/delivery" "github.com/superseriousbusiness/httpsig" ) @@ -50,6 +51,10 @@ type Transport interface { // transport client, retrying on certain preset errors. POST(*http.Request, []byte) (*http.Response, error) + // SignDelivery adds HTTP request signing client "middleware" + // to the request context within given delivery.Delivery{}. + SignDelivery(*delivery.Delivery) error + // Deliver sends an ActivityStreams object. Deliver(ctx context.Context, obj map[string]interface{}, to *url.URL) error diff --git a/internal/workers/worker_msg.go b/internal/workers/worker_msg.go index 92180651a..c7dc568d7 100644 --- a/internal/workers/worker_msg.go +++ b/internal/workers/worker_msg.go @@ -19,6 +19,7 @@ package workers import ( "context" + "errors" "codeberg.org/gruf/go-runners" "codeberg.org/gruf/go-structr" @@ -147,9 +148,25 @@ func (w *MsgWorker[T]) process(ctx context.Context) { return } - // Attempt to process popped message type. - if err := w.Process(ctx, msg); err != nil { + // Attempt to process message. + err := w.Process(ctx, msg) + if err != nil { log.Errorf(ctx, "%p: error processing: %v", w, err) + + if errors.Is(err, context.Canceled) && + ctx.Err() != nil { + // In the case of our own context + // being cancelled, push message + // back onto queue for persisting. + // + // Note we specifically check against + // context.Canceled here as it will + // be faster than the mutex lock of + // ctx.Err(), so gives an initial + // faster check in the if-clause. + w.Queue.Push(msg) + break + } } } } diff --git a/internal/workers/workers.go b/internal/workers/workers.go index 4d2b146b6..377a9d899 100644 --- a/internal/workers/workers.go +++ b/internal/workers/workers.go @@ -55,7 +55,8 @@ type Workers struct { // StartScheduler starts the job scheduler. func (w *Workers) StartScheduler() { - _ = w.Scheduler.Start() // false = already running + _ = w.Scheduler.Start() + // false = already running log.Info(nil, "started scheduler") } @@ -82,9 +83,12 @@ func (w *Workers) Start() { log.Infof(nil, "started %d dereference workers", n) } -// Stop will stop all of the contained worker pools (and global scheduler). +// Stop will stop all of the contained +// worker pools (and global scheduler). func (w *Workers) Stop() { - _ = w.Scheduler.Stop() // false = not running + _ = w.Scheduler.Stop() + // false = not running + log.Info(nil, "stopped scheduler") w.Delivery.Stop() log.Info(nil, "stopped delivery workers") diff --git a/testrig/db.go b/testrig/db.go index 67a7e2439..e6b40c846 100644 --- a/testrig/db.go +++ b/testrig/db.go @@ -29,6 +29,8 @@ import ( var testModels = []interface{}{ >smodel.Account{}, + >smodel.AccountNote{}, + >smodel.AccountSettings{}, >smodel.AccountToEmoji{}, >smodel.Application{}, >smodel.Block{}, @@ -67,8 +69,7 @@ var testModels = []interface{}{ >smodel.Tombstone{}, >smodel.Report{}, >smodel.Rule{}, - >smodel.AccountNote{}, - >smodel.AccountSettings{}, + >smodel.WorkerTask{}, } // NewTestDB returns a new initialized, empty database for testing.