From e70bf8a6c82e3d5c943550b364fc6f8120f6f07e Mon Sep 17 00:00:00 2001 From: tobi <31960611+tsmethurst@users.noreply.github.com> Date: Fri, 7 Jul 2023 11:34:12 +0200 Subject: [PATCH] [chore/bugfix] Domain block tidying up, Implement first pass of `207 Multi-Status` (#1886) * [chore/refactor] update domain block processing * expose domain block import errors a lil better * move/remove unused query keys --- internal/api/client/admin/admin.go | 2 - .../api/client/admin/domainblockcreate.go | 71 ++- internal/api/client/admin/domainblockget.go | 25 +- internal/api/client/admin/domainblocksget.go | 15 +- internal/api/model/multistatus.go | 90 +++ internal/api/util/parsequery.go | 148 ++--- internal/cache/gts.go | 22 + internal/config/config.go | 4 + internal/config/defaults.go | 4 + internal/config/helpers.gen.go | 75 +++ internal/db/bundb/bundb.go | 3 +- internal/db/bundb/domain.go | 29 +- internal/db/bundb/instance.go | 218 ++++++- internal/db/domain.go | 12 +- internal/db/instance.go | 9 + internal/federation/federatingprotocol.go | 2 +- internal/processing/admin/domainblock.go | 589 +++++++++++------- internal/processing/instance.go | 28 +- test/envparsing.sh | 3 + 19 files changed, 969 insertions(+), 380 deletions(-) create mode 100644 internal/api/model/multistatus.go diff --git a/internal/api/client/admin/admin.go b/internal/api/client/admin/admin.go index 4079dd97..a6c825b2 100644 --- a/internal/api/client/admin/admin.go +++ b/internal/api/client/admin/admin.go @@ -42,8 +42,6 @@ const ( EmailPath = BasePath + "/email" EmailTestPath = EmailPath + "/test" - ExportQueryKey = "export" - ImportQueryKey = "import" IDKey = "id" FilterQueryKey = "filter" MaxShortcodeDomainKey = "max_shortcode_domain" diff --git a/internal/api/client/admin/domainblockcreate.go b/internal/api/client/admin/domainblockcreate.go index 5177cb03..148fad7c 100644 --- a/internal/api/client/admin/domainblockcreate.go +++ b/internal/api/client/admin/domainblockcreate.go @@ -21,7 +21,6 @@ import ( "errors" "fmt" "net/http" - "strconv" "github.com/gin-gonic/gin" apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" @@ -140,48 +139,78 @@ func (m *Module) DomainBlocksPOSTHandler(c *gin.Context) { return } - imp := false - importString := c.Query(ImportQueryKey) - if importString != "" { - i, err := strconv.ParseBool(importString) - if err != nil { - err := fmt.Errorf("error parsing %s: %s", ImportQueryKey, err) - apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGetV1) - return - } - imp = i + importing, errWithCode := apiutil.ParseDomainBlockImport(c.Query(apiutil.DomainBlockImportKey), false) + if errWithCode != nil { + apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) + return } - form := &apimodel.DomainBlockCreateRequest{} + form := new(apimodel.DomainBlockCreateRequest) if err := c.ShouldBind(form); err != nil { apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGetV1) return } - if err := validateCreateDomainBlock(form, imp); err != nil { - err := fmt.Errorf("error validating form: %s", err) + if err := validateCreateDomainBlock(form, importing); err != nil { + err := fmt.Errorf("error validating form: %w", err) apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGetV1) return } - if imp { - // we're importing multiple blocks - domainBlocks, errWithCode := m.processor.Admin().DomainBlocksImport(c.Request.Context(), authed.Account, form.Domains) + if !importing { + // Single domain block creation. + domainBlock, errWithCode := m.processor.Admin().DomainBlockCreate( + c.Request.Context(), + authed.Account, + form.Domain, + form.Obfuscate, + form.PublicComment, + form.PrivateComment, + "", // No sub ID for single block creation. + ) if errWithCode != nil { apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) return } - c.JSON(http.StatusOK, domainBlocks) + + c.JSON(http.StatusOK, domainBlock) return } - // we're just creating one block - domainBlock, errWithCode := m.processor.Admin().DomainBlockCreate(c.Request.Context(), authed.Account, form.Domain, form.Obfuscate, form.PublicComment, form.PrivateComment, "") + // We're importing multiple domain blocks, + // so we're looking at a multi-status response. + multiStatus, errWithCode := m.processor.Admin().DomainBlocksImport( + c.Request.Context(), + authed.Account, + form.Domains, // Pass the file through. + ) if errWithCode != nil { apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) return } - c.JSON(http.StatusOK, domainBlock) + + // TODO: Return 207 and multiStatus data nicely + // when supported by the admin panel. + + if multiStatus.Metadata.Failure != 0 { + failures := make(map[string]any, multiStatus.Metadata.Failure) + for _, entry := range multiStatus.Data { + // nolint:forcetypeassert + failures[entry.Resource.(string)] = entry.Message + } + + err := fmt.Errorf("one or more errors importing domain blocks: %+v", failures) + apiutil.ErrorHandler(c, gtserror.NewErrorUnprocessableEntity(err, err.Error()), m.processor.InstanceGetV1) + return + } + + // Success, return slice of domain blocks. + domainBlocks := make([]any, 0, multiStatus.Metadata.Success) + for _, entry := range multiStatus.Data { + domainBlocks = append(domainBlocks, entry.Resource) + } + + c.JSON(http.StatusOK, domainBlocks) } func validateCreateDomainBlock(form *apimodel.DomainBlockCreateRequest, imp bool) error { diff --git a/internal/api/client/admin/domainblockget.go b/internal/api/client/admin/domainblockget.go index 335faed9..87bb75a2 100644 --- a/internal/api/client/admin/domainblockget.go +++ b/internal/api/client/admin/domainblockget.go @@ -18,10 +18,8 @@ package admin import ( - "errors" "fmt" "net/http" - "strconv" "github.com/gin-gonic/gin" apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util" @@ -87,26 +85,19 @@ func (m *Module) DomainBlockGETHandler(c *gin.Context) { return } - domainBlockID := c.Param(IDKey) - if domainBlockID == "" { - err := errors.New("no domain block id specified") - apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGetV1) + domainBlockID, errWithCode := apiutil.ParseID(c.Param(apiutil.IDKey)) + if errWithCode != nil { + apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) return } - export := false - exportString := c.Query(ExportQueryKey) - if exportString != "" { - i, err := strconv.ParseBool(exportString) - if err != nil { - err := fmt.Errorf("error parsing %s: %s", ExportQueryKey, err) - apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGetV1) - return - } - export = i + export, errWithCode := apiutil.ParseDomainBlockExport(c.Query(apiutil.DomainBlockExportKey), false) + if errWithCode != nil { + apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) + return } - domainBlock, errWithCode := m.processor.Admin().DomainBlockGet(c.Request.Context(), authed.Account, domainBlockID, export) + domainBlock, errWithCode := m.processor.Admin().DomainBlockGet(c.Request.Context(), domainBlockID, export) if errWithCode != nil { apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) return diff --git a/internal/api/client/admin/domainblocksget.go b/internal/api/client/admin/domainblocksget.go index d641fc0e..68947f47 100644 --- a/internal/api/client/admin/domainblocksget.go +++ b/internal/api/client/admin/domainblocksget.go @@ -20,7 +20,6 @@ package admin import ( "fmt" "net/http" - "strconv" "github.com/gin-gonic/gin" apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util" @@ -92,16 +91,10 @@ func (m *Module) DomainBlocksGETHandler(c *gin.Context) { return } - export := false - exportString := c.Query(ExportQueryKey) - if exportString != "" { - i, err := strconv.ParseBool(exportString) - if err != nil { - err := fmt.Errorf("error parsing %s: %s", ExportQueryKey, err) - apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGetV1) - return - } - export = i + export, errWithCode := apiutil.ParseDomainBlockExport(c.Query(apiutil.DomainBlockExportKey), false) + if errWithCode != nil { + apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) + return } domainBlocks, errWithCode := m.processor.Admin().DomainBlocksGet(c.Request.Context(), authed.Account, export) diff --git a/internal/api/model/multistatus.go b/internal/api/model/multistatus.go new file mode 100644 index 00000000..cac8b4f9 --- /dev/null +++ b/internal/api/model/multistatus.go @@ -0,0 +1,90 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package model + +// MultiStatus models a multistatus HTTP response body. +// This model should be transmitted along with http code +// 207 MULTI-STATUS to indicate a mixture of responses. +// See https://developer.mozilla.org/en-US/docs/Web/HTTP/Status/207 +// +// swagger:model multiStatus +type MultiStatus struct { + Data []MultiStatusEntry `json:"data"` + Metadata MultiStatusMetadata `json:"metadata"` +} + +// MultiStatusEntry models one entry in multistatus data. +// It can model either a success or a failure. The type +// and value of `Resource` is left to the discretion of +// the caller, but at minimum it should be expected to be +// JSON-serializable. +// +// swagger:model multiStatusEntry +type MultiStatusEntry struct { + // The resource/result for this entry. + // Value may be any type, check the docs + // per endpoint to see which to expect. + Resource any `json:"resource"` + // Message/error message for this entry. + Message string `json:"message"` + // HTTP status code of this entry. + Status int `json:"status"` +} + +// MultiStatusMetadata models an at-a-glance summary of +// the data contained in the MultiStatus. +// +// swagger:model multiStatusMetadata +type MultiStatusMetadata struct { + // Success count + failure count. + Total int `json:"total"` + // Count of successful results (2xx). + Success int `json:"success"` + // Count of unsuccessful results (!2xx). + Failure int `json:"failure"` +} + +// NewMultiStatus returns a new MultiStatus API model with +// the provided entries, which will be iterated through to +// look for 2xx and non 2xx status codes, in order to count +// successes and failures. +func NewMultiStatus(entries []MultiStatusEntry) *MultiStatus { + var ( + successCount int + failureCount int + total = len(entries) + ) + + for _, e := range entries { + // Outside 2xx range = failure. + if e.Status > 299 || e.Status < 200 { + failureCount++ + } else { + successCount++ + } + } + + return &MultiStatus{ + Data: entries, + Metadata: MultiStatusMetadata{ + Total: total, + Success: successCount, + Failure: failureCount, + }, + } +} diff --git a/internal/api/util/parsequery.go b/internal/api/util/parsequery.go index 460ca3e0..f5966bca 100644 --- a/internal/api/util/parsequery.go +++ b/internal/api/util/parsequery.go @@ -27,6 +27,7 @@ import ( const ( /* Common keys */ + IDKey = "id" LimitKey = "limit" LocalKey = "local" MaxIDKey = "max_id" @@ -41,6 +42,11 @@ const ( SearchQueryKey = "q" SearchResolveKey = "resolve" SearchTypeKey = "type" + + /* Domain block keys */ + + DomainBlockExportKey = "export" + DomainBlockImportKey = "import" ) // parseError returns gtserror.WithCode set to 400 Bad Request, to indicate @@ -50,6 +56,8 @@ func parseError(key string, value, defaultValue any, err error) gtserror.WithCod return gtserror.NewErrorBadRequest(err, err.Error()) } +// requiredError returns gtserror.WithCode set to 400 Bad Request, to indicate +// to the caller a required key value was not provided, or was empty. func requiredError(key string) gtserror.WithCode { err := fmt.Errorf("required key %s was not set or had empty value", key) return gtserror.NewErrorBadRequest(err, err.Error()) @@ -60,111 +68,51 @@ func requiredError(key string) gtserror.WithCode { */ func ParseLimit(value string, defaultValue int, max, min int) (int, gtserror.WithCode) { - key := LimitKey - - if value == "" { - return defaultValue, nil - } - - i, err := strconv.Atoi(value) - if err != nil { - return defaultValue, parseError(key, value, defaultValue, err) - } - - if i > max { - i = max - } else if i < min { - i = min - } - - return i, nil + return parseInt(value, defaultValue, max, min, LimitKey) } func ParseLocal(value string, defaultValue bool) (bool, gtserror.WithCode) { - key := LimitKey - - if value == "" { - return defaultValue, nil - } - - i, err := strconv.ParseBool(value) - if err != nil { - return defaultValue, parseError(key, value, defaultValue, err) - } - - return i, nil + return parseBool(value, defaultValue, LocalKey) } func ParseSearchExcludeUnreviewed(value string, defaultValue bool) (bool, gtserror.WithCode) { - key := SearchExcludeUnreviewedKey - - if value == "" { - return defaultValue, nil - } - - i, err := strconv.ParseBool(value) - if err != nil { - return defaultValue, parseError(key, value, defaultValue, err) - } - - return i, nil + return parseBool(value, defaultValue, SearchExcludeUnreviewedKey) } func ParseSearchFollowing(value string, defaultValue bool) (bool, gtserror.WithCode) { - key := SearchFollowingKey - - if value == "" { - return defaultValue, nil - } - - i, err := strconv.ParseBool(value) - if err != nil { - return defaultValue, parseError(key, value, defaultValue, err) - } - - return i, nil + return parseBool(value, defaultValue, SearchFollowingKey) } func ParseSearchOffset(value string, defaultValue int, max, min int) (int, gtserror.WithCode) { - key := SearchOffsetKey - - if value == "" { - return defaultValue, nil - } - - i, err := strconv.Atoi(value) - if err != nil { - return defaultValue, parseError(key, value, defaultValue, err) - } - - if i > max { - i = max - } else if i < min { - i = min - } - - return i, nil + return parseInt(value, defaultValue, max, min, SearchOffsetKey) } func ParseSearchResolve(value string, defaultValue bool) (bool, gtserror.WithCode) { - key := SearchResolveKey + return parseBool(value, defaultValue, SearchResolveKey) +} - if value == "" { - return defaultValue, nil - } +func ParseDomainBlockExport(value string, defaultValue bool) (bool, gtserror.WithCode) { + return parseBool(value, defaultValue, DomainBlockExportKey) +} - i, err := strconv.ParseBool(value) - if err != nil { - return defaultValue, parseError(key, value, defaultValue, err) - } - - return i, nil +func ParseDomainBlockImport(value string, defaultValue bool) (bool, gtserror.WithCode) { + return parseBool(value, defaultValue, DomainBlockImportKey) } /* Parse functions for *REQUIRED* parameters. */ +func ParseID(value string) (string, gtserror.WithCode) { + key := IDKey + + if value == "" { + return "", requiredError(key) + } + + return value, nil +} + func ParseSearchLookup(value string) (string, gtserror.WithCode) { key := SearchLookupKey @@ -184,3 +132,39 @@ func ParseSearchQuery(value string) (string, gtserror.WithCode) { return value, nil } + +/* + Internal functions +*/ + +func parseBool(value string, defaultValue bool, key string) (bool, gtserror.WithCode) { + if value == "" { + return defaultValue, nil + } + + i, err := strconv.ParseBool(value) + if err != nil { + return defaultValue, parseError(key, value, defaultValue, err) + } + + return i, nil +} + +func parseInt(value string, defaultValue int, max int, min int, key string) (int, gtserror.WithCode) { + if value == "" { + return defaultValue, nil + } + + i, err := strconv.Atoi(value) + if err != nil { + return defaultValue, parseError(key, value, defaultValue, err) + } + + if i > max { + i = max + } else if i < min { + i = min + } + + return i, nil +} diff --git a/internal/cache/gts.go b/internal/cache/gts.go index 3a2d0973..4b2e65b9 100644 --- a/internal/cache/gts.go +++ b/internal/cache/gts.go @@ -35,6 +35,7 @@ type GTSCaches struct { emojiCategory *result.Cache[*gtsmodel.EmojiCategory] follow *result.Cache[*gtsmodel.Follow] followRequest *result.Cache[*gtsmodel.FollowRequest] + instance *result.Cache[*gtsmodel.Instance] list *result.Cache[*gtsmodel.List] listEntry *result.Cache[*gtsmodel.ListEntry] media *result.Cache[*gtsmodel.MediaAttachment] @@ -59,6 +60,7 @@ func (c *GTSCaches) Init() { c.initEmojiCategory() c.initFollow() c.initFollowRequest() + c.initInstance() c.initList() c.initListEntry() c.initMedia() @@ -80,6 +82,7 @@ func (c *GTSCaches) Start() { tryStart(c.emojiCategory, config.GetCacheGTSEmojiCategorySweepFreq()) tryStart(c.follow, config.GetCacheGTSFollowSweepFreq()) tryStart(c.followRequest, config.GetCacheGTSFollowRequestSweepFreq()) + tryStart(c.instance, config.GetCacheGTSInstanceSweepFreq()) tryStart(c.list, config.GetCacheGTSListSweepFreq()) tryStart(c.listEntry, config.GetCacheGTSListEntrySweepFreq()) tryStart(c.media, config.GetCacheGTSMediaSweepFreq()) @@ -106,6 +109,7 @@ func (c *GTSCaches) Stop() { tryStop(c.emojiCategory, config.GetCacheGTSEmojiCategorySweepFreq()) tryStop(c.follow, config.GetCacheGTSFollowSweepFreq()) tryStop(c.followRequest, config.GetCacheGTSFollowRequestSweepFreq()) + tryStop(c.instance, config.GetCacheGTSInstanceSweepFreq()) tryStop(c.list, config.GetCacheGTSListSweepFreq()) tryStop(c.listEntry, config.GetCacheGTSListEntrySweepFreq()) tryStop(c.media, config.GetCacheGTSMediaSweepFreq()) @@ -154,6 +158,11 @@ func (c *GTSCaches) FollowRequest() *result.Cache[*gtsmodel.FollowRequest] { return c.followRequest } +// Instance provides access to the gtsmodel Instance database cache. +func (c *GTSCaches) Instance() *result.Cache[*gtsmodel.Instance] { + return c.instance +} + // List provides access to the gtsmodel List database cache. func (c *GTSCaches) List() *result.Cache[*gtsmodel.List] { return c.list @@ -301,6 +310,19 @@ func (c *GTSCaches) initFollowRequest() { c.followRequest.SetTTL(config.GetCacheGTSFollowRequestTTL(), true) } +func (c *GTSCaches) initInstance() { + c.instance = result.New([]result.Lookup{ + {Name: "ID"}, + {Name: "Domain"}, + }, func(i1 *gtsmodel.Instance) *gtsmodel.Instance { + i2 := new(gtsmodel.Instance) + *i2 = *i1 + return i1 + }, config.GetCacheGTSInstanceMaxSize()) + c.instance.SetTTL(config.GetCacheGTSInstanceTTL(), true) + c.emojiCategory.IgnoreErrors(ignoreErrors) +} + func (c *GTSCaches) initList() { c.list = result.New([]result.Lookup{ {Name: "ID"}, diff --git a/internal/config/config.go b/internal/config/config.go index 8dcbcaf9..c809bbc1 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -200,6 +200,10 @@ type GTSCacheConfiguration struct { FollowRequestTTL time.Duration `name:"follow-request-ttl"` FollowRequestSweepFreq time.Duration `name:"follow-request-sweep-freq"` + InstanceMaxSize int `name:"instance-max-size"` + InstanceTTL time.Duration `name:"instance-ttl"` + InstanceSweepFreq time.Duration `name:"instance-sweep-freq"` + ListMaxSize int `name:"list-max-size"` ListTTL time.Duration `name:"list-ttl"` ListSweepFreq time.Duration `name:"list-sweep-freq"` diff --git a/internal/config/defaults.go b/internal/config/defaults.go index c11f436d..1cb53c8e 100644 --- a/internal/config/defaults.go +++ b/internal/config/defaults.go @@ -154,6 +154,10 @@ var Defaults = Configuration{ FollowRequestTTL: time.Minute * 30, FollowRequestSweepFreq: time.Minute, + InstanceMaxSize: 2000, + InstanceTTL: time.Minute * 30, + InstanceSweepFreq: time.Minute, + ListMaxSize: 2000, ListTTL: time.Minute * 30, ListSweepFreq: time.Minute, diff --git a/internal/config/helpers.gen.go b/internal/config/helpers.gen.go index 16865f32..c82eba3b 100644 --- a/internal/config/helpers.gen.go +++ b/internal/config/helpers.gen.go @@ -2828,6 +2828,81 @@ func GetCacheGTSFollowRequestSweepFreq() time.Duration { // SetCacheGTSFollowRequestSweepFreq safely sets the value for global configuration 'Cache.GTS.FollowRequestSweepFreq' field func SetCacheGTSFollowRequestSweepFreq(v time.Duration) { global.SetCacheGTSFollowRequestSweepFreq(v) } +// GetCacheGTSInstanceMaxSize safely fetches the Configuration value for state's 'Cache.GTS.InstanceMaxSize' field +func (st *ConfigState) GetCacheGTSInstanceMaxSize() (v int) { + st.mutex.Lock() + v = st.config.Cache.GTS.InstanceMaxSize + st.mutex.Unlock() + return +} + +// SetCacheGTSInstanceMaxSize safely sets the Configuration value for state's 'Cache.GTS.InstanceMaxSize' field +func (st *ConfigState) SetCacheGTSInstanceMaxSize(v int) { + st.mutex.Lock() + defer st.mutex.Unlock() + st.config.Cache.GTS.InstanceMaxSize = v + st.reloadToViper() +} + +// CacheGTSInstanceMaxSizeFlag returns the flag name for the 'Cache.GTS.InstanceMaxSize' field +func CacheGTSInstanceMaxSizeFlag() string { return "cache-gts-instance-max-size" } + +// GetCacheGTSInstanceMaxSize safely fetches the value for global configuration 'Cache.GTS.InstanceMaxSize' field +func GetCacheGTSInstanceMaxSize() int { return global.GetCacheGTSInstanceMaxSize() } + +// SetCacheGTSInstanceMaxSize safely sets the value for global configuration 'Cache.GTS.InstanceMaxSize' field +func SetCacheGTSInstanceMaxSize(v int) { global.SetCacheGTSInstanceMaxSize(v) } + +// GetCacheGTSInstanceTTL safely fetches the Configuration value for state's 'Cache.GTS.InstanceTTL' field +func (st *ConfigState) GetCacheGTSInstanceTTL() (v time.Duration) { + st.mutex.Lock() + v = st.config.Cache.GTS.InstanceTTL + st.mutex.Unlock() + return +} + +// SetCacheGTSInstanceTTL safely sets the Configuration value for state's 'Cache.GTS.InstanceTTL' field +func (st *ConfigState) SetCacheGTSInstanceTTL(v time.Duration) { + st.mutex.Lock() + defer st.mutex.Unlock() + st.config.Cache.GTS.InstanceTTL = v + st.reloadToViper() +} + +// CacheGTSInstanceTTLFlag returns the flag name for the 'Cache.GTS.InstanceTTL' field +func CacheGTSInstanceTTLFlag() string { return "cache-gts-instance-ttl" } + +// GetCacheGTSInstanceTTL safely fetches the value for global configuration 'Cache.GTS.InstanceTTL' field +func GetCacheGTSInstanceTTL() time.Duration { return global.GetCacheGTSInstanceTTL() } + +// SetCacheGTSInstanceTTL safely sets the value for global configuration 'Cache.GTS.InstanceTTL' field +func SetCacheGTSInstanceTTL(v time.Duration) { global.SetCacheGTSInstanceTTL(v) } + +// GetCacheGTSInstanceSweepFreq safely fetches the Configuration value for state's 'Cache.GTS.InstanceSweepFreq' field +func (st *ConfigState) GetCacheGTSInstanceSweepFreq() (v time.Duration) { + st.mutex.Lock() + v = st.config.Cache.GTS.InstanceSweepFreq + st.mutex.Unlock() + return +} + +// SetCacheGTSInstanceSweepFreq safely sets the Configuration value for state's 'Cache.GTS.InstanceSweepFreq' field +func (st *ConfigState) SetCacheGTSInstanceSweepFreq(v time.Duration) { + st.mutex.Lock() + defer st.mutex.Unlock() + st.config.Cache.GTS.InstanceSweepFreq = v + st.reloadToViper() +} + +// CacheGTSInstanceSweepFreqFlag returns the flag name for the 'Cache.GTS.InstanceSweepFreq' field +func CacheGTSInstanceSweepFreqFlag() string { return "cache-gts-instance-sweep-freq" } + +// GetCacheGTSInstanceSweepFreq safely fetches the value for global configuration 'Cache.GTS.InstanceSweepFreq' field +func GetCacheGTSInstanceSweepFreq() time.Duration { return global.GetCacheGTSInstanceSweepFreq() } + +// SetCacheGTSInstanceSweepFreq safely sets the value for global configuration 'Cache.GTS.InstanceSweepFreq' field +func SetCacheGTSInstanceSweepFreq(v time.Duration) { global.SetCacheGTSInstanceSweepFreq(v) } + // GetCacheGTSListMaxSize safely fetches the Configuration value for state's 'Cache.GTS.ListMaxSize' field func (st *ConfigState) GetCacheGTSListMaxSize() (v int) { st.mutex.Lock() diff --git a/internal/db/bundb/bundb.go b/internal/db/bundb/bundb.go index 9d616954..ee28800b 100644 --- a/internal/db/bundb/bundb.go +++ b/internal/db/bundb/bundb.go @@ -179,7 +179,8 @@ func NewBunDBService(ctx context.Context, state *state.State) (db.DB, error) { state: state, }, Instance: &instanceDB{ - conn: conn, + conn: conn, + state: state, }, List: &listDB{ conn: conn, diff --git a/internal/db/bundb/domain.go b/internal/db/bundb/domain.go index 5c92645d..2e8ce2a6 100644 --- a/internal/db/bundb/domain.go +++ b/internal/db/bundb/domain.go @@ -42,7 +42,7 @@ func (d *domainDB) CreateDomainBlock(ctx context.Context, block *gtsmodel.Domain return err } - // Attempt to store domain in DB + // Attempt to store domain block in DB if _, err := d.conn.NewInsert(). Model(block). Exec(ctx); err != nil { @@ -82,6 +82,33 @@ func (d *domainDB) GetDomainBlock(ctx context.Context, domain string) (*gtsmodel return &block, nil } +func (d *domainDB) GetDomainBlocks(ctx context.Context) ([]*gtsmodel.DomainBlock, error) { + blocks := []*gtsmodel.DomainBlock{} + + if err := d.conn. + NewSelect(). + Model(&blocks). + Scan(ctx); err != nil { + return nil, d.conn.ProcessError(err) + } + + return blocks, nil +} + +func (d *domainDB) GetDomainBlockByID(ctx context.Context, id string) (*gtsmodel.DomainBlock, db.Error) { + var block gtsmodel.DomainBlock + + q := d.conn. + NewSelect(). + Model(&block). + Where("? = ?", bun.Ident("domain_block.id"), id) + if err := q.Scan(ctx); err != nil { + return nil, d.conn.ProcessError(err) + } + + return &block, nil +} + func (d *domainDB) DeleteDomainBlock(ctx context.Context, domain string) db.Error { // Normalize the domain as punycode domain, err := util.Punify(domain) diff --git a/internal/db/bundb/instance.go b/internal/db/bundb/instance.go index 95f6ad5b..60d77600 100644 --- a/internal/db/bundb/instance.go +++ b/internal/db/bundb/instance.go @@ -19,15 +19,23 @@ package bundb import ( "context" + "time" "github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/db" + "github.com/superseriousbusiness/gotosocial/internal/gtscontext" + "github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/superseriousbusiness/gotosocial/internal/id" + "github.com/superseriousbusiness/gotosocial/internal/log" + "github.com/superseriousbusiness/gotosocial/internal/state" + "github.com/superseriousbusiness/gotosocial/internal/util" "github.com/uptrace/bun" ) type instanceDB struct { - conn *DBConn + conn *DBConn + state *state.State } func (i *instanceDB) CountInstanceUsers(ctx context.Context, domain string) (int, db.Error) { @@ -99,62 +107,236 @@ func (i *instanceDB) CountInstanceDomains(ctx context.Context, domain string) (i } func (i *instanceDB) GetInstance(ctx context.Context, domain string) (*gtsmodel.Instance, db.Error) { - instance := >smodel.Instance{} + // Normalize the domain as punycode + var err error + domain, err = util.Punify(domain) + if err != nil { + return nil, gtserror.Newf("error punifying domain %s: %w", domain, err) + } - if err := i.conn. - NewSelect(). - Model(instance). - Where("? = ?", bun.Ident("instance.domain"), domain). - Scan(ctx); err != nil { - return nil, i.conn.ProcessError(err) + return i.getInstance( + ctx, + "Domain", + func(instance *gtsmodel.Instance) error { + return i.conn.NewSelect(). + Model(instance). + Where("? = ?", bun.Ident("instance.domain"), domain). + Scan(ctx) + }, + domain, + ) +} + +func (i *instanceDB) GetInstanceByID(ctx context.Context, id string) (*gtsmodel.Instance, error) { + return i.getInstance( + ctx, + "ID", + func(instance *gtsmodel.Instance) error { + return i.conn.NewSelect(). + Model(instance). + Where("? = ?", bun.Ident("instance.id"), id). + Scan(ctx) + }, + id, + ) +} + +func (i *instanceDB) getInstance(ctx context.Context, lookup string, dbQuery func(*gtsmodel.Instance) error, keyParts ...any) (*gtsmodel.Instance, db.Error) { + // Fetch instance from database cache with loader callback + instance, err := i.state.Caches.GTS.Instance().Load(lookup, func() (*gtsmodel.Instance, error) { + var instance gtsmodel.Instance + + // Not cached! Perform database query. + if err := dbQuery(&instance); err != nil { + return nil, i.conn.ProcessError(err) + } + + return &instance, nil + }, keyParts...) + if err != nil { + return nil, err + } + + if gtscontext.Barebones(ctx) { + // no need to fully populate. + return instance, nil + } + + // Further populate the instance fields where applicable. + if err := i.populateInstance(ctx, instance); err != nil { + return nil, err } return instance, nil } +func (i *instanceDB) populateInstance(ctx context.Context, instance *gtsmodel.Instance) error { + var ( + err error + errs = make(gtserror.MultiError, 0, 2) + ) + + if instance.DomainBlockID != "" && instance.DomainBlock == nil { + // Instance domain block is not set, fetch from database. + instance.DomainBlock, err = i.state.DB.GetDomainBlock( + gtscontext.SetBarebones(ctx), + instance.Domain, + ) + if err != nil { + errs.Append(gtserror.Newf("error populating instance domain block: %w", err)) + } + } + + if instance.ContactAccountID != "" && instance.ContactAccount == nil { + // Instance domain block is not set, fetch from database. + instance.ContactAccount, err = i.state.DB.GetAccountByID( + gtscontext.SetBarebones(ctx), + instance.ContactAccountID, + ) + if err != nil { + errs.Append(gtserror.Newf("error populating instance contact account: %w", err)) + } + } + + return errs.Combine() +} + +func (i *instanceDB) PutInstance(ctx context.Context, instance *gtsmodel.Instance) error { + // Normalize the domain as punycode + var err error + instance.Domain, err = util.Punify(instance.Domain) + if err != nil { + return gtserror.Newf("error punifying domain %s: %w", instance.Domain, err) + } + + return i.state.Caches.GTS.Instance().Store(instance, func() error { + _, err := i.conn.NewInsert().Model(instance).Exec(ctx) + return i.conn.ProcessError(err) + }) +} + +func (i *instanceDB) UpdateInstance(ctx context.Context, instance *gtsmodel.Instance, columns ...string) error { + // Normalize the domain as punycode + var err error + instance.Domain, err = util.Punify(instance.Domain) + if err != nil { + return gtserror.Newf("error punifying domain %s: %w", instance.Domain, err) + } + + // Update the instance's last-updated + instance.UpdatedAt = time.Now() + if len(columns) != 0 { + columns = append(columns, "updated_at") + } + + return i.state.Caches.GTS.Instance().Store(instance, func() error { + _, err := i.conn. + NewUpdate(). + Model(instance). + Where("? = ?", bun.Ident("instance.id"), instance.ID). + Column(columns...). + Exec(ctx) + return i.conn.ProcessError(err) + }) +} + func (i *instanceDB) GetInstancePeers(ctx context.Context, includeSuspended bool) ([]*gtsmodel.Instance, db.Error) { - instances := []*gtsmodel.Instance{} + instanceIDs := []string{} q := i.conn. NewSelect(). - Model(&instances). + TableExpr("? AS ?", bun.Ident("instances"), bun.Ident("instance")). + // Select just the IDs of each instance. + Column("instance.id"). + // Exclude our own instance. Where("? != ?", bun.Ident("instance.domain"), config.GetHost()) if !includeSuspended { q = q.Where("? IS NULL", bun.Ident("instance.suspended_at")) } - if err := q.Scan(ctx); err != nil { + if err := q.Scan(ctx, &instanceIDs); err != nil { return nil, i.conn.ProcessError(err) } + if len(instanceIDs) == 0 { + return make([]*gtsmodel.Instance, 0), nil + } + + instances := make([]*gtsmodel.Instance, 0, len(instanceIDs)) + + for _, id := range instanceIDs { + // Select each instance by its ID. + instance, err := i.GetInstanceByID(ctx, id) + if err != nil { + log.Errorf(ctx, "error getting instance %q: %v", id, err) + continue + } + + // Append to return slice. + instances = append(instances, instance) + } + return instances, nil } func (i *instanceDB) GetInstanceAccounts(ctx context.Context, domain string, maxID string, limit int) ([]*gtsmodel.Account, db.Error) { - accounts := []*gtsmodel.Account{} + // Ensure reasonable + if limit < 0 { + limit = 0 + } - q := i.conn.NewSelect(). - Model(&accounts). + // Normalize the domain as punycode. + var err error + domain, err = util.Punify(domain) + if err != nil { + return nil, gtserror.Newf("error punifying domain %s: %w", domain, err) + } + + // Make educated guess for slice size + accountIDs := make([]string, 0, limit) + + q := i.conn. + NewSelect(). + TableExpr("? AS ?", bun.Ident("accounts"), bun.Ident("account")). + // Select just the account ID. + Column("account.id"). + // Select accounts belonging to given domain. Where("? = ?", bun.Ident("account.domain"), domain). Order("account.id DESC") - if maxID != "" { - q = q.Where("? < ?", bun.Ident("account.id"), maxID) + if maxID == "" { + maxID = id.Highest } + q = q.Where("? < ?", bun.Ident("account.id"), maxID) if limit > 0 { q = q.Limit(limit) } - if err := q.Scan(ctx); err != nil { + if err := q.Scan(ctx, &accountIDs); err != nil { return nil, i.conn.ProcessError(err) } - if len(accounts) == 0 { + // Catch case of no accounts early. + count := len(accountIDs) + if count == 0 { return nil, db.ErrNoEntries } + // Select each account by its ID. + accounts := make([]*gtsmodel.Account, 0, count) + for _, id := range accountIDs { + account, err := i.state.DB.GetAccountByID(ctx, id) + if err != nil { + log.Errorf(ctx, "error getting account %q: %v", id, err) + continue + } + + // Append to return slice. + accounts = append(accounts, account) + } + return accounts, nil } diff --git a/internal/db/domain.go b/internal/db/domain.go index 8918d6fe..d859752a 100644 --- a/internal/db/domain.go +++ b/internal/db/domain.go @@ -26,13 +26,19 @@ import ( // Domain contains DB functions related to domains and domain blocks. type Domain interface { - // CreateDomainBlock ... + // CreateDomainBlock puts the given instance-level domain block into the database. CreateDomainBlock(ctx context.Context, block *gtsmodel.DomainBlock) Error - // GetDomainBlock ... + // GetDomainBlock returns one instance-level domain block with the given domain, if it exists. GetDomainBlock(ctx context.Context, domain string) (*gtsmodel.DomainBlock, Error) - // DeleteDomainBlock ... + // GetDomainBlockByID returns one instance-level domain block with the given id, if it exists. + GetDomainBlockByID(ctx context.Context, id string) (*gtsmodel.DomainBlock, Error) + + // GetDomainBlocks returns all instance-level domain blocks currently enforced by this instance. + GetDomainBlocks(ctx context.Context) ([]*gtsmodel.DomainBlock, error) + + // DeleteDomainBlock deletes an instance-level domain block with the given domain, if it exists. DeleteDomainBlock(ctx context.Context, domain string) Error // IsDomainBlocked checks if an instance-level domain block exists for the given domain string (eg., `example.org`). diff --git a/internal/db/instance.go b/internal/db/instance.go index 3166a0a1..ab40c7a8 100644 --- a/internal/db/instance.go +++ b/internal/db/instance.go @@ -37,6 +37,15 @@ type Instance interface { // GetInstance returns the instance entry for the given domain, if it exists. GetInstance(ctx context.Context, domain string) (*gtsmodel.Instance, Error) + // GetInstanceByID returns the instance entry corresponding to the given id, if it exists. + GetInstanceByID(ctx context.Context, id string) (*gtsmodel.Instance, error) + + // PutInstance inserts the given instance into the database. + PutInstance(ctx context.Context, instance *gtsmodel.Instance) error + + // UpdateInstance updates the given instance entry. + UpdateInstance(ctx context.Context, instance *gtsmodel.Instance, columns ...string) error + // GetInstanceAccounts returns a slice of accounts from the given instance, arranged by ID. GetInstanceAccounts(ctx context.Context, domain string, maxID string, limit int) ([]*gtsmodel.Account, Error) diff --git a/internal/federation/federatingprotocol.go b/internal/federation/federatingprotocol.go index ec74de09..ef42639e 100644 --- a/internal/federation/federatingprotocol.go +++ b/internal/federation/federatingprotocol.go @@ -256,7 +256,7 @@ func (f *federator) AuthenticatePostInbox(ctx context.Context, w http.ResponseWr return nil, false, err } - if err := f.db.Put(ctx, instance); err != nil && !errors.Is(err, db.ErrAlreadyExists) { + if err := f.db.PutInstance(ctx, instance); err != nil && !errors.Is(err, db.ErrAlreadyExists) { err = gtserror.Newf("error inserting instance entry for %s: %w", pubKeyOwner.Host, err) return nil, false, err } diff --git a/internal/processing/admin/domainblock.go b/internal/processing/admin/domainblock.go index f1097000..c645f287 100644 --- a/internal/processing/admin/domainblock.go +++ b/internal/processing/admin/domainblock.go @@ -25,7 +25,7 @@ import ( "fmt" "io" "mime/multipart" - "strings" + "net/http" "time" "codeberg.org/gruf/go-kv" @@ -40,20 +40,30 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/text" ) -func (p *Processor) DomainBlockCreate(ctx context.Context, account *gtsmodel.Account, domain string, obfuscate bool, publicComment string, privateComment string, subscriptionID string) (*apimodel.DomainBlock, gtserror.WithCode) { - // domain blocks will always be lowercase - domain = strings.ToLower(domain) +// DomainBlockCreate creates an instance-level block against the given domain, +// and then processes side effects of that block (deleting accounts, media, etc). +// +// If a domain block already exists for the domain, side effects will be retried. +func (p *Processor) DomainBlockCreate( + ctx context.Context, + account *gtsmodel.Account, + domain string, + obfuscate bool, + publicComment string, + privateComment string, + subscriptionID string, +) (*apimodel.DomainBlock, gtserror.WithCode) { + // Check if a block already exists for this domain. + domainBlock, err := p.state.DB.GetDomainBlock(ctx, domain) + if err != nil && !errors.Is(err, db.ErrNoEntries) { + // Something went wrong in the DB. + err = gtserror.Newf("db error getting domain block %s: %w", domain, err) + return nil, gtserror.NewErrorInternalError(err) + } - // first check if we already have a block -- if err == nil we already had a block so we can skip a whole lot of work - block, err := p.state.DB.GetDomainBlock(ctx, domain) - if err != nil { - if !errors.Is(err, db.ErrNoEntries) { - // something went wrong in the DB - return nil, gtserror.NewErrorInternalError(fmt.Errorf("db error checking for existence of domain block %s: %s", domain, err)) - } - - // there's no block for this domain yet so create one - newBlock := >smodel.DomainBlock{ + if domainBlock == nil { + // No block exists yet, create it. + domainBlock = >smodel.DomainBlock{ ID: id.NewULID(), Domain: domain, CreatedByAccountID: account.ID, @@ -63,249 +73,408 @@ func (p *Processor) DomainBlockCreate(ctx context.Context, account *gtsmodel.Acc SubscriptionID: subscriptionID, } - // Insert the new block into the database - if err := p.state.DB.CreateDomainBlock(ctx, newBlock); err != nil { - return nil, gtserror.NewErrorInternalError(fmt.Errorf("db error putting new domain block %s: %s", domain, err)) + // Insert the new block into the database. + if err := p.state.DB.CreateDomainBlock(ctx, domainBlock); err != nil { + err = gtserror.Newf("db error putting domain block %s: %s", domain, err) + return nil, gtserror.NewErrorInternalError(err) } - - // Set the newly created block - block = newBlock - - // Process the side effects of the domain block asynchronously since it might take a while - go func() { - p.initiateDomainBlockSideEffects(context.Background(), account, block) - }() } - // Convert our gts model domain block into an API model - apiDomainBlock, err := p.tc.DomainBlockToAPIDomainBlock(ctx, block, false) - if err != nil { - return nil, gtserror.NewErrorInternalError(fmt.Errorf("error converting domain block to frontend/api representation %s: %s", domain, err)) - } + // Process the side effects of the domain block + // asynchronously since it might take a while. + p.state.Workers.ClientAPI.Enqueue(func(ctx context.Context) { + p.domainBlockSideEffects(ctx, account, domainBlock) + }) - return apiDomainBlock, nil + return p.apiDomainBlock(ctx, domainBlock) } -// initiateDomainBlockSideEffects should be called asynchronously, to process the side effects of a domain block: +// DomainBlocksImport handles the import of multiple domain blocks, +// by calling the DomainBlockCreate function for each domain in the +// provided file. Will return a slice of processed domain blocks. // -// 1. Strip most info away from the instance entry for the domain. -// 2. Delete the instance account for that instance if it exists. -// 3. Select all accounts from this instance and pass them through the delete functionality of the processor. -func (p *Processor) initiateDomainBlockSideEffects(ctx context.Context, account *gtsmodel.Account, block *gtsmodel.DomainBlock) { - l := log.WithContext(ctx).WithFields(kv.Fields{{"domain", block.Domain}}...) - l.Debug("processing domain block side effects") - - // if we have an instance entry for this domain, update it with the new block ID and clear all fields - instance := >smodel.Instance{} - if err := p.state.DB.GetWhere(ctx, []db.Where{{Key: "domain", Value: block.Domain}}, instance); err == nil { - updatingColumns := []string{ - "title", - "updated_at", - "suspended_at", - "domain_block_id", - "short_description", - "description", - "terms", - "contact_email", - "contact_account_username", - "contact_account_id", - "version", - } - instance.Title = "" - instance.UpdatedAt = time.Now() - instance.SuspendedAt = time.Now() - instance.DomainBlockID = block.ID - instance.ShortDescription = "" - instance.Description = "" - instance.Terms = "" - instance.ContactEmail = "" - instance.ContactAccountUsername = "" - instance.ContactAccountID = "" - instance.Version = "" - if err := p.state.DB.UpdateByID(ctx, instance, instance.ID, updatingColumns...); err != nil { - l.Errorf("domainBlockProcessSideEffects: db error updating instance: %s", err) - } - l.Debug("domainBlockProcessSideEffects: instance entry updated") - } - - // if we have an instance account for this instance, delete it - if instanceAccount, err := p.state.DB.GetAccountByUsernameDomain(ctx, block.Domain, block.Domain); err == nil { - if err := p.state.DB.DeleteAccount(ctx, instanceAccount.ID); err != nil { - l.Errorf("domainBlockProcessSideEffects: db error deleting instance account: %s", err) - } - } - - // delete accounts through the normal account deletion system (which should also delete media + posts + remove posts from timelines) - - limit := 20 // just select 20 accounts at a time so we don't nuke our DB/mem with one huge query - var maxID string // this is initially an empty string so we'll start at the top of accounts list (sorted by ID) - -selectAccountsLoop: - for { - accounts, err := p.state.DB.GetInstanceAccounts(ctx, block.Domain, maxID, limit) - if err != nil { - if err == db.ErrNoEntries { - // no accounts left for this instance so we're done - l.Infof("domainBlockProcessSideEffects: done iterating through accounts for domain %s", block.Domain) - break selectAccountsLoop - } - // an actual error has occurred - l.Errorf("domainBlockProcessSideEffects: db error selecting accounts for domain %s: %s", block.Domain, err) - break selectAccountsLoop - } - - for i, a := range accounts { - l.Debugf("putting delete for account %s in the clientAPI channel", a.Username) - - // pass the account delete through the client api channel for processing - p.state.Workers.EnqueueClientAPI(ctx, messages.FromClientAPI{ - APObjectType: ap.ActorPerson, - APActivityType: ap.ActivityDelete, - GTSModel: block, - OriginAccount: account, - TargetAccount: a, - }) - - // if this is the last account in the slice, set the maxID appropriately for the next query - if i == len(accounts)-1 { - maxID = a.ID - } - } - } -} - -// DomainBlocksImport handles the import of a bunch of domain blocks at once, by calling the DomainBlockCreate function for each domain in the provided file. -func (p *Processor) DomainBlocksImport(ctx context.Context, account *gtsmodel.Account, domains *multipart.FileHeader) ([]*apimodel.DomainBlock, gtserror.WithCode) { - f, err := domains.Open() +// In the case of total failure, a gtserror.WithCode will be returned +// so that the caller can respond appropriately. In the case of +// partial or total success, a MultiStatus model will be returned, +// which contains information about success/failure count, so that +// the caller can retry any failures as they wish. +func (p *Processor) DomainBlocksImport( + ctx context.Context, + account *gtsmodel.Account, + domainsF *multipart.FileHeader, +) (*apimodel.MultiStatus, gtserror.WithCode) { + // Open the provided file. + file, err := domainsF.Open() if err != nil { - return nil, gtserror.NewErrorBadRequest(fmt.Errorf("DomainBlocksImport: error opening attachment: %s", err)) + err = gtserror.Newf("error opening attachment: %w", err) + return nil, gtserror.NewErrorBadRequest(err, err.Error()) } + defer file.Close() + + // Copy the file contents into a buffer. buf := new(bytes.Buffer) - size, err := io.Copy(buf, f) + size, err := io.Copy(buf, file) if err != nil { - return nil, gtserror.NewErrorBadRequest(fmt.Errorf("DomainBlocksImport: error reading attachment: %s", err)) + err = gtserror.Newf("error reading attachment: %w", err) + return nil, gtserror.NewErrorBadRequest(err, err.Error()) } + + // Ensure we actually read something. if size == 0 { - return nil, gtserror.NewErrorBadRequest(errors.New("DomainBlocksImport: could not read provided attachment: size 0 bytes")) + err = gtserror.New("error reading attachment: size 0 bytes") + return nil, gtserror.NewErrorBadRequest(err, err.Error()) } - d := []apimodel.DomainBlock{} - if err := json.Unmarshal(buf.Bytes(), &d); err != nil { - return nil, gtserror.NewErrorBadRequest(fmt.Errorf("DomainBlocksImport: could not read provided attachment: %s", err)) + // Parse bytes as slice of domain blocks. + domainBlocks := make([]*apimodel.DomainBlock, 0) + if err := json.Unmarshal(buf.Bytes(), &domainBlocks); err != nil { + err = gtserror.Newf("error parsing attachment as domain blocks: %w", err) + return nil, gtserror.NewErrorBadRequest(err, err.Error()) } - blocks := []*apimodel.DomainBlock{} - for _, d := range d { - block, err := p.DomainBlockCreate(ctx, account, d.Domain.Domain, false, d.PublicComment, "", "") - if err != nil { - return nil, err + count := len(domainBlocks) + if count == 0 { + err = gtserror.New("error importing domain blocks: 0 entries provided") + return nil, gtserror.NewErrorBadRequest(err, err.Error()) + } + + // Try to process each domain block, differentiating + // between successes and errors so that the caller can + // try failed imports again if desired. + multiStatusEntries := make([]apimodel.MultiStatusEntry, 0, count) + + for _, domainBlock := range domainBlocks { + var ( + domain = domainBlock.Domain.Domain + obfuscate = domainBlock.Obfuscate + publicComment = domainBlock.PublicComment + privateComment = domainBlock.PrivateComment + subscriptionID = "" // No sub ID for imports. + errWithCode gtserror.WithCode + ) + + domainBlock, errWithCode = p.DomainBlockCreate( + ctx, + account, + domain, + obfuscate, + publicComment, + privateComment, + subscriptionID, + ) + + var entry *apimodel.MultiStatusEntry + + if errWithCode != nil { + entry = &apimodel.MultiStatusEntry{ + // Use the failed domain entry as the resource value. + Resource: domain, + Message: errWithCode.Safe(), + Status: errWithCode.Code(), + } + } else { + entry = &apimodel.MultiStatusEntry{ + // Use successfully created API model domain block as the resource value. + Resource: domainBlock, + Message: http.StatusText(http.StatusOK), + Status: http.StatusOK, + } } - blocks = append(blocks, block) + multiStatusEntries = append(multiStatusEntries, *entry) } - return blocks, nil + return apimodel.NewMultiStatus(multiStatusEntries), nil } -// DomainBlocksGet returns all existing domain blocks. -// If export is true, the format will be suitable for writing out to an export. +// DomainBlocksGet returns all existing domain blocks. If export is +// true, the format will be suitable for writing out to an export. func (p *Processor) DomainBlocksGet(ctx context.Context, account *gtsmodel.Account, export bool) ([]*apimodel.DomainBlock, gtserror.WithCode) { - domainBlocks := []*gtsmodel.DomainBlock{} - - if err := p.state.DB.GetAll(ctx, &domainBlocks); err != nil { - if !errors.Is(err, db.ErrNoEntries) { - // something has gone really wrong - return nil, gtserror.NewErrorInternalError(err) - } + domainBlocks, err := p.state.DB.GetDomainBlocks(ctx) + if err != nil && !errors.Is(err, db.ErrNoEntries) { + err = gtserror.Newf("db error getting domain blocks: %w", err) + return nil, gtserror.NewErrorInternalError(err) } - apiDomainBlocks := []*apimodel.DomainBlock{} - for _, b := range domainBlocks { - apiDomainBlock, err := p.tc.DomainBlockToAPIDomainBlock(ctx, b, export) - if err != nil { - return nil, gtserror.NewErrorInternalError(err) + apiDomainBlocks := make([]*apimodel.DomainBlock, 0, len(domainBlocks)) + for _, domainBlock := range domainBlocks { + apiDomainBlock, errWithCode := p.apiDomainBlock(ctx, domainBlock) + if errWithCode != nil { + return nil, errWithCode } + apiDomainBlocks = append(apiDomainBlocks, apiDomainBlock) } return apiDomainBlocks, nil } -// DomainBlockGet returns one domain block with the given id. -// If export is true, the format will be suitable for writing out to an export. -func (p *Processor) DomainBlockGet(ctx context.Context, account *gtsmodel.Account, id string, export bool) (*apimodel.DomainBlock, gtserror.WithCode) { - domainBlock := >smodel.DomainBlock{} - - if err := p.state.DB.GetByID(ctx, id, domainBlock); err != nil { - if !errors.Is(err, db.ErrNoEntries) { - // something has gone really wrong - return nil, gtserror.NewErrorInternalError(err) - } - // there are no entries for this ID - return nil, gtserror.NewErrorNotFound(fmt.Errorf("no entry for ID %s", id)) - } - - apiDomainBlock, err := p.tc.DomainBlockToAPIDomainBlock(ctx, domainBlock, export) +// DomainBlockGet returns one domain block with the given id. If export +// is true, the format will be suitable for writing out to an export. +func (p *Processor) DomainBlockGet(ctx context.Context, id string, export bool) (*apimodel.DomainBlock, gtserror.WithCode) { + domainBlock, err := p.state.DB.GetDomainBlockByID(ctx, id) if err != nil { + if errors.Is(err, db.ErrNoEntries) { + err = fmt.Errorf("no domain block exists with id %s", id) + return nil, gtserror.NewErrorNotFound(err, err.Error()) + } + + // Something went wrong in the DB. + err = gtserror.Newf("db error getting domain block %s: %w", id, err) return nil, gtserror.NewErrorInternalError(err) } + return p.apiDomainBlock(ctx, domainBlock) +} + +// DomainBlockDelete removes one domain block with the given ID, +// and processes side effects of removing the block asynchronously. +func (p *Processor) DomainBlockDelete(ctx context.Context, account *gtsmodel.Account, id string) (*apimodel.DomainBlock, gtserror.WithCode) { + domainBlock, err := p.state.DB.GetDomainBlockByID(ctx, id) + if err != nil { + if !errors.Is(err, db.ErrNoEntries) { + // Real error. + err = gtserror.Newf("db error getting domain block: %w", err) + return nil, gtserror.NewErrorInternalError(err) + } + + // There are just no entries for this ID. + err = fmt.Errorf("no domain block entry exists with ID %s", id) + return nil, gtserror.NewErrorNotFound(err, err.Error()) + } + + // Prepare the domain block to return, *before* the deletion goes through. + apiDomainBlock, errWithCode := p.apiDomainBlock(ctx, domainBlock) + if errWithCode != nil { + return nil, errWithCode + } + + // Copy value of the domain block. + domainBlockC := new(gtsmodel.DomainBlock) + *domainBlockC = *domainBlock + + // Delete the original domain block. + if err := p.state.DB.DeleteDomainBlock(ctx, domainBlock.Domain); err != nil { + err = gtserror.Newf("db error deleting domain block: %w", err) + return nil, gtserror.NewErrorInternalError(err) + } + + // Process the side effects of the domain unblock + // asynchronously since it might take a while. + p.state.Workers.ClientAPI.Enqueue(func(ctx context.Context) { + p.domainUnblockSideEffects(ctx, domainBlockC) // Use the copy. + }) + return apiDomainBlock, nil } -// DomainBlockDelete removes one domain block with the given ID. -func (p *Processor) DomainBlockDelete(ctx context.Context, account *gtsmodel.Account, id string) (*apimodel.DomainBlock, gtserror.WithCode) { - domainBlock := >smodel.DomainBlock{} +// stubbifyInstance renders the given instance as a stub, +// removing most information from it and marking it as +// suspended. +// +// For caller's convenience, this function returns the db +// names of all columns that are updated by it. +func stubbifyInstance(instance *gtsmodel.Instance, domainBlockID string) []string { + instance.Title = "" + instance.SuspendedAt = time.Now() + instance.DomainBlockID = domainBlockID + instance.ShortDescription = "" + instance.Description = "" + instance.Terms = "" + instance.ContactEmail = "" + instance.ContactAccountUsername = "" + instance.ContactAccountID = "" + instance.Version = "" - if err := p.state.DB.GetByID(ctx, id, domainBlock); err != nil { - if !errors.Is(err, db.ErrNoEntries) { - // something has gone really wrong - return nil, gtserror.NewErrorInternalError(err) - } - // there are no entries for this ID - return nil, gtserror.NewErrorNotFound(fmt.Errorf("no entry for ID %s", id)) + return []string{ + "title", + "suspended_at", + "domain_block_id", + "short_description", + "description", + "terms", + "contact_email", + "contact_account_username", + "contact_account_id", + "version", } +} - // prepare the domain block to return +// apiDomainBlock is a cheeky shortcut function for returning the API +// version of the given domainBlock, or an appropriate error if +// something goes wrong. +func (p *Processor) apiDomainBlock(ctx context.Context, domainBlock *gtsmodel.DomainBlock) (*apimodel.DomainBlock, gtserror.WithCode) { apiDomainBlock, err := p.tc.DomainBlockToAPIDomainBlock(ctx, domainBlock, false) if err != nil { + err = gtserror.Newf("error converting domain block for %s to api model : %w", domainBlock.Domain, err) return nil, gtserror.NewErrorInternalError(err) } - // Delete the domain block - if err := p.state.DB.DeleteDomainBlock(ctx, domainBlock.Domain); err != nil { - return nil, gtserror.NewErrorInternalError(err) - } - - // remove the domain block reference from the instance, if we have an entry for it - i := >smodel.Instance{} - if err := p.state.DB.GetWhere(ctx, []db.Where{ - {Key: "domain", Value: domainBlock.Domain}, - {Key: "domain_block_id", Value: id}, - }, i); err == nil { - updatingColumns := []string{"suspended_at", "domain_block_id", "updated_at"} - i.SuspendedAt = time.Time{} - i.DomainBlockID = "" - i.UpdatedAt = time.Now() - if err := p.state.DB.UpdateByID(ctx, i, i.ID, updatingColumns...); err != nil { - return nil, gtserror.NewErrorInternalError(fmt.Errorf("couldn't update database entry for instance %s: %s", domainBlock.Domain, err)) - } - } - - // unsuspend all accounts whose suspension origin was this domain block - // 1. remove the 'suspended_at' entry from their accounts - if err := p.state.DB.UpdateWhere(ctx, []db.Where{ - {Key: "suspension_origin", Value: domainBlock.ID}, - }, "suspended_at", nil, &[]*gtsmodel.Account{}); err != nil { - return nil, gtserror.NewErrorInternalError(fmt.Errorf("database error removing suspended_at from accounts: %s", err)) - } - - // 2. remove the 'suspension_origin' entry from their accounts - if err := p.state.DB.UpdateWhere(ctx, []db.Where{ - {Key: "suspension_origin", Value: domainBlock.ID}, - }, "suspension_origin", nil, &[]*gtsmodel.Account{}); err != nil { - return nil, gtserror.NewErrorInternalError(fmt.Errorf("database error removing suspension_origin from accounts: %s", err)) - } - return apiDomainBlock, nil } + +// rangeAccounts iterates through all accounts originating from the +// given domain, and calls the provided range function on each account. +// If an error is returned from the range function, the loop will stop +// and return the error. +func (p *Processor) rangeAccounts( + ctx context.Context, + domain string, + rangeF func(*gtsmodel.Account) error, +) error { + var ( + limit = 50 // Limit selection to avoid spiking mem/cpu. + maxID string // Start with empty string to select from top. + ) + + for { + // Get (next) page of accounts. + accounts, err := p.state.DB.GetInstanceAccounts(ctx, domain, maxID, limit) + if err != nil && !errors.Is(err, db.ErrNoEntries) { + // Real db error. + return gtserror.Newf("db error getting instance accounts: %w", err) + } + + if len(accounts) == 0 { + // No accounts left, we're done. + return nil + } + + // Set next max ID for paging down. + maxID = accounts[len(accounts)-1].ID + + // Call provided range function. + for _, account := range accounts { + if err := rangeF(account); err != nil { + return err + } + } + } +} + +// domainBlockSideEffects processes the side effects of a domain block: +// +// 1. Strip most info away from the instance entry for the domain. +// 2. Pass each account from the domain to the processor for deletion. +// +// It should be called asynchronously, since it can take a while when +// there are many accounts present on the given domain. +func (p *Processor) domainBlockSideEffects(ctx context.Context, account *gtsmodel.Account, block *gtsmodel.DomainBlock) { + l := log. + WithContext(ctx). + WithFields(kv.Fields{ + {"domain", block.Domain}, + }...) + l.Debug("processing domain block side effects") + + // If we have an instance entry for this domain, + // update it with the new block ID and clear all fields + instance, err := p.state.DB.GetInstance(ctx, block.Domain) + if err != nil && !errors.Is(err, db.ErrNoEntries) { + l.Errorf("db error getting instance %s: %q", block.Domain, err) + } + + if instance != nil { + // We had an entry for this domain. + columns := stubbifyInstance(instance, block.ID) + if err := p.state.DB.UpdateInstance(ctx, instance, columns...); err != nil { + l.Errorf("db error updating instance: %s", err) + } else { + l.Debug("instance entry updated") + } + } + + // For each account that belongs to this domain, create + // an account delete message to process via the client API + // worker pool, to remove that account's posts, media, etc. + msgs := []messages.FromClientAPI{} + if err := p.rangeAccounts(ctx, block.Domain, func(account *gtsmodel.Account) error { + msgs = append(msgs, messages.FromClientAPI{ + APObjectType: ap.ActorPerson, + APActivityType: ap.ActivityDelete, + GTSModel: block, + OriginAccount: account, + TargetAccount: account, + }) + + return nil + }); err != nil { + l.Errorf("error while ranging through accounts: %q", err) + } + + // Batch process all accreted messages. + p.state.Workers.EnqueueClientAPI(ctx, msgs...) +} + +// domainUnblockSideEffects processes the side effects of undoing a +// domain block: +// +// 1. Mark instance entry as no longer suspended. +// 2. Mark each account from the domain as no longer suspended, if the +// suspension origin corresponds to the ID of the provided domain block. +// +// It should be called asynchronously, since it can take a while when +// there are many accounts present on the given domain. +func (p *Processor) domainUnblockSideEffects(ctx context.Context, block *gtsmodel.DomainBlock) { + l := log. + WithContext(ctx). + WithFields(kv.Fields{ + {"domain", block.Domain}, + }...) + l.Debug("processing domain unblock side effects") + + // Update instance entry for this domain, if we have it. + instance, err := p.state.DB.GetInstance(ctx, block.Domain) + if err != nil && !errors.Is(err, db.ErrNoEntries) { + l.Errorf("db error getting instance %s: %q", block.Domain, err) + } + + if instance != nil { + // We had an entry, update it to signal + // that it's no longer suspended. + instance.SuspendedAt = time.Time{} + instance.DomainBlockID = "" + if err := p.state.DB.UpdateInstance( + ctx, + instance, + "suspended_at", + "domain_block_id", + ); err != nil { + l.Errorf("db error updating instance: %s", err) + } else { + l.Debug("instance entry updated") + } + } + + // Unsuspend all accounts whose suspension origin was this domain block. + if err := p.rangeAccounts(ctx, block.Domain, func(account *gtsmodel.Account) error { + if account.SuspensionOrigin == "" || account.SuspendedAt.IsZero() { + // Account wasn't suspended, nothing to do. + return nil + } + + if account.SuspensionOrigin != block.ID { + // Account was suspended, but not by + // this domain block, leave it alone. + return nil + } + + // Account was suspended by this domain + // block, mark it as unsuspended. + account.SuspendedAt = time.Time{} + account.SuspensionOrigin = "" + + if err := p.state.DB.UpdateAccount( + ctx, + account, + "suspended_at", + "suspension_origin", + ); err != nil { + return gtserror.Newf("db error updating account %s: %w", account.Username, err) + } + + return nil + }); err != nil { + l.Errorf("error while ranging through accounts: %q", err) + } +} diff --git a/internal/processing/instance.go b/internal/processing/instance.go index a9d849fa..ac63814c 100644 --- a/internal/processing/instance.go +++ b/internal/processing/instance.go @@ -34,11 +34,12 @@ import ( ) func (p *Processor) getThisInstance(ctx context.Context) (*gtsmodel.Instance, error) { - i := >smodel.Instance{} - if err := p.state.DB.GetWhere(ctx, []db.Where{{Key: "domain", Value: config.GetHost()}}, i); err != nil { + instance, err := p.state.DB.GetInstance(ctx, config.GetHost()) + if err != nil { return nil, err } - return i, nil + + return instance, nil } func (p *Processor) InstanceGetV1(ctx context.Context) (*apimodel.InstanceV1, gtserror.WithCode) { @@ -137,9 +138,10 @@ func (p *Processor) InstancePeersGet(ctx context.Context, includeSuspended bool, func (p *Processor) InstancePatch(ctx context.Context, form *apimodel.InstanceSettingsUpdateRequest) (*apimodel.InstanceV1, gtserror.WithCode) { // fetch the instance entry from the db for processing - i := >smodel.Instance{} host := config.GetHost() - if err := p.state.DB.GetWhere(ctx, []db.Where{{Key: "domain", Value: host}}, i); err != nil { + + instance, err := p.state.DB.GetInstance(ctx, host) + if err != nil { return nil, gtserror.NewErrorInternalError(fmt.Errorf("db error fetching instance %s: %s", host, err)) } @@ -157,7 +159,7 @@ func (p *Processor) InstancePatch(ctx context.Context, form *apimodel.InstanceSe return nil, gtserror.NewErrorBadRequest(err, fmt.Sprintf("site title invalid: %s", err)) } updatingColumns = append(updatingColumns, "title") - i.Title = text.SanitizePlaintext(*form.Title) // don't allow html in site title + instance.Title = text.SanitizePlaintext(*form.Title) // don't allow html in site title } // validate & update site contact account if it's set on the form @@ -192,7 +194,7 @@ func (p *Processor) InstancePatch(ctx context.Context, form *apimodel.InstanceSe return nil, gtserror.NewErrorBadRequest(err, err.Error()) } updatingColumns = append(updatingColumns, "contact_account_id") - i.ContactAccountID = contactAccount.ID + instance.ContactAccountID = contactAccount.ID } // validate & update site contact email if it's set on the form @@ -204,7 +206,7 @@ func (p *Processor) InstancePatch(ctx context.Context, form *apimodel.InstanceSe } } updatingColumns = append(updatingColumns, "contact_email") - i.ContactEmail = contactEmail + instance.ContactEmail = contactEmail } // validate & update site short description if it's set on the form @@ -213,7 +215,7 @@ func (p *Processor) InstancePatch(ctx context.Context, form *apimodel.InstanceSe return nil, gtserror.NewErrorBadRequest(err, err.Error()) } updatingColumns = append(updatingColumns, "short_description") - i.ShortDescription = text.SanitizeHTML(*form.ShortDescription) // html is OK in site description, but we should sanitize it + instance.ShortDescription = text.SanitizeHTML(*form.ShortDescription) // html is OK in site description, but we should sanitize it } // validate & update site description if it's set on the form @@ -222,7 +224,7 @@ func (p *Processor) InstancePatch(ctx context.Context, form *apimodel.InstanceSe return nil, gtserror.NewErrorBadRequest(err, err.Error()) } updatingColumns = append(updatingColumns, "description") - i.Description = text.SanitizeHTML(*form.Description) // html is OK in site description, but we should sanitize it + instance.Description = text.SanitizeHTML(*form.Description) // html is OK in site description, but we should sanitize it } // validate & update site terms if it's set on the form @@ -231,7 +233,7 @@ func (p *Processor) InstancePatch(ctx context.Context, form *apimodel.InstanceSe return nil, gtserror.NewErrorBadRequest(err, err.Error()) } updatingColumns = append(updatingColumns, "terms") - i.Terms = text.SanitizeHTML(*form.Terms) // html is OK in site terms, but we should sanitize it + instance.Terms = text.SanitizeHTML(*form.Terms) // html is OK in site terms, but we should sanitize it } var updateInstanceAccount bool @@ -273,12 +275,12 @@ func (p *Processor) InstancePatch(ctx context.Context, form *apimodel.InstanceSe } if len(updatingColumns) != 0 { - if err := p.state.DB.UpdateByID(ctx, i, i.ID, updatingColumns...); err != nil { + if err := p.state.DB.UpdateInstance(ctx, instance, updatingColumns...); err != nil { return nil, gtserror.NewErrorInternalError(fmt.Errorf("db error updating instance %s: %s", host, err)) } } - ai, err := p.tc.InstanceToAPIV1Instance(ctx, i) + ai, err := p.tc.InstanceToAPIV1Instance(ctx, instance) if err != nil { return nil, gtserror.NewErrorInternalError(fmt.Errorf("error converting instance to api representation: %s", err)) } diff --git a/test/envparsing.sh b/test/envparsing.sh index 01a0877b..eb6b8da1 100755 --- a/test/envparsing.sh +++ b/test/envparsing.sh @@ -40,6 +40,9 @@ EXPECT=$(cat <<"EOF" "follow-request-ttl": 1800000000000, "follow-sweep-freq": 60000000000, "follow-ttl": 1800000000000, + "instance-max-size": 2000, + "instance-sweep-freq": 60000000000, + "instance-ttl": 1800000000000, "list-entry-max-size": 2000, "list-entry-sweep-freq": 60000000000, "list-entry-ttl": 1800000000000,