[bugfix] Rework MultiError to wrap + unwrap errors properly (#2057)

* rework multierror a bit

* test multierror
This commit is contained in:
tobi 2023-08-02 17:21:46 +02:00 committed by GitHub
parent 2cee8f2dd8
commit e8a20f587c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
24 changed files with 263 additions and 154 deletions

View file

@ -75,14 +75,14 @@ func setupPrune(ctx context.Context) (*prune, error) {
}
func (p *prune) shutdown(ctx context.Context) error {
var errs gtserror.MultiError
errs := gtserror.NewMultiError(2)
if err := p.storage.Close(); err != nil {
errs.Appendf("error closing storage backend: %v", err)
errs.Appendf("error closing storage backend: %w", err)
}
if err := p.dbService.Stop(ctx); err != nil {
errs.Appendf("error stopping database: %v", err)
errs.Appendf("error stopping database: %w", err)
}
p.state.Workers.Stop()

View file

@ -22,7 +22,6 @@ import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/http/httptest"
@ -105,16 +104,16 @@ func (suite *InboxPostTestSuite) inboxPost(
suite.FailNow(err.Error())
}
errs := gtserror.MultiError{}
errs := gtserror.NewMultiError(2)
// Check expected code + body.
if resultCode := recorder.Code; expectedHTTPStatus != resultCode {
errs = append(errs, fmt.Sprintf("expected %d got %d", expectedHTTPStatus, resultCode))
errs.Appendf("expected %d got %d", expectedHTTPStatus, resultCode)
}
// If we got an expected body, return early.
if expectedBody != "" && string(b) != expectedBody {
errs = append(errs, fmt.Sprintf("expected %s got %s", expectedBody, string(b)))
errs.Appendf("expected %s got %s", expectedBody, string(b))
}
if err := errs.Combine(); err != nil {

View file

@ -90,16 +90,16 @@ func (suite *AccountUpdateTestSuite) updateAccount(
return nil, err
}
errs := gtserror.MultiError{}
errs := gtserror.NewMultiError(2)
// Check expected code + body.
if resultCode := recorder.Code; expectedHTTPStatus != resultCode {
errs = append(errs, fmt.Sprintf("expected %d got %d", expectedHTTPStatus, resultCode))
errs.Appendf("expected %d got %d", expectedHTTPStatus, resultCode)
}
// If we got an expected body, return early.
if expectedBody != "" && string(b) != expectedBody {
errs = append(errs, fmt.Sprintf("expected %s got %s", expectedBody, string(b)))
errs.Appendf("expected %s got %s", expectedBody, string(b))
}
if err := errs.Combine(); err != nil {

View file

@ -19,7 +19,6 @@ package accounts_test
import (
"encoding/json"
"fmt"
"io"
"net/http"
"net/http/httptest"
@ -63,16 +62,16 @@ func (suite *ListsTestSuite) getLists(targetAccountID string, expectedHTTPStatus
suite.FailNow(err.Error())
}
errs := gtserror.MultiError{}
errs := gtserror.NewMultiError(2)
// Check expected code + body.
if resultCode := recorder.Code; expectedHTTPStatus != resultCode {
errs = append(errs, fmt.Sprintf("expected %d got %d", expectedHTTPStatus, resultCode))
errs.Appendf("expected %d got %d", expectedHTTPStatus, resultCode)
}
// If we got an expected body, return early.
if expectedBody != "" && string(b) != expectedBody {
errs = append(errs, fmt.Sprintf("expected %s got %s", expectedBody, string(b)))
errs.Appendf("expected %s got %s", expectedBody, string(b))
}
if err := errs.Combine(); err != nil {

View file

@ -19,7 +19,6 @@ package accounts_test
import (
"encoding/json"
"fmt"
"io"
"net/http"
"net/http/httptest"
@ -99,16 +98,16 @@ func (suite *AccountSearchTestSuite) getSearch(
suite.FailNow(err.Error())
}
errs := gtserror.MultiError{}
errs := gtserror.NewMultiError(2)
// Check expected code + body.
if resultCode := recorder.Code; expectedHTTPStatus != resultCode {
errs = append(errs, fmt.Sprintf("expected %d got %d", expectedHTTPStatus, resultCode))
errs.Appendf("expected %d got %d", expectedHTTPStatus, resultCode)
}
// If we got an expected body, return early.
if expectedBody != "" && string(b) != expectedBody {
errs = append(errs, fmt.Sprintf("expected %s got %s", expectedBody, string(b)))
errs.Appendf("expected %s got %s", expectedBody, string(b))
}
if err := errs.Combine(); err != nil {

View file

@ -19,7 +19,6 @@ package admin_test
import (
"encoding/json"
"fmt"
"io/ioutil"
"net/http"
"net/http/httptest"
@ -84,16 +83,16 @@ func (suite *ReportResolveTestSuite) resolveReport(
return nil, err
}
errs := gtserror.MultiError{}
errs := gtserror.NewMultiError(2)
if resultCode := recorder.Code; expectedHTTPStatus != resultCode {
errs = append(errs, fmt.Sprintf("expected %d got %d", expectedHTTPStatus, resultCode))
errs.Appendf("expected %d got %d", expectedHTTPStatus, resultCode)
}
// if we got an expected body, return early
if expectedBody != "" {
if string(b) != expectedBody {
errs = append(errs, fmt.Sprintf("expected %s got %s", expectedBody, string(b)))
errs.Appendf("expected %s got %s", expectedBody, string(b))
}
return nil, errs.Combine()
}

View file

@ -19,7 +19,6 @@ package admin_test
import (
"encoding/json"
"fmt"
"io/ioutil"
"net/http"
"net/http/httptest"
@ -101,16 +100,16 @@ func (suite *ReportsGetTestSuite) getReports(
return nil, "", err
}
errs := gtserror.MultiError{}
errs := gtserror.NewMultiError(2)
if resultCode := recorder.Code; expectedHTTPStatus != resultCode {
errs = append(errs, fmt.Sprintf("expected %d got %d", expectedHTTPStatus, resultCode))
errs.Appendf("expected %d got %d", expectedHTTPStatus, resultCode)
}
// if we got an expected body, return early
if expectedBody != "" {
if string(b) != expectedBody {
errs = append(errs, fmt.Sprintf("expected %s got %s", expectedBody, string(b)))
errs.Appendf("expected %s got %s", expectedBody, string(b))
}
return nil, "", errs.Combine()
}

View file

@ -19,7 +19,6 @@ package lists_test
import (
"encoding/json"
"fmt"
"io/ioutil"
"net/http"
"net/http/httptest"
@ -103,17 +102,17 @@ func (suite *ListAccountsTestSuite) getListAccounts(
return nil, "", err
}
errs := gtserror.MultiError{}
errs := gtserror.NewMultiError(2)
// check code + body
if resultCode := recorder.Code; expectedHTTPStatus != resultCode {
errs = append(errs, fmt.Sprintf("expected %d got %d", expectedHTTPStatus, resultCode))
errs.Appendf("expected %d got %d", expectedHTTPStatus, resultCode)
}
// if we got an expected body, return early
if expectedBody != "" {
if string(b) != expectedBody {
errs = append(errs, fmt.Sprintf("expected %s got %s", expectedBody, string(b)))
errs.Appendf("expected %s got %s", expectedBody, string(b))
}
return nil, "", errs.Combine()
}

View file

@ -19,7 +19,6 @@ package reports_test
import (
"encoding/json"
"fmt"
"io/ioutil"
"net/http"
"net/http/httptest"
@ -77,17 +76,17 @@ func (suite *ReportCreateTestSuite) createReport(expectedHTTPStatus int, expecte
return nil, err
}
errs := gtserror.MultiError{}
errs := gtserror.NewMultiError(2)
// check code + body
if resultCode := recorder.Code; expectedHTTPStatus != resultCode {
errs = append(errs, fmt.Sprintf("expected %d got %d", expectedHTTPStatus, resultCode))
errs.Appendf("expected %d got %d", expectedHTTPStatus, resultCode)
}
// if we got an expected body, return early
if expectedBody != "" {
if string(b) != expectedBody {
errs = append(errs, fmt.Sprintf("expected %s got %s", expectedBody, string(b)))
errs.Appendf("expected %s got %s", expectedBody, string(b))
}
return nil, errs.Combine()
}

View file

@ -19,7 +19,6 @@ package reports_test
import (
"encoding/json"
"fmt"
"io/ioutil"
"net/http"
"net/http/httptest"
@ -64,17 +63,17 @@ func (suite *ReportGetTestSuite) getReport(expectedHTTPStatus int, expectedBody
return nil, err
}
errs := gtserror.MultiError{}
errs := gtserror.NewMultiError(2)
// check code + body
if resultCode := recorder.Code; expectedHTTPStatus != resultCode {
errs = append(errs, fmt.Sprintf("expected %d got %d", expectedHTTPStatus, resultCode))
errs.Appendf("expected %d got %d", expectedHTTPStatus, resultCode)
}
// if we got an expected body, return early
if expectedBody != "" {
if string(b) != expectedBody {
errs = append(errs, fmt.Sprintf("expected %s got %s", expectedBody, string(b)))
errs.Appendf("expected %s got %s", expectedBody, string(b))
}
return nil, errs.Combine()
}

View file

@ -22,7 +22,6 @@ import (
"crypto/rand"
"crypto/rsa"
"encoding/json"
"fmt"
"io"
"net/http"
"net/http/httptest"
@ -122,16 +121,16 @@ func (suite *SearchGetTestSuite) getSearch(
suite.FailNow(err.Error())
}
errs := gtserror.MultiError{}
errs := gtserror.NewMultiError(2)
// Check expected code + body.
if resultCode := recorder.Code; expectedHTTPStatus != resultCode {
errs = append(errs, fmt.Sprintf("expected %d got %d: %v", expectedHTTPStatus, resultCode, ctx.Errors.JSON()))
errs.Appendf("expected %d got %d", expectedHTTPStatus, resultCode)
}
// If we got an expected body, return early.
if expectedBody != "" && string(b) != expectedBody {
errs = append(errs, fmt.Sprintf("expected %s got %s", expectedBody, string(b)))
errs.Appendf("expected %s got %s", expectedBody, string(b))
}
if err := errs.Combine(); err != nil {

View file

@ -20,7 +20,6 @@ package statuses_test
import (
"context"
"encoding/json"
"fmt"
"io/ioutil"
"net/http"
"net/http/httptest"
@ -74,20 +73,20 @@ func (suite *StatusPinTestSuite) createPin(
return nil, err
}
errs := gtserror.MultiError{}
errs := gtserror.NewMultiError(2)
// check code + body
// Check expected code + body.
if resultCode := recorder.Code; expectedHTTPStatus != resultCode {
errs = append(errs, fmt.Sprintf("expected %d got %d", expectedHTTPStatus, resultCode))
errs.Appendf("expected %d got %d", expectedHTTPStatus, resultCode)
}
// if we got an expected body, return early
// If we got an expected body, return early.
if expectedBody != "" && string(b) != expectedBody {
errs = append(errs, fmt.Sprintf("expected %s got %s", expectedBody, string(b)))
errs.Appendf("expected %s got %s", expectedBody, string(b))
}
if len(errs) > 0 {
return nil, errs.Combine()
if err := errs.Combine(); err != nil {
suite.FailNow("", "%v (body %s)", err, string(b))
}
resp := &apimodel.Status{}

View file

@ -19,7 +19,6 @@ package statuses_test
import (
"encoding/json"
"fmt"
"io/ioutil"
"net/http"
"net/http/httptest"
@ -68,20 +67,20 @@ func (suite *StatusUnpinTestSuite) createUnpin(
return nil, err
}
errs := gtserror.MultiError{}
errs := gtserror.NewMultiError(2)
// check code + body
// Check expected code + body.
if resultCode := recorder.Code; expectedHTTPStatus != resultCode {
errs = append(errs, fmt.Sprintf("expected %d got %d", expectedHTTPStatus, resultCode))
errs.Appendf("expected %d got %d", expectedHTTPStatus, resultCode)
}
// if we got an expected body, return early
// If we got an expected body, return early.
if expectedBody != "" && string(b) != expectedBody {
errs = append(errs, fmt.Sprintf("expected %s got %s", expectedBody, string(b)))
errs.Appendf("expected %s got %s", expectedBody, string(b))
}
if len(errs) > 0 {
return nil, errs.Combine()
if err := errs.Combine(); err != nil {
suite.FailNow("", "%v (body %s)", err, string(b))
}
resp := &apimodel.Status{}

View file

@ -83,19 +83,23 @@ func (c *Cleaner) removeFiles(ctx context.Context, files ...string) (int, error)
return len(files), nil
}
var errs gtserror.MultiError
var (
errs gtserror.MultiError
errCount int
)
for _, path := range files {
// Remove each provided storage path.
log.Debugf(ctx, "removing file: %s", path)
err := c.state.Storage.Delete(ctx, path)
if err != nil && !errors.Is(err, storage.ErrNotFound) {
errs.Appendf("error removing %s: %v", path, err)
errs.Appendf("error removing %s: %w", path, err)
errCount++
}
}
// Calculate no. files removed.
diff := len(files) - len(errs)
diff := len(files) - errCount
// Wrap the combined error slice.
if err := errs.Combine(); err != nil {

View file

@ -20,7 +20,6 @@ package bundb
import (
"context"
"errors"
"fmt"
"strings"
"time"
@ -255,7 +254,7 @@ func (a *accountDB) getAccount(ctx context.Context, lookup string, dbQuery func(
func (a *accountDB) PopulateAccount(ctx context.Context, account *gtsmodel.Account) error {
var (
err error
errs = make(gtserror.MultiError, 0, 3)
errs = gtserror.NewMultiError(3)
)
if account.AvatarMediaAttachment == nil && account.AvatarMediaAttachmentID != "" {
@ -265,7 +264,7 @@ func (a *accountDB) PopulateAccount(ctx context.Context, account *gtsmodel.Accou
account.AvatarMediaAttachmentID,
)
if err != nil {
errs.Append(fmt.Errorf("error populating account avatar: %w", err))
errs.Appendf("error populating account avatar: %w", err)
}
}
@ -276,7 +275,7 @@ func (a *accountDB) PopulateAccount(ctx context.Context, account *gtsmodel.Accou
account.HeaderMediaAttachmentID,
)
if err != nil {
errs.Append(fmt.Errorf("error populating account header: %w", err))
errs.Appendf("error populating account header: %w", err)
}
}
@ -287,11 +286,15 @@ func (a *accountDB) PopulateAccount(ctx context.Context, account *gtsmodel.Accou
account.EmojiIDs,
)
if err != nil {
errs.Append(fmt.Errorf("error populating account emojis: %w", err))
errs.Appendf("error populating account emojis: %w", err)
}
}
return errs.Combine()
if err := errs.Combine(); err != nil {
return gtserror.Newf("%w", err)
}
return nil
}
func (a *accountDB) PutAccount(ctx context.Context, account *gtsmodel.Account) error {

View file

@ -173,7 +173,7 @@ func (i *instanceDB) getInstance(ctx context.Context, lookup string, dbQuery fun
func (i *instanceDB) populateInstance(ctx context.Context, instance *gtsmodel.Instance) error {
var (
err error
errs = make(gtserror.MultiError, 0, 2)
errs = gtserror.NewMultiError(2)
)
if instance.DomainBlockID != "" && instance.DomainBlock == nil {
@ -183,7 +183,7 @@ func (i *instanceDB) populateInstance(ctx context.Context, instance *gtsmodel.In
instance.Domain,
)
if err != nil {
errs.Append(gtserror.Newf("error populating instance domain block: %w", err))
errs.Appendf("error populating instance domain block: %w", err)
}
}
@ -194,11 +194,15 @@ func (i *instanceDB) populateInstance(ctx context.Context, instance *gtsmodel.In
instance.ContactAccountID,
)
if err != nil {
errs.Append(gtserror.Newf("error populating instance contact account: %w", err))
errs.Appendf("error populating instance contact account: %w", err)
}
}
return errs.Combine()
if err := errs.Combine(); err != nil {
return gtserror.Newf("%w", err)
}
return nil
}
func (i *instanceDB) PutInstance(ctx context.Context, instance *gtsmodel.Instance) error {

View file

@ -117,7 +117,7 @@ func (l *listDB) GetListsForAccountID(ctx context.Context, accountID string) ([]
func (l *listDB) PopulateList(ctx context.Context, list *gtsmodel.List) error {
var (
err error
errs = make(gtserror.MultiError, 0, 2)
errs = gtserror.NewMultiError(2)
)
if list.Account == nil {
@ -127,7 +127,7 @@ func (l *listDB) PopulateList(ctx context.Context, list *gtsmodel.List) error {
list.AccountID,
)
if err != nil {
errs.Append(fmt.Errorf("error populating list account: %w", err))
errs.Appendf("error populating list account: %w", err)
}
}
@ -139,11 +139,15 @@ func (l *listDB) PopulateList(ctx context.Context, list *gtsmodel.List) error {
"", "", "", 0,
)
if err != nil {
errs.Append(fmt.Errorf("error populating list entries: %w", err))
errs.Appendf("error populating list entries: %w", err)
}
}
return errs.Combine()
if err := errs.Combine(); err != nil {
return gtserror.Newf("%w", err)
}
return nil
}
func (l *listDB) PutList(ctx context.Context, list *gtsmodel.List) error {

View file

@ -160,7 +160,7 @@ func (r *relationshipDB) getFollow(ctx context.Context, lookup string, dbQuery f
func (r *relationshipDB) PopulateFollow(ctx context.Context, follow *gtsmodel.Follow) error {
var (
err error
errs = make(gtserror.MultiError, 0, 2)
errs = gtserror.NewMultiError(2)
)
if follow.Account == nil {
@ -170,7 +170,7 @@ func (r *relationshipDB) PopulateFollow(ctx context.Context, follow *gtsmodel.Fo
follow.AccountID,
)
if err != nil {
errs.Append(fmt.Errorf("error populating follow account: %w", err))
errs.Appendf("error populating follow account: %w", err)
}
}
@ -181,11 +181,15 @@ func (r *relationshipDB) PopulateFollow(ctx context.Context, follow *gtsmodel.Fo
follow.TargetAccountID,
)
if err != nil {
errs.Append(fmt.Errorf("error populating follow target account: %w", err))
errs.Appendf("error populating follow target account: %w", err)
}
}
return errs.Combine()
if err := errs.Combine(); err != nil {
return gtserror.Newf("%w", err)
}
return nil
}
func (r *relationshipDB) PutFollow(ctx context.Context, follow *gtsmodel.Follow) error {

View file

@ -22,7 +22,6 @@ import (
"context"
"database/sql"
"errors"
"fmt"
"time"
"github.com/superseriousbusiness/gotosocial/internal/db"
@ -129,7 +128,7 @@ func (s *statusDB) getStatus(ctx context.Context, lookup string, dbQuery func(*g
func (s *statusDB) PopulateStatus(ctx context.Context, status *gtsmodel.Status) error {
var (
err error
errs = make(gtserror.MultiError, 0, 9)
errs = gtserror.NewMultiError(9)
)
if status.Account == nil {
@ -139,7 +138,7 @@ func (s *statusDB) PopulateStatus(ctx context.Context, status *gtsmodel.Status)
status.AccountID,
)
if err != nil {
errs.Append(fmt.Errorf("error populating status author: %w", err))
errs.Appendf("error populating status author: %w", err)
}
}
@ -150,7 +149,7 @@ func (s *statusDB) PopulateStatus(ctx context.Context, status *gtsmodel.Status)
status.InReplyToID,
)
if err != nil {
errs.Append(fmt.Errorf("error populating status parent: %w", err))
errs.Appendf("error populating status parent: %w", err)
}
}
@ -162,7 +161,7 @@ func (s *statusDB) PopulateStatus(ctx context.Context, status *gtsmodel.Status)
status.InReplyToID,
)
if err != nil {
errs.Append(fmt.Errorf("error populating status parent: %w", err))
errs.Appendf("error populating status parent: %w", err)
}
}
@ -173,7 +172,7 @@ func (s *statusDB) PopulateStatus(ctx context.Context, status *gtsmodel.Status)
status.InReplyToAccountID,
)
if err != nil {
errs.Append(fmt.Errorf("error populating status parent author: %w", err))
errs.Appendf("error populating status parent author: %w", err)
}
}
}
@ -186,7 +185,7 @@ func (s *statusDB) PopulateStatus(ctx context.Context, status *gtsmodel.Status)
status.BoostOfID,
)
if err != nil {
errs.Append(fmt.Errorf("error populating status boost: %w", err))
errs.Appendf("error populating status boost: %w", err)
}
}
@ -197,7 +196,7 @@ func (s *statusDB) PopulateStatus(ctx context.Context, status *gtsmodel.Status)
status.BoostOfAccountID,
)
if err != nil {
errs.Append(fmt.Errorf("error populating status boost author: %w", err))
errs.Appendf("error populating status boost author: %w", err)
}
}
}
@ -209,7 +208,7 @@ func (s *statusDB) PopulateStatus(ctx context.Context, status *gtsmodel.Status)
status.AttachmentIDs,
)
if err != nil {
errs.Append(fmt.Errorf("error populating status attachments: %w", err))
errs.Appendf("error populating status attachments: %w", err)
}
}
@ -220,7 +219,7 @@ func (s *statusDB) PopulateStatus(ctx context.Context, status *gtsmodel.Status)
status.TagIDs,
)
if err != nil {
errs.Append(fmt.Errorf("error populating status tags: %w", err))
errs.Appendf("error populating status tags: %w", err)
}
}
@ -231,7 +230,7 @@ func (s *statusDB) PopulateStatus(ctx context.Context, status *gtsmodel.Status)
status.MentionIDs,
)
if err != nil {
errs.Append(fmt.Errorf("error populating status mentions: %w", err))
errs.Appendf("error populating status mentions: %w", err)
}
}
@ -242,11 +241,15 @@ func (s *statusDB) PopulateStatus(ctx context.Context, status *gtsmodel.Status)
status.EmojiIDs,
)
if err != nil {
errs.Append(fmt.Errorf("error populating status emojis: %w", err))
errs.Appendf("error populating status emojis: %w", err)
}
}
return errs.Combine()
if err := errs.Combine(); err != nil {
return gtserror.Newf("%w", err)
}
return nil
}
func (s *statusDB) PutStatus(ctx context.Context, status *gtsmodel.Status) error {

View file

@ -149,7 +149,7 @@ func (s *statusFaveDB) GetStatusFavesForStatus(ctx context.Context, statusID str
func (s *statusFaveDB) PopulateStatusFave(ctx context.Context, statusFave *gtsmodel.StatusFave) error {
var (
err error
errs = make(gtserror.MultiError, 0, 3)
errs = gtserror.NewMultiError(3)
)
if statusFave.Account == nil {
@ -159,7 +159,7 @@ func (s *statusFaveDB) PopulateStatusFave(ctx context.Context, statusFave *gtsmo
statusFave.AccountID,
)
if err != nil {
errs.Append(fmt.Errorf("error populating status fave author: %w", err))
errs.Appendf("error populating status fave author: %w", err)
}
}
@ -170,7 +170,7 @@ func (s *statusFaveDB) PopulateStatusFave(ctx context.Context, statusFave *gtsmo
statusFave.TargetAccountID,
)
if err != nil {
errs.Append(fmt.Errorf("error populating status fave target account: %w", err))
errs.Appendf("error populating status fave target account: %w", err)
}
}
@ -181,11 +181,15 @@ func (s *statusFaveDB) PopulateStatusFave(ctx context.Context, statusFave *gtsmo
statusFave.StatusID,
)
if err != nil {
errs.Append(fmt.Errorf("error populating status fave status: %w", err))
errs.Appendf("error populating status fave status: %w", err)
}
}
return errs.Combine()
if err := errs.Combine(); err != nil {
return gtserror.Newf("%w", err)
}
return nil
}
func (s *statusFaveDB) PutStatusFave(ctx context.Context, fave *gtsmodel.StatusFave) error {

View file

@ -20,25 +20,48 @@ package gtserror
import (
"errors"
"fmt"
"strings"
)
// MultiError allows encapsulating multiple errors under a singular instance,
// which is useful when you only want to log on errors, not return early / bubble up.
type MultiError []string
func (e *MultiError) Append(err error) {
*e = append(*e, err.Error())
// MultiError allows encapsulating multiple
// errors under a singular instance, which
// is useful when you only want to log on
// errors, not return early / bubble up.
type MultiError struct {
e []error
}
func (e *MultiError) Appendf(format string, args ...any) {
*e = append(*e, fmt.Sprintf(format, args...))
}
// Combine converts this multiError to a singular error instance, returning nil if empty.
func (e MultiError) Combine() error {
if len(e) == 0 {
return nil
// NewMultiError returns a *MultiError with
// the capacity of its underlying error slice
// set to the provided value.
//
// This capacity can be exceeded if necessary,
// but it saves a teeny tiny bit of memory if
// callers set it correctly.
//
// If you don't know in advance what the capacity
// must be, just use new(MultiError) instead.
func NewMultiError(capacity int) *MultiError {
return &MultiError{
e: make([]error, 0, capacity),
}
return errors.New(`"` + strings.Join(e, `","`) + `"`)
}
// Append the given error to the MultiError.
func (m *MultiError) Append(err error) {
m.e = append(m.e, err)
}
// Append the given format string to the MultiError.
//
// It is valid to use %w in the format string
// to wrap any other errors.
func (m *MultiError) Appendf(format string, args ...any) {
m.e = append(m.e, fmt.Errorf(format, args...))
}
// Combine the MultiError into a single error.
//
// Unwrap will work on the returned error as expected.
func (m MultiError) Combine() error {
return errors.Join(m.e...)
}

View file

@ -0,0 +1,64 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package gtserror
import (
"errors"
"testing"
"github.com/superseriousbusiness/gotosocial/internal/db"
)
func TestMultiError(t *testing.T) {
errs := MultiError{
e: []error{
db.ErrNoEntries,
errors.New("oopsie woopsie we did a fucky wucky etc"),
},
}
errs.Appendf("appended + wrapped error: %w", db.ErrAlreadyExists)
err := errs.Combine()
if !errors.Is(err, db.ErrNoEntries) {
t.Error("should be db.ErrNoEntries")
}
if !errors.Is(err, db.ErrAlreadyExists) {
t.Error("should be db.ErrAlreadyExists")
}
if errors.Is(err, db.ErrBusyTimeout) {
t.Error("should not be db.ErrBusyTimeout")
}
errString := err.Error()
expected := `sql: no rows in result set
oopsie woopsie we did a fucky wucky etc
appended + wrapped error: already exists`
if errString != expected {
t.Errorf("errString '%s' should be '%s'", errString, expected)
}
}
func TestMultiErrorEmpty(t *testing.T) {
err := new(MultiError).Combine()
if err != nil {
t.Errorf("should be nil")
}
}

View file

@ -20,7 +20,6 @@ package processing
import (
"context"
"errors"
"fmt"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
@ -42,13 +41,13 @@ import (
func (p *Processor) timelineAndNotifyStatus(ctx context.Context, status *gtsmodel.Status) error {
// Ensure status fully populated; including account, mentions, etc.
if err := p.state.DB.PopulateStatus(ctx, status); err != nil {
return fmt.Errorf("timelineAndNotifyStatus: error populating status with id %s: %w", status.ID, err)
return gtserror.Newf("error populating status with id %s: %w", status.ID, err)
}
// Get local followers of the account that posted the status.
follows, err := p.state.DB.GetAccountLocalFollowers(ctx, status.AccountID)
if err != nil {
return fmt.Errorf("timelineAndNotifyStatus: error getting local followers for account id %s: %w", status.AccountID, err)
return gtserror.Newf("error getting local followers for account id %s: %w", status.AccountID, err)
}
// If the poster is also local, add a fake entry for them
@ -66,12 +65,12 @@ func (p *Processor) timelineAndNotifyStatus(ctx context.Context, status *gtsmode
// This will also handle notifying any followers with notify
// set to true on their follow.
if err := p.timelineAndNotifyStatusForFollowers(ctx, status, follows); err != nil {
return fmt.Errorf("timelineAndNotifyStatus: error timelining status %s for followers: %w", status.ID, err)
return gtserror.Newf("error timelining status %s for followers: %w", status.ID, err)
}
// Notify each local account that's mentioned by this status.
if err := p.notifyStatusMentions(ctx, status); err != nil {
return fmt.Errorf("timelineAndNotifyStatus: error notifying status mentions for status %s: %w", status.ID, err)
return gtserror.Newf("error notifying status mentions for status %s: %w", status.ID, err)
}
return nil
@ -79,7 +78,7 @@ func (p *Processor) timelineAndNotifyStatus(ctx context.Context, status *gtsmode
func (p *Processor) timelineAndNotifyStatusForFollowers(ctx context.Context, status *gtsmodel.Status, follows []*gtsmodel.Follow) error {
var (
errs = make(gtserror.MultiError, 0, len(follows))
errs = gtserror.NewMultiError(len(follows))
boost = status.BoostOfID != ""
reply = status.InReplyToURI != ""
)
@ -100,7 +99,7 @@ func (p *Processor) timelineAndNotifyStatusForFollowers(ctx context.Context, sta
follow.ID,
)
if err != nil && !errors.Is(err, db.ErrNoEntries) {
errs.Append(fmt.Errorf("timelineAndNotifyStatusForFollowers: error list timelining status: %w", err))
errs.Appendf("error list timelining status: %w", err)
continue
}
@ -113,7 +112,7 @@ func (p *Processor) timelineAndNotifyStatusForFollowers(ctx context.Context, sta
status,
stream.TimelineList+":"+listEntry.ListID, // key streamType to this specific list
); err != nil {
errs.Append(fmt.Errorf("timelineAndNotifyStatusForFollowers: error list timelining status: %w", err))
errs.Appendf("error list timelining status: %w", err)
continue
}
}
@ -128,7 +127,7 @@ func (p *Processor) timelineAndNotifyStatusForFollowers(ctx context.Context, sta
status,
stream.TimelineHome,
); err != nil {
errs.Append(fmt.Errorf("timelineAndNotifyStatusForFollowers: error home timelining status: %w", err))
errs.Appendf("error home timelining status: %w", err)
continue
} else if !timelined {
// Status wasn't added to home tomeline,
@ -162,11 +161,15 @@ func (p *Processor) timelineAndNotifyStatusForFollowers(ctx context.Context, sta
status.AccountID,
status.ID,
); err != nil {
errs.Append(fmt.Errorf("timelineAndNotifyStatusForFollowers: error notifying account %s about new status: %w", follow.AccountID, err))
errs.Appendf("error notifying account %s about new status: %w", follow.AccountID, err)
}
}
return errs.Combine()
if err := errs.Combine(); err != nil {
return gtserror.Newf("%w", err)
}
return nil
}
// timelineStatus uses the provided ingest function to put the given
@ -185,7 +188,7 @@ func (p *Processor) timelineStatus(
// Make sure the status is timelineable.
// This works for both home and list timelines.
if timelineable, err := p.filter.StatusHomeTimelineable(ctx, account, status); err != nil {
err = fmt.Errorf("timelineStatusForAccount: error getting timelineability for status for timeline with id %s: %w", account.ID, err)
err = gtserror.Newf("error getting timelineability for status for timeline with id %s: %w", account.ID, err)
return false, err
} else if !timelineable {
// Nothing to do.
@ -194,7 +197,7 @@ func (p *Processor) timelineStatus(
// Ingest status into given timeline using provided function.
if inserted, err := ingest(ctx, timelineID, status); err != nil {
err = fmt.Errorf("timelineStatusForAccount: error ingesting status %s: %w", status.ID, err)
err = gtserror.Newf("error ingesting status %s: %w", status.ID, err)
return false, err
} else if !inserted {
// Nothing more to do.
@ -204,12 +207,12 @@ func (p *Processor) timelineStatus(
// The status was inserted so stream it to the user.
apiStatus, err := p.tc.StatusToAPIStatus(ctx, status, account)
if err != nil {
err = fmt.Errorf("timelineStatusForAccount: error converting status %s to frontend representation: %w", status.ID, err)
err = gtserror.Newf("error converting status %s to frontend representation: %w", status.ID, err)
return true, err
}
if err := p.stream.Update(apiStatus, account, []string{streamType}); err != nil {
err = fmt.Errorf("timelineStatusForAccount: error streaming update for status %s: %w", status.ID, err)
err = gtserror.Newf("error streaming update for status %s: %w", status.ID, err)
return true, err
}
@ -217,7 +220,7 @@ func (p *Processor) timelineStatus(
}
func (p *Processor) notifyStatusMentions(ctx context.Context, status *gtsmodel.Status) error {
errs := make(gtserror.MultiError, 0, len(status.Mentions))
errs := gtserror.NewMultiError(len(status.Mentions))
for _, m := range status.Mentions {
if err := p.notify(
@ -231,7 +234,11 @@ func (p *Processor) notifyStatusMentions(ctx context.Context, status *gtsmodel.S
}
}
return errs.Combine()
if err := errs.Combine(); err != nil {
return gtserror.Newf("%w", err)
}
return nil
}
func (p *Processor) notifyFollowRequest(ctx context.Context, followRequest *gtsmodel.FollowRequest) error {
@ -255,13 +262,13 @@ func (p *Processor) notifyFollow(ctx context.Context, follow *gtsmodel.Follow, t
)
if err != nil && !errors.Is(err, db.ErrNoEntries) {
// Proper error while checking.
return fmt.Errorf("notifyFollow: db error checking for previous follow request notification: %w", err)
return gtserror.Newf("db error checking for previous follow request notification: %w", err)
}
if prevNotif != nil {
// Previous notification existed, delete.
if err := p.state.DB.DeleteNotificationByID(ctx, prevNotif.ID); err != nil {
return fmt.Errorf("notifyFollow: db error removing previous follow request notification %s: %w", prevNotif.ID, err)
return gtserror.Newf("db error removing previous follow request notification %s: %w", prevNotif.ID, err)
}
}
@ -319,7 +326,7 @@ func (p *Processor) notify(
) error {
targetAccount, err := p.state.DB.GetAccountByID(ctx, targetAccountID)
if err != nil {
return fmt.Errorf("notify: error getting target account %s: %w", targetAccountID, err)
return gtserror.Newf("error getting target account %s: %w", targetAccountID, err)
}
if !targetAccount.IsLocal() {
@ -340,7 +347,7 @@ func (p *Processor) notify(
return nil
} else if !errors.Is(err, db.ErrNoEntries) {
// Real error.
return fmt.Errorf("notify: error checking existence of notification: %w", err)
return gtserror.Newf("error checking existence of notification: %w", err)
}
// Notification doesn't yet exist, so
@ -354,17 +361,17 @@ func (p *Processor) notify(
}
if err := p.state.DB.PutNotification(ctx, notif); err != nil {
return fmt.Errorf("notify: error putting notification in database: %w", err)
return gtserror.Newf("error putting notification in database: %w", err)
}
// Stream notification to the user.
apiNotif, err := p.tc.NotificationToAPINotification(ctx, notif)
if err != nil {
return fmt.Errorf("notify: error converting notification to api representation: %w", err)
return gtserror.Newf("error converting notification to api representation: %w", err)
}
if err := p.stream.Notify(apiNotif, targetAccount); err != nil {
return fmt.Errorf("notify: error streaming notification to account: %w", err)
return gtserror.Newf("error streaming notification to account: %w", err)
}
return nil
@ -479,7 +486,7 @@ func (p *Processor) invalidateStatusFromTimelines(ctx context.Context, statusID
func (p *Processor) emailReport(ctx context.Context, report *gtsmodel.Report) error {
instance, err := p.state.DB.GetInstance(ctx, config.GetHost())
if err != nil {
return fmt.Errorf("emailReport: error getting instance: %w", err)
return gtserror.Newf("error getting instance: %w", err)
}
toAddresses, err := p.state.DB.GetInstanceModeratorAddresses(ctx)
@ -488,20 +495,20 @@ func (p *Processor) emailReport(ctx context.Context, report *gtsmodel.Report) er
// No registered moderator addresses.
return nil
}
return fmt.Errorf("emailReport: error getting instance moderator addresses: %w", err)
return gtserror.Newf("error getting instance moderator addresses: %w", err)
}
if report.Account == nil {
report.Account, err = p.state.DB.GetAccountByID(ctx, report.AccountID)
if err != nil {
return fmt.Errorf("emailReport: error getting report account: %w", err)
return gtserror.Newf("error getting report account: %w", err)
}
}
if report.TargetAccount == nil {
report.TargetAccount, err = p.state.DB.GetAccountByID(ctx, report.TargetAccountID)
if err != nil {
return fmt.Errorf("emailReport: error getting report target account: %w", err)
return gtserror.Newf("error getting report target account: %w", err)
}
}
@ -514,7 +521,7 @@ func (p *Processor) emailReport(ctx context.Context, report *gtsmodel.Report) er
}
if err := p.emailSender.SendNewReportEmail(toAddresses, reportData); err != nil {
return fmt.Errorf("emailReport: error emailing instance moderators: %w", err)
return gtserror.Newf("error emailing instance moderators: %w", err)
}
return nil
@ -523,7 +530,7 @@ func (p *Processor) emailReport(ctx context.Context, report *gtsmodel.Report) er
func (p *Processor) emailReportClosed(ctx context.Context, report *gtsmodel.Report) error {
user, err := p.state.DB.GetUserByAccountID(ctx, report.Account.ID)
if err != nil {
return fmt.Errorf("emailReportClosed: db error getting user: %w", err)
return gtserror.Newf("db error getting user: %w", err)
}
if user.ConfirmedAt.IsZero() || !*user.Approved || *user.Disabled || user.Email == "" {
@ -537,20 +544,20 @@ func (p *Processor) emailReportClosed(ctx context.Context, report *gtsmodel.Repo
instance, err := p.state.DB.GetInstance(ctx, config.GetHost())
if err != nil {
return fmt.Errorf("emailReportClosed: db error getting instance: %w", err)
return gtserror.Newf("db error getting instance: %w", err)
}
if report.Account == nil {
report.Account, err = p.state.DB.GetAccountByID(ctx, report.AccountID)
if err != nil {
return fmt.Errorf("emailReportClosed: error getting report account: %w", err)
return gtserror.Newf("error getting report account: %w", err)
}
}
if report.TargetAccount == nil {
report.TargetAccount, err = p.state.DB.GetAccountByID(ctx, report.TargetAccountID)
if err != nil {
return fmt.Errorf("emailReportClosed: error getting report target account: %w", err)
return gtserror.Newf("error getting report target account: %w", err)
}
}

View file

@ -190,18 +190,18 @@ func (m *manager) GetOldestIndexedID(ctx context.Context, timelineID string) str
}
func (m *manager) WipeItemFromAllTimelines(ctx context.Context, itemID string) error {
errors := gtserror.MultiError{}
errs := new(gtserror.MultiError)
m.timelines.Range(func(_ any, v any) bool {
if _, err := v.(Timeline).Remove(ctx, itemID); err != nil {
errors.Append(err)
errs.Append(err)
}
return true // always continue range
})
if len(errors) > 0 {
return gtserror.Newf("error(s) wiping status %s: %w", itemID, errors.Combine())
if err := errs.Combine(); err != nil {
return gtserror.Newf("error(s) wiping status %s: %w", itemID, errs.Combine())
}
return nil
@ -213,21 +213,21 @@ func (m *manager) WipeItemsFromAccountID(ctx context.Context, timelineID string,
}
func (m *manager) UnprepareItemFromAllTimelines(ctx context.Context, itemID string) error {
errors := gtserror.MultiError{}
errs := new(gtserror.MultiError)
// Work through all timelines held by this
// manager, and call Unprepare for each.
m.timelines.Range(func(_ any, v any) bool {
// nolint:forcetypeassert
if err := v.(Timeline).Unprepare(ctx, itemID); err != nil {
errors.Append(err)
errs.Append(err)
}
return true // always continue range
})
if len(errors) > 0 {
return gtserror.Newf("error(s) unpreparing status %s: %w", itemID, errors.Combine())
if err := errs.Combine(); err != nil {
return gtserror.Newf("error(s) unpreparing status %s: %w", itemID, errs.Combine())
}
return nil