[chore/bugfix] Domain block tidying up, Implement first pass of 207 Multi-Status (#1886)

* [chore/refactor] update domain block processing

* expose domain block import errors a lil better

* move/remove unused query keys
This commit is contained in:
tobi 2023-07-07 11:34:12 +02:00 committed by GitHub
parent d9c69f6ce0
commit e70bf8a6c8
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
19 changed files with 969 additions and 380 deletions

View file

@ -42,8 +42,6 @@ const (
EmailPath = BasePath + "/email" EmailPath = BasePath + "/email"
EmailTestPath = EmailPath + "/test" EmailTestPath = EmailPath + "/test"
ExportQueryKey = "export"
ImportQueryKey = "import"
IDKey = "id" IDKey = "id"
FilterQueryKey = "filter" FilterQueryKey = "filter"
MaxShortcodeDomainKey = "max_shortcode_domain" MaxShortcodeDomainKey = "max_shortcode_domain"

View file

@ -21,7 +21,6 @@ import (
"errors" "errors"
"fmt" "fmt"
"net/http" "net/http"
"strconv"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
@ -140,48 +139,78 @@ func (m *Module) DomainBlocksPOSTHandler(c *gin.Context) {
return return
} }
imp := false importing, errWithCode := apiutil.ParseDomainBlockImport(c.Query(apiutil.DomainBlockImportKey), false)
importString := c.Query(ImportQueryKey) if errWithCode != nil {
if importString != "" { apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1)
i, err := strconv.ParseBool(importString)
if err != nil {
err := fmt.Errorf("error parsing %s: %s", ImportQueryKey, err)
apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGetV1)
return return
} }
imp = i
}
form := &apimodel.DomainBlockCreateRequest{} form := new(apimodel.DomainBlockCreateRequest)
if err := c.ShouldBind(form); err != nil { if err := c.ShouldBind(form); err != nil {
apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGetV1) apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGetV1)
return return
} }
if err := validateCreateDomainBlock(form, imp); err != nil { if err := validateCreateDomainBlock(form, importing); err != nil {
err := fmt.Errorf("error validating form: %s", err) err := fmt.Errorf("error validating form: %w", err)
apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGetV1) apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGetV1)
return return
} }
if imp { if !importing {
// we're importing multiple blocks // Single domain block creation.
domainBlocks, errWithCode := m.processor.Admin().DomainBlocksImport(c.Request.Context(), authed.Account, form.Domains) domainBlock, errWithCode := m.processor.Admin().DomainBlockCreate(
c.Request.Context(),
authed.Account,
form.Domain,
form.Obfuscate,
form.PublicComment,
form.PrivateComment,
"", // No sub ID for single block creation.
)
if errWithCode != nil { if errWithCode != nil {
apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1)
return return
} }
c.JSON(http.StatusOK, domainBlocks)
return
}
// we're just creating one block c.JSON(http.StatusOK, domainBlock)
domainBlock, errWithCode := m.processor.Admin().DomainBlockCreate(c.Request.Context(), authed.Account, form.Domain, form.Obfuscate, form.PublicComment, form.PrivateComment, "") return
}
// We're importing multiple domain blocks,
// so we're looking at a multi-status response.
multiStatus, errWithCode := m.processor.Admin().DomainBlocksImport(
c.Request.Context(),
authed.Account,
form.Domains, // Pass the file through.
)
if errWithCode != nil { if errWithCode != nil {
apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1)
return return
} }
c.JSON(http.StatusOK, domainBlock)
// TODO: Return 207 and multiStatus data nicely
// when supported by the admin panel.
if multiStatus.Metadata.Failure != 0 {
failures := make(map[string]any, multiStatus.Metadata.Failure)
for _, entry := range multiStatus.Data {
// nolint:forcetypeassert
failures[entry.Resource.(string)] = entry.Message
}
err := fmt.Errorf("one or more errors importing domain blocks: %+v", failures)
apiutil.ErrorHandler(c, gtserror.NewErrorUnprocessableEntity(err, err.Error()), m.processor.InstanceGetV1)
return
}
// Success, return slice of domain blocks.
domainBlocks := make([]any, 0, multiStatus.Metadata.Success)
for _, entry := range multiStatus.Data {
domainBlocks = append(domainBlocks, entry.Resource)
}
c.JSON(http.StatusOK, domainBlocks)
} }
func validateCreateDomainBlock(form *apimodel.DomainBlockCreateRequest, imp bool) error { func validateCreateDomainBlock(form *apimodel.DomainBlockCreateRequest, imp bool) error {

View file

@ -18,10 +18,8 @@
package admin package admin
import ( import (
"errors"
"fmt" "fmt"
"net/http" "net/http"
"strconv"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util" apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util"
@ -87,26 +85,19 @@ func (m *Module) DomainBlockGETHandler(c *gin.Context) {
return return
} }
domainBlockID := c.Param(IDKey) domainBlockID, errWithCode := apiutil.ParseID(c.Param(apiutil.IDKey))
if domainBlockID == "" { if errWithCode != nil {
err := errors.New("no domain block id specified") apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1)
apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGetV1)
return return
} }
export := false export, errWithCode := apiutil.ParseDomainBlockExport(c.Query(apiutil.DomainBlockExportKey), false)
exportString := c.Query(ExportQueryKey) if errWithCode != nil {
if exportString != "" { apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1)
i, err := strconv.ParseBool(exportString)
if err != nil {
err := fmt.Errorf("error parsing %s: %s", ExportQueryKey, err)
apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGetV1)
return return
} }
export = i
}
domainBlock, errWithCode := m.processor.Admin().DomainBlockGet(c.Request.Context(), authed.Account, domainBlockID, export) domainBlock, errWithCode := m.processor.Admin().DomainBlockGet(c.Request.Context(), domainBlockID, export)
if errWithCode != nil { if errWithCode != nil {
apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1)
return return

View file

@ -20,7 +20,6 @@ package admin
import ( import (
"fmt" "fmt"
"net/http" "net/http"
"strconv"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util" apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util"
@ -92,17 +91,11 @@ func (m *Module) DomainBlocksGETHandler(c *gin.Context) {
return return
} }
export := false export, errWithCode := apiutil.ParseDomainBlockExport(c.Query(apiutil.DomainBlockExportKey), false)
exportString := c.Query(ExportQueryKey) if errWithCode != nil {
if exportString != "" { apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1)
i, err := strconv.ParseBool(exportString)
if err != nil {
err := fmt.Errorf("error parsing %s: %s", ExportQueryKey, err)
apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGetV1)
return return
} }
export = i
}
domainBlocks, errWithCode := m.processor.Admin().DomainBlocksGet(c.Request.Context(), authed.Account, export) domainBlocks, errWithCode := m.processor.Admin().DomainBlocksGet(c.Request.Context(), authed.Account, export)
if errWithCode != nil { if errWithCode != nil {

View file

@ -0,0 +1,90 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package model
// MultiStatus models a multistatus HTTP response body.
// This model should be transmitted along with http code
// 207 MULTI-STATUS to indicate a mixture of responses.
// See https://developer.mozilla.org/en-US/docs/Web/HTTP/Status/207
//
// swagger:model multiStatus
type MultiStatus struct {
Data []MultiStatusEntry `json:"data"`
Metadata MultiStatusMetadata `json:"metadata"`
}
// MultiStatusEntry models one entry in multistatus data.
// It can model either a success or a failure. The type
// and value of `Resource` is left to the discretion of
// the caller, but at minimum it should be expected to be
// JSON-serializable.
//
// swagger:model multiStatusEntry
type MultiStatusEntry struct {
// The resource/result for this entry.
// Value may be any type, check the docs
// per endpoint to see which to expect.
Resource any `json:"resource"`
// Message/error message for this entry.
Message string `json:"message"`
// HTTP status code of this entry.
Status int `json:"status"`
}
// MultiStatusMetadata models an at-a-glance summary of
// the data contained in the MultiStatus.
//
// swagger:model multiStatusMetadata
type MultiStatusMetadata struct {
// Success count + failure count.
Total int `json:"total"`
// Count of successful results (2xx).
Success int `json:"success"`
// Count of unsuccessful results (!2xx).
Failure int `json:"failure"`
}
// NewMultiStatus returns a new MultiStatus API model with
// the provided entries, which will be iterated through to
// look for 2xx and non 2xx status codes, in order to count
// successes and failures.
func NewMultiStatus(entries []MultiStatusEntry) *MultiStatus {
var (
successCount int
failureCount int
total = len(entries)
)
for _, e := range entries {
// Outside 2xx range = failure.
if e.Status > 299 || e.Status < 200 {
failureCount++
} else {
successCount++
}
}
return &MultiStatus{
Data: entries,
Metadata: MultiStatusMetadata{
Total: total,
Success: successCount,
Failure: failureCount,
},
}
}

View file

@ -27,6 +27,7 @@ import (
const ( const (
/* Common keys */ /* Common keys */
IDKey = "id"
LimitKey = "limit" LimitKey = "limit"
LocalKey = "local" LocalKey = "local"
MaxIDKey = "max_id" MaxIDKey = "max_id"
@ -41,6 +42,11 @@ const (
SearchQueryKey = "q" SearchQueryKey = "q"
SearchResolveKey = "resolve" SearchResolveKey = "resolve"
SearchTypeKey = "type" SearchTypeKey = "type"
/* Domain block keys */
DomainBlockExportKey = "export"
DomainBlockImportKey = "import"
) )
// parseError returns gtserror.WithCode set to 400 Bad Request, to indicate // parseError returns gtserror.WithCode set to 400 Bad Request, to indicate
@ -50,6 +56,8 @@ func parseError(key string, value, defaultValue any, err error) gtserror.WithCod
return gtserror.NewErrorBadRequest(err, err.Error()) return gtserror.NewErrorBadRequest(err, err.Error())
} }
// requiredError returns gtserror.WithCode set to 400 Bad Request, to indicate
// to the caller a required key value was not provided, or was empty.
func requiredError(key string) gtserror.WithCode { func requiredError(key string) gtserror.WithCode {
err := fmt.Errorf("required key %s was not set or had empty value", key) err := fmt.Errorf("required key %s was not set or had empty value", key)
return gtserror.NewErrorBadRequest(err, err.Error()) return gtserror.NewErrorBadRequest(err, err.Error())
@ -60,111 +68,51 @@ func requiredError(key string) gtserror.WithCode {
*/ */
func ParseLimit(value string, defaultValue int, max, min int) (int, gtserror.WithCode) { func ParseLimit(value string, defaultValue int, max, min int) (int, gtserror.WithCode) {
key := LimitKey return parseInt(value, defaultValue, max, min, LimitKey)
if value == "" {
return defaultValue, nil
}
i, err := strconv.Atoi(value)
if err != nil {
return defaultValue, parseError(key, value, defaultValue, err)
}
if i > max {
i = max
} else if i < min {
i = min
}
return i, nil
} }
func ParseLocal(value string, defaultValue bool) (bool, gtserror.WithCode) { func ParseLocal(value string, defaultValue bool) (bool, gtserror.WithCode) {
key := LimitKey return parseBool(value, defaultValue, LocalKey)
if value == "" {
return defaultValue, nil
}
i, err := strconv.ParseBool(value)
if err != nil {
return defaultValue, parseError(key, value, defaultValue, err)
}
return i, nil
} }
func ParseSearchExcludeUnreviewed(value string, defaultValue bool) (bool, gtserror.WithCode) { func ParseSearchExcludeUnreviewed(value string, defaultValue bool) (bool, gtserror.WithCode) {
key := SearchExcludeUnreviewedKey return parseBool(value, defaultValue, SearchExcludeUnreviewedKey)
if value == "" {
return defaultValue, nil
}
i, err := strconv.ParseBool(value)
if err != nil {
return defaultValue, parseError(key, value, defaultValue, err)
}
return i, nil
} }
func ParseSearchFollowing(value string, defaultValue bool) (bool, gtserror.WithCode) { func ParseSearchFollowing(value string, defaultValue bool) (bool, gtserror.WithCode) {
key := SearchFollowingKey return parseBool(value, defaultValue, SearchFollowingKey)
if value == "" {
return defaultValue, nil
}
i, err := strconv.ParseBool(value)
if err != nil {
return defaultValue, parseError(key, value, defaultValue, err)
}
return i, nil
} }
func ParseSearchOffset(value string, defaultValue int, max, min int) (int, gtserror.WithCode) { func ParseSearchOffset(value string, defaultValue int, max, min int) (int, gtserror.WithCode) {
key := SearchOffsetKey return parseInt(value, defaultValue, max, min, SearchOffsetKey)
if value == "" {
return defaultValue, nil
}
i, err := strconv.Atoi(value)
if err != nil {
return defaultValue, parseError(key, value, defaultValue, err)
}
if i > max {
i = max
} else if i < min {
i = min
}
return i, nil
} }
func ParseSearchResolve(value string, defaultValue bool) (bool, gtserror.WithCode) { func ParseSearchResolve(value string, defaultValue bool) (bool, gtserror.WithCode) {
key := SearchResolveKey return parseBool(value, defaultValue, SearchResolveKey)
if value == "" {
return defaultValue, nil
} }
i, err := strconv.ParseBool(value) func ParseDomainBlockExport(value string, defaultValue bool) (bool, gtserror.WithCode) {
if err != nil { return parseBool(value, defaultValue, DomainBlockExportKey)
return defaultValue, parseError(key, value, defaultValue, err)
} }
return i, nil func ParseDomainBlockImport(value string, defaultValue bool) (bool, gtserror.WithCode) {
return parseBool(value, defaultValue, DomainBlockImportKey)
} }
/* /*
Parse functions for *REQUIRED* parameters. Parse functions for *REQUIRED* parameters.
*/ */
func ParseID(value string) (string, gtserror.WithCode) {
key := IDKey
if value == "" {
return "", requiredError(key)
}
return value, nil
}
func ParseSearchLookup(value string) (string, gtserror.WithCode) { func ParseSearchLookup(value string) (string, gtserror.WithCode) {
key := SearchLookupKey key := SearchLookupKey
@ -184,3 +132,39 @@ func ParseSearchQuery(value string) (string, gtserror.WithCode) {
return value, nil return value, nil
} }
/*
Internal functions
*/
func parseBool(value string, defaultValue bool, key string) (bool, gtserror.WithCode) {
if value == "" {
return defaultValue, nil
}
i, err := strconv.ParseBool(value)
if err != nil {
return defaultValue, parseError(key, value, defaultValue, err)
}
return i, nil
}
func parseInt(value string, defaultValue int, max int, min int, key string) (int, gtserror.WithCode) {
if value == "" {
return defaultValue, nil
}
i, err := strconv.Atoi(value)
if err != nil {
return defaultValue, parseError(key, value, defaultValue, err)
}
if i > max {
i = max
} else if i < min {
i = min
}
return i, nil
}

22
internal/cache/gts.go vendored
View file

@ -35,6 +35,7 @@ type GTSCaches struct {
emojiCategory *result.Cache[*gtsmodel.EmojiCategory] emojiCategory *result.Cache[*gtsmodel.EmojiCategory]
follow *result.Cache[*gtsmodel.Follow] follow *result.Cache[*gtsmodel.Follow]
followRequest *result.Cache[*gtsmodel.FollowRequest] followRequest *result.Cache[*gtsmodel.FollowRequest]
instance *result.Cache[*gtsmodel.Instance]
list *result.Cache[*gtsmodel.List] list *result.Cache[*gtsmodel.List]
listEntry *result.Cache[*gtsmodel.ListEntry] listEntry *result.Cache[*gtsmodel.ListEntry]
media *result.Cache[*gtsmodel.MediaAttachment] media *result.Cache[*gtsmodel.MediaAttachment]
@ -59,6 +60,7 @@ func (c *GTSCaches) Init() {
c.initEmojiCategory() c.initEmojiCategory()
c.initFollow() c.initFollow()
c.initFollowRequest() c.initFollowRequest()
c.initInstance()
c.initList() c.initList()
c.initListEntry() c.initListEntry()
c.initMedia() c.initMedia()
@ -80,6 +82,7 @@ func (c *GTSCaches) Start() {
tryStart(c.emojiCategory, config.GetCacheGTSEmojiCategorySweepFreq()) tryStart(c.emojiCategory, config.GetCacheGTSEmojiCategorySweepFreq())
tryStart(c.follow, config.GetCacheGTSFollowSweepFreq()) tryStart(c.follow, config.GetCacheGTSFollowSweepFreq())
tryStart(c.followRequest, config.GetCacheGTSFollowRequestSweepFreq()) tryStart(c.followRequest, config.GetCacheGTSFollowRequestSweepFreq())
tryStart(c.instance, config.GetCacheGTSInstanceSweepFreq())
tryStart(c.list, config.GetCacheGTSListSweepFreq()) tryStart(c.list, config.GetCacheGTSListSweepFreq())
tryStart(c.listEntry, config.GetCacheGTSListEntrySweepFreq()) tryStart(c.listEntry, config.GetCacheGTSListEntrySweepFreq())
tryStart(c.media, config.GetCacheGTSMediaSweepFreq()) tryStart(c.media, config.GetCacheGTSMediaSweepFreq())
@ -106,6 +109,7 @@ func (c *GTSCaches) Stop() {
tryStop(c.emojiCategory, config.GetCacheGTSEmojiCategorySweepFreq()) tryStop(c.emojiCategory, config.GetCacheGTSEmojiCategorySweepFreq())
tryStop(c.follow, config.GetCacheGTSFollowSweepFreq()) tryStop(c.follow, config.GetCacheGTSFollowSweepFreq())
tryStop(c.followRequest, config.GetCacheGTSFollowRequestSweepFreq()) tryStop(c.followRequest, config.GetCacheGTSFollowRequestSweepFreq())
tryStop(c.instance, config.GetCacheGTSInstanceSweepFreq())
tryStop(c.list, config.GetCacheGTSListSweepFreq()) tryStop(c.list, config.GetCacheGTSListSweepFreq())
tryStop(c.listEntry, config.GetCacheGTSListEntrySweepFreq()) tryStop(c.listEntry, config.GetCacheGTSListEntrySweepFreq())
tryStop(c.media, config.GetCacheGTSMediaSweepFreq()) tryStop(c.media, config.GetCacheGTSMediaSweepFreq())
@ -154,6 +158,11 @@ func (c *GTSCaches) FollowRequest() *result.Cache[*gtsmodel.FollowRequest] {
return c.followRequest return c.followRequest
} }
// Instance provides access to the gtsmodel Instance database cache.
func (c *GTSCaches) Instance() *result.Cache[*gtsmodel.Instance] {
return c.instance
}
// List provides access to the gtsmodel List database cache. // List provides access to the gtsmodel List database cache.
func (c *GTSCaches) List() *result.Cache[*gtsmodel.List] { func (c *GTSCaches) List() *result.Cache[*gtsmodel.List] {
return c.list return c.list
@ -301,6 +310,19 @@ func (c *GTSCaches) initFollowRequest() {
c.followRequest.SetTTL(config.GetCacheGTSFollowRequestTTL(), true) c.followRequest.SetTTL(config.GetCacheGTSFollowRequestTTL(), true)
} }
func (c *GTSCaches) initInstance() {
c.instance = result.New([]result.Lookup{
{Name: "ID"},
{Name: "Domain"},
}, func(i1 *gtsmodel.Instance) *gtsmodel.Instance {
i2 := new(gtsmodel.Instance)
*i2 = *i1
return i1
}, config.GetCacheGTSInstanceMaxSize())
c.instance.SetTTL(config.GetCacheGTSInstanceTTL(), true)
c.emojiCategory.IgnoreErrors(ignoreErrors)
}
func (c *GTSCaches) initList() { func (c *GTSCaches) initList() {
c.list = result.New([]result.Lookup{ c.list = result.New([]result.Lookup{
{Name: "ID"}, {Name: "ID"},

View file

@ -200,6 +200,10 @@ type GTSCacheConfiguration struct {
FollowRequestTTL time.Duration `name:"follow-request-ttl"` FollowRequestTTL time.Duration `name:"follow-request-ttl"`
FollowRequestSweepFreq time.Duration `name:"follow-request-sweep-freq"` FollowRequestSweepFreq time.Duration `name:"follow-request-sweep-freq"`
InstanceMaxSize int `name:"instance-max-size"`
InstanceTTL time.Duration `name:"instance-ttl"`
InstanceSweepFreq time.Duration `name:"instance-sweep-freq"`
ListMaxSize int `name:"list-max-size"` ListMaxSize int `name:"list-max-size"`
ListTTL time.Duration `name:"list-ttl"` ListTTL time.Duration `name:"list-ttl"`
ListSweepFreq time.Duration `name:"list-sweep-freq"` ListSweepFreq time.Duration `name:"list-sweep-freq"`

View file

@ -154,6 +154,10 @@ var Defaults = Configuration{
FollowRequestTTL: time.Minute * 30, FollowRequestTTL: time.Minute * 30,
FollowRequestSweepFreq: time.Minute, FollowRequestSweepFreq: time.Minute,
InstanceMaxSize: 2000,
InstanceTTL: time.Minute * 30,
InstanceSweepFreq: time.Minute,
ListMaxSize: 2000, ListMaxSize: 2000,
ListTTL: time.Minute * 30, ListTTL: time.Minute * 30,
ListSweepFreq: time.Minute, ListSweepFreq: time.Minute,

View file

@ -2828,6 +2828,81 @@ func GetCacheGTSFollowRequestSweepFreq() time.Duration {
// SetCacheGTSFollowRequestSweepFreq safely sets the value for global configuration 'Cache.GTS.FollowRequestSweepFreq' field // SetCacheGTSFollowRequestSweepFreq safely sets the value for global configuration 'Cache.GTS.FollowRequestSweepFreq' field
func SetCacheGTSFollowRequestSweepFreq(v time.Duration) { global.SetCacheGTSFollowRequestSweepFreq(v) } func SetCacheGTSFollowRequestSweepFreq(v time.Duration) { global.SetCacheGTSFollowRequestSweepFreq(v) }
// GetCacheGTSInstanceMaxSize safely fetches the Configuration value for state's 'Cache.GTS.InstanceMaxSize' field
func (st *ConfigState) GetCacheGTSInstanceMaxSize() (v int) {
st.mutex.Lock()
v = st.config.Cache.GTS.InstanceMaxSize
st.mutex.Unlock()
return
}
// SetCacheGTSInstanceMaxSize safely sets the Configuration value for state's 'Cache.GTS.InstanceMaxSize' field
func (st *ConfigState) SetCacheGTSInstanceMaxSize(v int) {
st.mutex.Lock()
defer st.mutex.Unlock()
st.config.Cache.GTS.InstanceMaxSize = v
st.reloadToViper()
}
// CacheGTSInstanceMaxSizeFlag returns the flag name for the 'Cache.GTS.InstanceMaxSize' field
func CacheGTSInstanceMaxSizeFlag() string { return "cache-gts-instance-max-size" }
// GetCacheGTSInstanceMaxSize safely fetches the value for global configuration 'Cache.GTS.InstanceMaxSize' field
func GetCacheGTSInstanceMaxSize() int { return global.GetCacheGTSInstanceMaxSize() }
// SetCacheGTSInstanceMaxSize safely sets the value for global configuration 'Cache.GTS.InstanceMaxSize' field
func SetCacheGTSInstanceMaxSize(v int) { global.SetCacheGTSInstanceMaxSize(v) }
// GetCacheGTSInstanceTTL safely fetches the Configuration value for state's 'Cache.GTS.InstanceTTL' field
func (st *ConfigState) GetCacheGTSInstanceTTL() (v time.Duration) {
st.mutex.Lock()
v = st.config.Cache.GTS.InstanceTTL
st.mutex.Unlock()
return
}
// SetCacheGTSInstanceTTL safely sets the Configuration value for state's 'Cache.GTS.InstanceTTL' field
func (st *ConfigState) SetCacheGTSInstanceTTL(v time.Duration) {
st.mutex.Lock()
defer st.mutex.Unlock()
st.config.Cache.GTS.InstanceTTL = v
st.reloadToViper()
}
// CacheGTSInstanceTTLFlag returns the flag name for the 'Cache.GTS.InstanceTTL' field
func CacheGTSInstanceTTLFlag() string { return "cache-gts-instance-ttl" }
// GetCacheGTSInstanceTTL safely fetches the value for global configuration 'Cache.GTS.InstanceTTL' field
func GetCacheGTSInstanceTTL() time.Duration { return global.GetCacheGTSInstanceTTL() }
// SetCacheGTSInstanceTTL safely sets the value for global configuration 'Cache.GTS.InstanceTTL' field
func SetCacheGTSInstanceTTL(v time.Duration) { global.SetCacheGTSInstanceTTL(v) }
// GetCacheGTSInstanceSweepFreq safely fetches the Configuration value for state's 'Cache.GTS.InstanceSweepFreq' field
func (st *ConfigState) GetCacheGTSInstanceSweepFreq() (v time.Duration) {
st.mutex.Lock()
v = st.config.Cache.GTS.InstanceSweepFreq
st.mutex.Unlock()
return
}
// SetCacheGTSInstanceSweepFreq safely sets the Configuration value for state's 'Cache.GTS.InstanceSweepFreq' field
func (st *ConfigState) SetCacheGTSInstanceSweepFreq(v time.Duration) {
st.mutex.Lock()
defer st.mutex.Unlock()
st.config.Cache.GTS.InstanceSweepFreq = v
st.reloadToViper()
}
// CacheGTSInstanceSweepFreqFlag returns the flag name for the 'Cache.GTS.InstanceSweepFreq' field
func CacheGTSInstanceSweepFreqFlag() string { return "cache-gts-instance-sweep-freq" }
// GetCacheGTSInstanceSweepFreq safely fetches the value for global configuration 'Cache.GTS.InstanceSweepFreq' field
func GetCacheGTSInstanceSweepFreq() time.Duration { return global.GetCacheGTSInstanceSweepFreq() }
// SetCacheGTSInstanceSweepFreq safely sets the value for global configuration 'Cache.GTS.InstanceSweepFreq' field
func SetCacheGTSInstanceSweepFreq(v time.Duration) { global.SetCacheGTSInstanceSweepFreq(v) }
// GetCacheGTSListMaxSize safely fetches the Configuration value for state's 'Cache.GTS.ListMaxSize' field // GetCacheGTSListMaxSize safely fetches the Configuration value for state's 'Cache.GTS.ListMaxSize' field
func (st *ConfigState) GetCacheGTSListMaxSize() (v int) { func (st *ConfigState) GetCacheGTSListMaxSize() (v int) {
st.mutex.Lock() st.mutex.Lock()

View file

@ -180,6 +180,7 @@ func NewBunDBService(ctx context.Context, state *state.State) (db.DB, error) {
}, },
Instance: &instanceDB{ Instance: &instanceDB{
conn: conn, conn: conn,
state: state,
}, },
List: &listDB{ List: &listDB{
conn: conn, conn: conn,

View file

@ -42,7 +42,7 @@ func (d *domainDB) CreateDomainBlock(ctx context.Context, block *gtsmodel.Domain
return err return err
} }
// Attempt to store domain in DB // Attempt to store domain block in DB
if _, err := d.conn.NewInsert(). if _, err := d.conn.NewInsert().
Model(block). Model(block).
Exec(ctx); err != nil { Exec(ctx); err != nil {
@ -82,6 +82,33 @@ func (d *domainDB) GetDomainBlock(ctx context.Context, domain string) (*gtsmodel
return &block, nil return &block, nil
} }
func (d *domainDB) GetDomainBlocks(ctx context.Context) ([]*gtsmodel.DomainBlock, error) {
blocks := []*gtsmodel.DomainBlock{}
if err := d.conn.
NewSelect().
Model(&blocks).
Scan(ctx); err != nil {
return nil, d.conn.ProcessError(err)
}
return blocks, nil
}
func (d *domainDB) GetDomainBlockByID(ctx context.Context, id string) (*gtsmodel.DomainBlock, db.Error) {
var block gtsmodel.DomainBlock
q := d.conn.
NewSelect().
Model(&block).
Where("? = ?", bun.Ident("domain_block.id"), id)
if err := q.Scan(ctx); err != nil {
return nil, d.conn.ProcessError(err)
}
return &block, nil
}
func (d *domainDB) DeleteDomainBlock(ctx context.Context, domain string) db.Error { func (d *domainDB) DeleteDomainBlock(ctx context.Context, domain string) db.Error {
// Normalize the domain as punycode // Normalize the domain as punycode
domain, err := util.Punify(domain) domain, err := util.Punify(domain)

View file

@ -19,15 +19,23 @@ package bundb
import ( import (
"context" "context"
"time"
"github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtscontext"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/id"
"github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/util"
"github.com/uptrace/bun" "github.com/uptrace/bun"
) )
type instanceDB struct { type instanceDB struct {
conn *DBConn conn *DBConn
state *state.State
} }
func (i *instanceDB) CountInstanceUsers(ctx context.Context, domain string) (int, db.Error) { func (i *instanceDB) CountInstanceUsers(ctx context.Context, domain string) (int, db.Error) {
@ -99,62 +107,236 @@ func (i *instanceDB) CountInstanceDomains(ctx context.Context, domain string) (i
} }
func (i *instanceDB) GetInstance(ctx context.Context, domain string) (*gtsmodel.Instance, db.Error) { func (i *instanceDB) GetInstance(ctx context.Context, domain string) (*gtsmodel.Instance, db.Error) {
instance := &gtsmodel.Instance{} // Normalize the domain as punycode
var err error
domain, err = util.Punify(domain)
if err != nil {
return nil, gtserror.Newf("error punifying domain %s: %w", domain, err)
}
if err := i.conn. return i.getInstance(
NewSelect(). ctx,
"Domain",
func(instance *gtsmodel.Instance) error {
return i.conn.NewSelect().
Model(instance). Model(instance).
Where("? = ?", bun.Ident("instance.domain"), domain). Where("? = ?", bun.Ident("instance.domain"), domain).
Scan(ctx); err != nil { Scan(ctx)
},
domain,
)
}
func (i *instanceDB) GetInstanceByID(ctx context.Context, id string) (*gtsmodel.Instance, error) {
return i.getInstance(
ctx,
"ID",
func(instance *gtsmodel.Instance) error {
return i.conn.NewSelect().
Model(instance).
Where("? = ?", bun.Ident("instance.id"), id).
Scan(ctx)
},
id,
)
}
func (i *instanceDB) getInstance(ctx context.Context, lookup string, dbQuery func(*gtsmodel.Instance) error, keyParts ...any) (*gtsmodel.Instance, db.Error) {
// Fetch instance from database cache with loader callback
instance, err := i.state.Caches.GTS.Instance().Load(lookup, func() (*gtsmodel.Instance, error) {
var instance gtsmodel.Instance
// Not cached! Perform database query.
if err := dbQuery(&instance); err != nil {
return nil, i.conn.ProcessError(err) return nil, i.conn.ProcessError(err)
} }
return &instance, nil
}, keyParts...)
if err != nil {
return nil, err
}
if gtscontext.Barebones(ctx) {
// no need to fully populate.
return instance, nil
}
// Further populate the instance fields where applicable.
if err := i.populateInstance(ctx, instance); err != nil {
return nil, err
}
return instance, nil return instance, nil
} }
func (i *instanceDB) populateInstance(ctx context.Context, instance *gtsmodel.Instance) error {
var (
err error
errs = make(gtserror.MultiError, 0, 2)
)
if instance.DomainBlockID != "" && instance.DomainBlock == nil {
// Instance domain block is not set, fetch from database.
instance.DomainBlock, err = i.state.DB.GetDomainBlock(
gtscontext.SetBarebones(ctx),
instance.Domain,
)
if err != nil {
errs.Append(gtserror.Newf("error populating instance domain block: %w", err))
}
}
if instance.ContactAccountID != "" && instance.ContactAccount == nil {
// Instance domain block is not set, fetch from database.
instance.ContactAccount, err = i.state.DB.GetAccountByID(
gtscontext.SetBarebones(ctx),
instance.ContactAccountID,
)
if err != nil {
errs.Append(gtserror.Newf("error populating instance contact account: %w", err))
}
}
return errs.Combine()
}
func (i *instanceDB) PutInstance(ctx context.Context, instance *gtsmodel.Instance) error {
// Normalize the domain as punycode
var err error
instance.Domain, err = util.Punify(instance.Domain)
if err != nil {
return gtserror.Newf("error punifying domain %s: %w", instance.Domain, err)
}
return i.state.Caches.GTS.Instance().Store(instance, func() error {
_, err := i.conn.NewInsert().Model(instance).Exec(ctx)
return i.conn.ProcessError(err)
})
}
func (i *instanceDB) UpdateInstance(ctx context.Context, instance *gtsmodel.Instance, columns ...string) error {
// Normalize the domain as punycode
var err error
instance.Domain, err = util.Punify(instance.Domain)
if err != nil {
return gtserror.Newf("error punifying domain %s: %w", instance.Domain, err)
}
// Update the instance's last-updated
instance.UpdatedAt = time.Now()
if len(columns) != 0 {
columns = append(columns, "updated_at")
}
return i.state.Caches.GTS.Instance().Store(instance, func() error {
_, err := i.conn.
NewUpdate().
Model(instance).
Where("? = ?", bun.Ident("instance.id"), instance.ID).
Column(columns...).
Exec(ctx)
return i.conn.ProcessError(err)
})
}
func (i *instanceDB) GetInstancePeers(ctx context.Context, includeSuspended bool) ([]*gtsmodel.Instance, db.Error) { func (i *instanceDB) GetInstancePeers(ctx context.Context, includeSuspended bool) ([]*gtsmodel.Instance, db.Error) {
instances := []*gtsmodel.Instance{} instanceIDs := []string{}
q := i.conn. q := i.conn.
NewSelect(). NewSelect().
Model(&instances). TableExpr("? AS ?", bun.Ident("instances"), bun.Ident("instance")).
// Select just the IDs of each instance.
Column("instance.id").
// Exclude our own instance.
Where("? != ?", bun.Ident("instance.domain"), config.GetHost()) Where("? != ?", bun.Ident("instance.domain"), config.GetHost())
if !includeSuspended { if !includeSuspended {
q = q.Where("? IS NULL", bun.Ident("instance.suspended_at")) q = q.Where("? IS NULL", bun.Ident("instance.suspended_at"))
} }
if err := q.Scan(ctx); err != nil { if err := q.Scan(ctx, &instanceIDs); err != nil {
return nil, i.conn.ProcessError(err) return nil, i.conn.ProcessError(err)
} }
if len(instanceIDs) == 0 {
return make([]*gtsmodel.Instance, 0), nil
}
instances := make([]*gtsmodel.Instance, 0, len(instanceIDs))
for _, id := range instanceIDs {
// Select each instance by its ID.
instance, err := i.GetInstanceByID(ctx, id)
if err != nil {
log.Errorf(ctx, "error getting instance %q: %v", id, err)
continue
}
// Append to return slice.
instances = append(instances, instance)
}
return instances, nil return instances, nil
} }
func (i *instanceDB) GetInstanceAccounts(ctx context.Context, domain string, maxID string, limit int) ([]*gtsmodel.Account, db.Error) { func (i *instanceDB) GetInstanceAccounts(ctx context.Context, domain string, maxID string, limit int) ([]*gtsmodel.Account, db.Error) {
accounts := []*gtsmodel.Account{} // Ensure reasonable
if limit < 0 {
limit = 0
}
q := i.conn.NewSelect(). // Normalize the domain as punycode.
Model(&accounts). var err error
domain, err = util.Punify(domain)
if err != nil {
return nil, gtserror.Newf("error punifying domain %s: %w", domain, err)
}
// Make educated guess for slice size
accountIDs := make([]string, 0, limit)
q := i.conn.
NewSelect().
TableExpr("? AS ?", bun.Ident("accounts"), bun.Ident("account")).
// Select just the account ID.
Column("account.id").
// Select accounts belonging to given domain.
Where("? = ?", bun.Ident("account.domain"), domain). Where("? = ?", bun.Ident("account.domain"), domain).
Order("account.id DESC") Order("account.id DESC")
if maxID != "" { if maxID == "" {
q = q.Where("? < ?", bun.Ident("account.id"), maxID) maxID = id.Highest
} }
q = q.Where("? < ?", bun.Ident("account.id"), maxID)
if limit > 0 { if limit > 0 {
q = q.Limit(limit) q = q.Limit(limit)
} }
if err := q.Scan(ctx); err != nil { if err := q.Scan(ctx, &accountIDs); err != nil {
return nil, i.conn.ProcessError(err) return nil, i.conn.ProcessError(err)
} }
if len(accounts) == 0 { // Catch case of no accounts early.
count := len(accountIDs)
if count == 0 {
return nil, db.ErrNoEntries return nil, db.ErrNoEntries
} }
// Select each account by its ID.
accounts := make([]*gtsmodel.Account, 0, count)
for _, id := range accountIDs {
account, err := i.state.DB.GetAccountByID(ctx, id)
if err != nil {
log.Errorf(ctx, "error getting account %q: %v", id, err)
continue
}
// Append to return slice.
accounts = append(accounts, account)
}
return accounts, nil return accounts, nil
} }

View file

@ -26,13 +26,19 @@ import (
// Domain contains DB functions related to domains and domain blocks. // Domain contains DB functions related to domains and domain blocks.
type Domain interface { type Domain interface {
// CreateDomainBlock ... // CreateDomainBlock puts the given instance-level domain block into the database.
CreateDomainBlock(ctx context.Context, block *gtsmodel.DomainBlock) Error CreateDomainBlock(ctx context.Context, block *gtsmodel.DomainBlock) Error
// GetDomainBlock ... // GetDomainBlock returns one instance-level domain block with the given domain, if it exists.
GetDomainBlock(ctx context.Context, domain string) (*gtsmodel.DomainBlock, Error) GetDomainBlock(ctx context.Context, domain string) (*gtsmodel.DomainBlock, Error)
// DeleteDomainBlock ... // GetDomainBlockByID returns one instance-level domain block with the given id, if it exists.
GetDomainBlockByID(ctx context.Context, id string) (*gtsmodel.DomainBlock, Error)
// GetDomainBlocks returns all instance-level domain blocks currently enforced by this instance.
GetDomainBlocks(ctx context.Context) ([]*gtsmodel.DomainBlock, error)
// DeleteDomainBlock deletes an instance-level domain block with the given domain, if it exists.
DeleteDomainBlock(ctx context.Context, domain string) Error DeleteDomainBlock(ctx context.Context, domain string) Error
// IsDomainBlocked checks if an instance-level domain block exists for the given domain string (eg., `example.org`). // IsDomainBlocked checks if an instance-level domain block exists for the given domain string (eg., `example.org`).

View file

@ -37,6 +37,15 @@ type Instance interface {
// GetInstance returns the instance entry for the given domain, if it exists. // GetInstance returns the instance entry for the given domain, if it exists.
GetInstance(ctx context.Context, domain string) (*gtsmodel.Instance, Error) GetInstance(ctx context.Context, domain string) (*gtsmodel.Instance, Error)
// GetInstanceByID returns the instance entry corresponding to the given id, if it exists.
GetInstanceByID(ctx context.Context, id string) (*gtsmodel.Instance, error)
// PutInstance inserts the given instance into the database.
PutInstance(ctx context.Context, instance *gtsmodel.Instance) error
// UpdateInstance updates the given instance entry.
UpdateInstance(ctx context.Context, instance *gtsmodel.Instance, columns ...string) error
// GetInstanceAccounts returns a slice of accounts from the given instance, arranged by ID. // GetInstanceAccounts returns a slice of accounts from the given instance, arranged by ID.
GetInstanceAccounts(ctx context.Context, domain string, maxID string, limit int) ([]*gtsmodel.Account, Error) GetInstanceAccounts(ctx context.Context, domain string, maxID string, limit int) ([]*gtsmodel.Account, Error)

View file

@ -256,7 +256,7 @@ func (f *federator) AuthenticatePostInbox(ctx context.Context, w http.ResponseWr
return nil, false, err return nil, false, err
} }
if err := f.db.Put(ctx, instance); err != nil && !errors.Is(err, db.ErrAlreadyExists) { if err := f.db.PutInstance(ctx, instance); err != nil && !errors.Is(err, db.ErrAlreadyExists) {
err = gtserror.Newf("error inserting instance entry for %s: %w", pubKeyOwner.Host, err) err = gtserror.Newf("error inserting instance entry for %s: %w", pubKeyOwner.Host, err)
return nil, false, err return nil, false, err
} }

View file

@ -25,7 +25,7 @@ import (
"fmt" "fmt"
"io" "io"
"mime/multipart" "mime/multipart"
"strings" "net/http"
"time" "time"
"codeberg.org/gruf/go-kv" "codeberg.org/gruf/go-kv"
@ -40,20 +40,30 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/text" "github.com/superseriousbusiness/gotosocial/internal/text"
) )
func (p *Processor) DomainBlockCreate(ctx context.Context, account *gtsmodel.Account, domain string, obfuscate bool, publicComment string, privateComment string, subscriptionID string) (*apimodel.DomainBlock, gtserror.WithCode) { // DomainBlockCreate creates an instance-level block against the given domain,
// domain blocks will always be lowercase // and then processes side effects of that block (deleting accounts, media, etc).
domain = strings.ToLower(domain) //
// If a domain block already exists for the domain, side effects will be retried.
// first check if we already have a block -- if err == nil we already had a block so we can skip a whole lot of work func (p *Processor) DomainBlockCreate(
block, err := p.state.DB.GetDomainBlock(ctx, domain) ctx context.Context,
if err != nil { account *gtsmodel.Account,
if !errors.Is(err, db.ErrNoEntries) { domain string,
// something went wrong in the DB obfuscate bool,
return nil, gtserror.NewErrorInternalError(fmt.Errorf("db error checking for existence of domain block %s: %s", domain, err)) publicComment string,
privateComment string,
subscriptionID string,
) (*apimodel.DomainBlock, gtserror.WithCode) {
// Check if a block already exists for this domain.
domainBlock, err := p.state.DB.GetDomainBlock(ctx, domain)
if err != nil && !errors.Is(err, db.ErrNoEntries) {
// Something went wrong in the DB.
err = gtserror.Newf("db error getting domain block %s: %w", domain, err)
return nil, gtserror.NewErrorInternalError(err)
} }
// there's no block for this domain yet so create one if domainBlock == nil {
newBlock := &gtsmodel.DomainBlock{ // No block exists yet, create it.
domainBlock = &gtsmodel.DomainBlock{
ID: id.NewULID(), ID: id.NewULID(),
Domain: domain, Domain: domain,
CreatedByAccountID: account.ID, CreatedByAccountID: account.ID,
@ -63,44 +73,221 @@ func (p *Processor) DomainBlockCreate(ctx context.Context, account *gtsmodel.Acc
SubscriptionID: subscriptionID, SubscriptionID: subscriptionID,
} }
// Insert the new block into the database // Insert the new block into the database.
if err := p.state.DB.CreateDomainBlock(ctx, newBlock); err != nil { if err := p.state.DB.CreateDomainBlock(ctx, domainBlock); err != nil {
return nil, gtserror.NewErrorInternalError(fmt.Errorf("db error putting new domain block %s: %s", domain, err)) err = gtserror.Newf("db error putting domain block %s: %s", domain, err)
return nil, gtserror.NewErrorInternalError(err)
}
} }
// Set the newly created block // Process the side effects of the domain block
block = newBlock // asynchronously since it might take a while.
p.state.Workers.ClientAPI.Enqueue(func(ctx context.Context) {
p.domainBlockSideEffects(ctx, account, domainBlock)
})
// Process the side effects of the domain block asynchronously since it might take a while return p.apiDomainBlock(ctx, domainBlock)
go func() {
p.initiateDomainBlockSideEffects(context.Background(), account, block)
}()
} }
// Convert our gts model domain block into an API model // DomainBlocksImport handles the import of multiple domain blocks,
apiDomainBlock, err := p.tc.DomainBlockToAPIDomainBlock(ctx, block, false) // by calling the DomainBlockCreate function for each domain in the
// provided file. Will return a slice of processed domain blocks.
//
// In the case of total failure, a gtserror.WithCode will be returned
// so that the caller can respond appropriately. In the case of
// partial or total success, a MultiStatus model will be returned,
// which contains information about success/failure count, so that
// the caller can retry any failures as they wish.
func (p *Processor) DomainBlocksImport(
ctx context.Context,
account *gtsmodel.Account,
domainsF *multipart.FileHeader,
) (*apimodel.MultiStatus, gtserror.WithCode) {
// Open the provided file.
file, err := domainsF.Open()
if err != nil { if err != nil {
return nil, gtserror.NewErrorInternalError(fmt.Errorf("error converting domain block to frontend/api representation %s: %s", domain, err)) err = gtserror.Newf("error opening attachment: %w", err)
return nil, gtserror.NewErrorBadRequest(err, err.Error())
} }
defer file.Close()
// Copy the file contents into a buffer.
buf := new(bytes.Buffer)
size, err := io.Copy(buf, file)
if err != nil {
err = gtserror.Newf("error reading attachment: %w", err)
return nil, gtserror.NewErrorBadRequest(err, err.Error())
}
// Ensure we actually read something.
if size == 0 {
err = gtserror.New("error reading attachment: size 0 bytes")
return nil, gtserror.NewErrorBadRequest(err, err.Error())
}
// Parse bytes as slice of domain blocks.
domainBlocks := make([]*apimodel.DomainBlock, 0)
if err := json.Unmarshal(buf.Bytes(), &domainBlocks); err != nil {
err = gtserror.Newf("error parsing attachment as domain blocks: %w", err)
return nil, gtserror.NewErrorBadRequest(err, err.Error())
}
count := len(domainBlocks)
if count == 0 {
err = gtserror.New("error importing domain blocks: 0 entries provided")
return nil, gtserror.NewErrorBadRequest(err, err.Error())
}
// Try to process each domain block, differentiating
// between successes and errors so that the caller can
// try failed imports again if desired.
multiStatusEntries := make([]apimodel.MultiStatusEntry, 0, count)
for _, domainBlock := range domainBlocks {
var (
domain = domainBlock.Domain.Domain
obfuscate = domainBlock.Obfuscate
publicComment = domainBlock.PublicComment
privateComment = domainBlock.PrivateComment
subscriptionID = "" // No sub ID for imports.
errWithCode gtserror.WithCode
)
domainBlock, errWithCode = p.DomainBlockCreate(
ctx,
account,
domain,
obfuscate,
publicComment,
privateComment,
subscriptionID,
)
var entry *apimodel.MultiStatusEntry
if errWithCode != nil {
entry = &apimodel.MultiStatusEntry{
// Use the failed domain entry as the resource value.
Resource: domain,
Message: errWithCode.Safe(),
Status: errWithCode.Code(),
}
} else {
entry = &apimodel.MultiStatusEntry{
// Use successfully created API model domain block as the resource value.
Resource: domainBlock,
Message: http.StatusText(http.StatusOK),
Status: http.StatusOK,
}
}
multiStatusEntries = append(multiStatusEntries, *entry)
}
return apimodel.NewMultiStatus(multiStatusEntries), nil
}
// DomainBlocksGet returns all existing domain blocks. If export is
// true, the format will be suitable for writing out to an export.
func (p *Processor) DomainBlocksGet(ctx context.Context, account *gtsmodel.Account, export bool) ([]*apimodel.DomainBlock, gtserror.WithCode) {
domainBlocks, err := p.state.DB.GetDomainBlocks(ctx)
if err != nil && !errors.Is(err, db.ErrNoEntries) {
err = gtserror.Newf("db error getting domain blocks: %w", err)
return nil, gtserror.NewErrorInternalError(err)
}
apiDomainBlocks := make([]*apimodel.DomainBlock, 0, len(domainBlocks))
for _, domainBlock := range domainBlocks {
apiDomainBlock, errWithCode := p.apiDomainBlock(ctx, domainBlock)
if errWithCode != nil {
return nil, errWithCode
}
apiDomainBlocks = append(apiDomainBlocks, apiDomainBlock)
}
return apiDomainBlocks, nil
}
// DomainBlockGet returns one domain block with the given id. If export
// is true, the format will be suitable for writing out to an export.
func (p *Processor) DomainBlockGet(ctx context.Context, id string, export bool) (*apimodel.DomainBlock, gtserror.WithCode) {
domainBlock, err := p.state.DB.GetDomainBlockByID(ctx, id)
if err != nil {
if errors.Is(err, db.ErrNoEntries) {
err = fmt.Errorf("no domain block exists with id %s", id)
return nil, gtserror.NewErrorNotFound(err, err.Error())
}
// Something went wrong in the DB.
err = gtserror.Newf("db error getting domain block %s: %w", id, err)
return nil, gtserror.NewErrorInternalError(err)
}
return p.apiDomainBlock(ctx, domainBlock)
}
// DomainBlockDelete removes one domain block with the given ID,
// and processes side effects of removing the block asynchronously.
func (p *Processor) DomainBlockDelete(ctx context.Context, account *gtsmodel.Account, id string) (*apimodel.DomainBlock, gtserror.WithCode) {
domainBlock, err := p.state.DB.GetDomainBlockByID(ctx, id)
if err != nil {
if !errors.Is(err, db.ErrNoEntries) {
// Real error.
err = gtserror.Newf("db error getting domain block: %w", err)
return nil, gtserror.NewErrorInternalError(err)
}
// There are just no entries for this ID.
err = fmt.Errorf("no domain block entry exists with ID %s", id)
return nil, gtserror.NewErrorNotFound(err, err.Error())
}
// Prepare the domain block to return, *before* the deletion goes through.
apiDomainBlock, errWithCode := p.apiDomainBlock(ctx, domainBlock)
if errWithCode != nil {
return nil, errWithCode
}
// Copy value of the domain block.
domainBlockC := new(gtsmodel.DomainBlock)
*domainBlockC = *domainBlock
// Delete the original domain block.
if err := p.state.DB.DeleteDomainBlock(ctx, domainBlock.Domain); err != nil {
err = gtserror.Newf("db error deleting domain block: %w", err)
return nil, gtserror.NewErrorInternalError(err)
}
// Process the side effects of the domain unblock
// asynchronously since it might take a while.
p.state.Workers.ClientAPI.Enqueue(func(ctx context.Context) {
p.domainUnblockSideEffects(ctx, domainBlockC) // Use the copy.
})
return apiDomainBlock, nil return apiDomainBlock, nil
} }
// initiateDomainBlockSideEffects should be called asynchronously, to process the side effects of a domain block: // stubbifyInstance renders the given instance as a stub,
// removing most information from it and marking it as
// suspended.
// //
// 1. Strip most info away from the instance entry for the domain. // For caller's convenience, this function returns the db
// 2. Delete the instance account for that instance if it exists. // names of all columns that are updated by it.
// 3. Select all accounts from this instance and pass them through the delete functionality of the processor. func stubbifyInstance(instance *gtsmodel.Instance, domainBlockID string) []string {
func (p *Processor) initiateDomainBlockSideEffects(ctx context.Context, account *gtsmodel.Account, block *gtsmodel.DomainBlock) { instance.Title = ""
l := log.WithContext(ctx).WithFields(kv.Fields{{"domain", block.Domain}}...) instance.SuspendedAt = time.Now()
l.Debug("processing domain block side effects") instance.DomainBlockID = domainBlockID
instance.ShortDescription = ""
instance.Description = ""
instance.Terms = ""
instance.ContactEmail = ""
instance.ContactAccountUsername = ""
instance.ContactAccountID = ""
instance.Version = ""
// if we have an instance entry for this domain, update it with the new block ID and clear all fields return []string{
instance := &gtsmodel.Instance{}
if err := p.state.DB.GetWhere(ctx, []db.Where{{Key: "domain", Value: block.Domain}}, instance); err == nil {
updatingColumns := []string{
"title", "title",
"updated_at",
"suspended_at", "suspended_at",
"domain_block_id", "domain_block_id",
"short_description", "short_description",
@ -111,201 +298,183 @@ func (p *Processor) initiateDomainBlockSideEffects(ctx context.Context, account
"contact_account_id", "contact_account_id",
"version", "version",
} }
instance.Title = ""
instance.UpdatedAt = time.Now()
instance.SuspendedAt = time.Now()
instance.DomainBlockID = block.ID
instance.ShortDescription = ""
instance.Description = ""
instance.Terms = ""
instance.ContactEmail = ""
instance.ContactAccountUsername = ""
instance.ContactAccountID = ""
instance.Version = ""
if err := p.state.DB.UpdateByID(ctx, instance, instance.ID, updatingColumns...); err != nil {
l.Errorf("domainBlockProcessSideEffects: db error updating instance: %s", err)
}
l.Debug("domainBlockProcessSideEffects: instance entry updated")
} }
// if we have an instance account for this instance, delete it // apiDomainBlock is a cheeky shortcut function for returning the API
if instanceAccount, err := p.state.DB.GetAccountByUsernameDomain(ctx, block.Domain, block.Domain); err == nil { // version of the given domainBlock, or an appropriate error if
if err := p.state.DB.DeleteAccount(ctx, instanceAccount.ID); err != nil { // something goes wrong.
l.Errorf("domainBlockProcessSideEffects: db error deleting instance account: %s", err) func (p *Processor) apiDomainBlock(ctx context.Context, domainBlock *gtsmodel.DomainBlock) (*apimodel.DomainBlock, gtserror.WithCode) {
} apiDomainBlock, err := p.tc.DomainBlockToAPIDomainBlock(ctx, domainBlock, false)
}
// delete accounts through the normal account deletion system (which should also delete media + posts + remove posts from timelines)
limit := 20 // just select 20 accounts at a time so we don't nuke our DB/mem with one huge query
var maxID string // this is initially an empty string so we'll start at the top of accounts list (sorted by ID)
selectAccountsLoop:
for {
accounts, err := p.state.DB.GetInstanceAccounts(ctx, block.Domain, maxID, limit)
if err != nil { if err != nil {
if err == db.ErrNoEntries { err = gtserror.Newf("error converting domain block for %s to api model : %w", domainBlock.Domain, err)
// no accounts left for this instance so we're done return nil, gtserror.NewErrorInternalError(err)
l.Infof("domainBlockProcessSideEffects: done iterating through accounts for domain %s", block.Domain)
break selectAccountsLoop
}
// an actual error has occurred
l.Errorf("domainBlockProcessSideEffects: db error selecting accounts for domain %s: %s", block.Domain, err)
break selectAccountsLoop
} }
for i, a := range accounts { return apiDomainBlock, nil
l.Debugf("putting delete for account %s in the clientAPI channel", a.Username) }
// pass the account delete through the client api channel for processing // rangeAccounts iterates through all accounts originating from the
p.state.Workers.EnqueueClientAPI(ctx, messages.FromClientAPI{ // given domain, and calls the provided range function on each account.
// If an error is returned from the range function, the loop will stop
// and return the error.
func (p *Processor) rangeAccounts(
ctx context.Context,
domain string,
rangeF func(*gtsmodel.Account) error,
) error {
var (
limit = 50 // Limit selection to avoid spiking mem/cpu.
maxID string // Start with empty string to select from top.
)
for {
// Get (next) page of accounts.
accounts, err := p.state.DB.GetInstanceAccounts(ctx, domain, maxID, limit)
if err != nil && !errors.Is(err, db.ErrNoEntries) {
// Real db error.
return gtserror.Newf("db error getting instance accounts: %w", err)
}
if len(accounts) == 0 {
// No accounts left, we're done.
return nil
}
// Set next max ID for paging down.
maxID = accounts[len(accounts)-1].ID
// Call provided range function.
for _, account := range accounts {
if err := rangeF(account); err != nil {
return err
}
}
}
}
// domainBlockSideEffects processes the side effects of a domain block:
//
// 1. Strip most info away from the instance entry for the domain.
// 2. Pass each account from the domain to the processor for deletion.
//
// It should be called asynchronously, since it can take a while when
// there are many accounts present on the given domain.
func (p *Processor) domainBlockSideEffects(ctx context.Context, account *gtsmodel.Account, block *gtsmodel.DomainBlock) {
l := log.
WithContext(ctx).
WithFields(kv.Fields{
{"domain", block.Domain},
}...)
l.Debug("processing domain block side effects")
// If we have an instance entry for this domain,
// update it with the new block ID and clear all fields
instance, err := p.state.DB.GetInstance(ctx, block.Domain)
if err != nil && !errors.Is(err, db.ErrNoEntries) {
l.Errorf("db error getting instance %s: %q", block.Domain, err)
}
if instance != nil {
// We had an entry for this domain.
columns := stubbifyInstance(instance, block.ID)
if err := p.state.DB.UpdateInstance(ctx, instance, columns...); err != nil {
l.Errorf("db error updating instance: %s", err)
} else {
l.Debug("instance entry updated")
}
}
// For each account that belongs to this domain, create
// an account delete message to process via the client API
// worker pool, to remove that account's posts, media, etc.
msgs := []messages.FromClientAPI{}
if err := p.rangeAccounts(ctx, block.Domain, func(account *gtsmodel.Account) error {
msgs = append(msgs, messages.FromClientAPI{
APObjectType: ap.ActorPerson, APObjectType: ap.ActorPerson,
APActivityType: ap.ActivityDelete, APActivityType: ap.ActivityDelete,
GTSModel: block, GTSModel: block,
OriginAccount: account, OriginAccount: account,
TargetAccount: a, TargetAccount: account,
}) })
// if this is the last account in the slice, set the maxID appropriately for the next query return nil
if i == len(accounts)-1 { }); err != nil {
maxID = a.ID l.Errorf("error while ranging through accounts: %q", err)
} }
// Batch process all accreted messages.
p.state.Workers.EnqueueClientAPI(ctx, msgs...)
} }
// domainUnblockSideEffects processes the side effects of undoing a
// domain block:
//
// 1. Mark instance entry as no longer suspended.
// 2. Mark each account from the domain as no longer suspended, if the
// suspension origin corresponds to the ID of the provided domain block.
//
// It should be called asynchronously, since it can take a while when
// there are many accounts present on the given domain.
func (p *Processor) domainUnblockSideEffects(ctx context.Context, block *gtsmodel.DomainBlock) {
l := log.
WithContext(ctx).
WithFields(kv.Fields{
{"domain", block.Domain},
}...)
l.Debug("processing domain unblock side effects")
// Update instance entry for this domain, if we have it.
instance, err := p.state.DB.GetInstance(ctx, block.Domain)
if err != nil && !errors.Is(err, db.ErrNoEntries) {
l.Errorf("db error getting instance %s: %q", block.Domain, err)
}
if instance != nil {
// We had an entry, update it to signal
// that it's no longer suspended.
instance.SuspendedAt = time.Time{}
instance.DomainBlockID = ""
if err := p.state.DB.UpdateInstance(
ctx,
instance,
"suspended_at",
"domain_block_id",
); err != nil {
l.Errorf("db error updating instance: %s", err)
} else {
l.Debug("instance entry updated")
} }
} }
// DomainBlocksImport handles the import of a bunch of domain blocks at once, by calling the DomainBlockCreate function for each domain in the provided file. // Unsuspend all accounts whose suspension origin was this domain block.
func (p *Processor) DomainBlocksImport(ctx context.Context, account *gtsmodel.Account, domains *multipart.FileHeader) ([]*apimodel.DomainBlock, gtserror.WithCode) { if err := p.rangeAccounts(ctx, block.Domain, func(account *gtsmodel.Account) error {
f, err := domains.Open() if account.SuspensionOrigin == "" || account.SuspendedAt.IsZero() {
if err != nil { // Account wasn't suspended, nothing to do.
return nil, gtserror.NewErrorBadRequest(fmt.Errorf("DomainBlocksImport: error opening attachment: %s", err)) return nil
}
buf := new(bytes.Buffer)
size, err := io.Copy(buf, f)
if err != nil {
return nil, gtserror.NewErrorBadRequest(fmt.Errorf("DomainBlocksImport: error reading attachment: %s", err))
}
if size == 0 {
return nil, gtserror.NewErrorBadRequest(errors.New("DomainBlocksImport: could not read provided attachment: size 0 bytes"))
} }
d := []apimodel.DomainBlock{} if account.SuspensionOrigin != block.ID {
if err := json.Unmarshal(buf.Bytes(), &d); err != nil { // Account was suspended, but not by
return nil, gtserror.NewErrorBadRequest(fmt.Errorf("DomainBlocksImport: could not read provided attachment: %s", err)) // this domain block, leave it alone.
return nil
} }
blocks := []*apimodel.DomainBlock{} // Account was suspended by this domain
for _, d := range d { // block, mark it as unsuspended.
block, err := p.DomainBlockCreate(ctx, account, d.Domain.Domain, false, d.PublicComment, "", "") account.SuspendedAt = time.Time{}
if err != nil { account.SuspensionOrigin = ""
return nil, err
if err := p.state.DB.UpdateAccount(
ctx,
account,
"suspended_at",
"suspension_origin",
); err != nil {
return gtserror.Newf("db error updating account %s: %w", account.Username, err)
} }
blocks = append(blocks, block) return nil
} }); err != nil {
l.Errorf("error while ranging through accounts: %q", err)
return blocks, nil
}
// DomainBlocksGet returns all existing domain blocks.
// If export is true, the format will be suitable for writing out to an export.
func (p *Processor) DomainBlocksGet(ctx context.Context, account *gtsmodel.Account, export bool) ([]*apimodel.DomainBlock, gtserror.WithCode) {
domainBlocks := []*gtsmodel.DomainBlock{}
if err := p.state.DB.GetAll(ctx, &domainBlocks); err != nil {
if !errors.Is(err, db.ErrNoEntries) {
// something has gone really wrong
return nil, gtserror.NewErrorInternalError(err)
} }
} }
apiDomainBlocks := []*apimodel.DomainBlock{}
for _, b := range domainBlocks {
apiDomainBlock, err := p.tc.DomainBlockToAPIDomainBlock(ctx, b, export)
if err != nil {
return nil, gtserror.NewErrorInternalError(err)
}
apiDomainBlocks = append(apiDomainBlocks, apiDomainBlock)
}
return apiDomainBlocks, nil
}
// DomainBlockGet returns one domain block with the given id.
// If export is true, the format will be suitable for writing out to an export.
func (p *Processor) DomainBlockGet(ctx context.Context, account *gtsmodel.Account, id string, export bool) (*apimodel.DomainBlock, gtserror.WithCode) {
domainBlock := &gtsmodel.DomainBlock{}
if err := p.state.DB.GetByID(ctx, id, domainBlock); err != nil {
if !errors.Is(err, db.ErrNoEntries) {
// something has gone really wrong
return nil, gtserror.NewErrorInternalError(err)
}
// there are no entries for this ID
return nil, gtserror.NewErrorNotFound(fmt.Errorf("no entry for ID %s", id))
}
apiDomainBlock, err := p.tc.DomainBlockToAPIDomainBlock(ctx, domainBlock, export)
if err != nil {
return nil, gtserror.NewErrorInternalError(err)
}
return apiDomainBlock, nil
}
// DomainBlockDelete removes one domain block with the given ID.
func (p *Processor) DomainBlockDelete(ctx context.Context, account *gtsmodel.Account, id string) (*apimodel.DomainBlock, gtserror.WithCode) {
domainBlock := &gtsmodel.DomainBlock{}
if err := p.state.DB.GetByID(ctx, id, domainBlock); err != nil {
if !errors.Is(err, db.ErrNoEntries) {
// something has gone really wrong
return nil, gtserror.NewErrorInternalError(err)
}
// there are no entries for this ID
return nil, gtserror.NewErrorNotFound(fmt.Errorf("no entry for ID %s", id))
}
// prepare the domain block to return
apiDomainBlock, err := p.tc.DomainBlockToAPIDomainBlock(ctx, domainBlock, false)
if err != nil {
return nil, gtserror.NewErrorInternalError(err)
}
// Delete the domain block
if err := p.state.DB.DeleteDomainBlock(ctx, domainBlock.Domain); err != nil {
return nil, gtserror.NewErrorInternalError(err)
}
// remove the domain block reference from the instance, if we have an entry for it
i := &gtsmodel.Instance{}
if err := p.state.DB.GetWhere(ctx, []db.Where{
{Key: "domain", Value: domainBlock.Domain},
{Key: "domain_block_id", Value: id},
}, i); err == nil {
updatingColumns := []string{"suspended_at", "domain_block_id", "updated_at"}
i.SuspendedAt = time.Time{}
i.DomainBlockID = ""
i.UpdatedAt = time.Now()
if err := p.state.DB.UpdateByID(ctx, i, i.ID, updatingColumns...); err != nil {
return nil, gtserror.NewErrorInternalError(fmt.Errorf("couldn't update database entry for instance %s: %s", domainBlock.Domain, err))
}
}
// unsuspend all accounts whose suspension origin was this domain block
// 1. remove the 'suspended_at' entry from their accounts
if err := p.state.DB.UpdateWhere(ctx, []db.Where{
{Key: "suspension_origin", Value: domainBlock.ID},
}, "suspended_at", nil, &[]*gtsmodel.Account{}); err != nil {
return nil, gtserror.NewErrorInternalError(fmt.Errorf("database error removing suspended_at from accounts: %s", err))
}
// 2. remove the 'suspension_origin' entry from their accounts
if err := p.state.DB.UpdateWhere(ctx, []db.Where{
{Key: "suspension_origin", Value: domainBlock.ID},
}, "suspension_origin", nil, &[]*gtsmodel.Account{}); err != nil {
return nil, gtserror.NewErrorInternalError(fmt.Errorf("database error removing suspension_origin from accounts: %s", err))
}
return apiDomainBlock, nil
}

View file

@ -34,11 +34,12 @@ import (
) )
func (p *Processor) getThisInstance(ctx context.Context) (*gtsmodel.Instance, error) { func (p *Processor) getThisInstance(ctx context.Context) (*gtsmodel.Instance, error) {
i := &gtsmodel.Instance{} instance, err := p.state.DB.GetInstance(ctx, config.GetHost())
if err := p.state.DB.GetWhere(ctx, []db.Where{{Key: "domain", Value: config.GetHost()}}, i); err != nil { if err != nil {
return nil, err return nil, err
} }
return i, nil
return instance, nil
} }
func (p *Processor) InstanceGetV1(ctx context.Context) (*apimodel.InstanceV1, gtserror.WithCode) { func (p *Processor) InstanceGetV1(ctx context.Context) (*apimodel.InstanceV1, gtserror.WithCode) {
@ -137,9 +138,10 @@ func (p *Processor) InstancePeersGet(ctx context.Context, includeSuspended bool,
func (p *Processor) InstancePatch(ctx context.Context, form *apimodel.InstanceSettingsUpdateRequest) (*apimodel.InstanceV1, gtserror.WithCode) { func (p *Processor) InstancePatch(ctx context.Context, form *apimodel.InstanceSettingsUpdateRequest) (*apimodel.InstanceV1, gtserror.WithCode) {
// fetch the instance entry from the db for processing // fetch the instance entry from the db for processing
i := &gtsmodel.Instance{}
host := config.GetHost() host := config.GetHost()
if err := p.state.DB.GetWhere(ctx, []db.Where{{Key: "domain", Value: host}}, i); err != nil {
instance, err := p.state.DB.GetInstance(ctx, host)
if err != nil {
return nil, gtserror.NewErrorInternalError(fmt.Errorf("db error fetching instance %s: %s", host, err)) return nil, gtserror.NewErrorInternalError(fmt.Errorf("db error fetching instance %s: %s", host, err))
} }
@ -157,7 +159,7 @@ func (p *Processor) InstancePatch(ctx context.Context, form *apimodel.InstanceSe
return nil, gtserror.NewErrorBadRequest(err, fmt.Sprintf("site title invalid: %s", err)) return nil, gtserror.NewErrorBadRequest(err, fmt.Sprintf("site title invalid: %s", err))
} }
updatingColumns = append(updatingColumns, "title") updatingColumns = append(updatingColumns, "title")
i.Title = text.SanitizePlaintext(*form.Title) // don't allow html in site title instance.Title = text.SanitizePlaintext(*form.Title) // don't allow html in site title
} }
// validate & update site contact account if it's set on the form // validate & update site contact account if it's set on the form
@ -192,7 +194,7 @@ func (p *Processor) InstancePatch(ctx context.Context, form *apimodel.InstanceSe
return nil, gtserror.NewErrorBadRequest(err, err.Error()) return nil, gtserror.NewErrorBadRequest(err, err.Error())
} }
updatingColumns = append(updatingColumns, "contact_account_id") updatingColumns = append(updatingColumns, "contact_account_id")
i.ContactAccountID = contactAccount.ID instance.ContactAccountID = contactAccount.ID
} }
// validate & update site contact email if it's set on the form // validate & update site contact email if it's set on the form
@ -204,7 +206,7 @@ func (p *Processor) InstancePatch(ctx context.Context, form *apimodel.InstanceSe
} }
} }
updatingColumns = append(updatingColumns, "contact_email") updatingColumns = append(updatingColumns, "contact_email")
i.ContactEmail = contactEmail instance.ContactEmail = contactEmail
} }
// validate & update site short description if it's set on the form // validate & update site short description if it's set on the form
@ -213,7 +215,7 @@ func (p *Processor) InstancePatch(ctx context.Context, form *apimodel.InstanceSe
return nil, gtserror.NewErrorBadRequest(err, err.Error()) return nil, gtserror.NewErrorBadRequest(err, err.Error())
} }
updatingColumns = append(updatingColumns, "short_description") updatingColumns = append(updatingColumns, "short_description")
i.ShortDescription = text.SanitizeHTML(*form.ShortDescription) // html is OK in site description, but we should sanitize it instance.ShortDescription = text.SanitizeHTML(*form.ShortDescription) // html is OK in site description, but we should sanitize it
} }
// validate & update site description if it's set on the form // validate & update site description if it's set on the form
@ -222,7 +224,7 @@ func (p *Processor) InstancePatch(ctx context.Context, form *apimodel.InstanceSe
return nil, gtserror.NewErrorBadRequest(err, err.Error()) return nil, gtserror.NewErrorBadRequest(err, err.Error())
} }
updatingColumns = append(updatingColumns, "description") updatingColumns = append(updatingColumns, "description")
i.Description = text.SanitizeHTML(*form.Description) // html is OK in site description, but we should sanitize it instance.Description = text.SanitizeHTML(*form.Description) // html is OK in site description, but we should sanitize it
} }
// validate & update site terms if it's set on the form // validate & update site terms if it's set on the form
@ -231,7 +233,7 @@ func (p *Processor) InstancePatch(ctx context.Context, form *apimodel.InstanceSe
return nil, gtserror.NewErrorBadRequest(err, err.Error()) return nil, gtserror.NewErrorBadRequest(err, err.Error())
} }
updatingColumns = append(updatingColumns, "terms") updatingColumns = append(updatingColumns, "terms")
i.Terms = text.SanitizeHTML(*form.Terms) // html is OK in site terms, but we should sanitize it instance.Terms = text.SanitizeHTML(*form.Terms) // html is OK in site terms, but we should sanitize it
} }
var updateInstanceAccount bool var updateInstanceAccount bool
@ -273,12 +275,12 @@ func (p *Processor) InstancePatch(ctx context.Context, form *apimodel.InstanceSe
} }
if len(updatingColumns) != 0 { if len(updatingColumns) != 0 {
if err := p.state.DB.UpdateByID(ctx, i, i.ID, updatingColumns...); err != nil { if err := p.state.DB.UpdateInstance(ctx, instance, updatingColumns...); err != nil {
return nil, gtserror.NewErrorInternalError(fmt.Errorf("db error updating instance %s: %s", host, err)) return nil, gtserror.NewErrorInternalError(fmt.Errorf("db error updating instance %s: %s", host, err))
} }
} }
ai, err := p.tc.InstanceToAPIV1Instance(ctx, i) ai, err := p.tc.InstanceToAPIV1Instance(ctx, instance)
if err != nil { if err != nil {
return nil, gtserror.NewErrorInternalError(fmt.Errorf("error converting instance to api representation: %s", err)) return nil, gtserror.NewErrorInternalError(fmt.Errorf("error converting instance to api representation: %s", err))
} }

View file

@ -40,6 +40,9 @@ EXPECT=$(cat <<"EOF"
"follow-request-ttl": 1800000000000, "follow-request-ttl": 1800000000000,
"follow-sweep-freq": 60000000000, "follow-sweep-freq": 60000000000,
"follow-ttl": 1800000000000, "follow-ttl": 1800000000000,
"instance-max-size": 2000,
"instance-sweep-freq": 60000000000,
"instance-ttl": 1800000000000,
"list-entry-max-size": 2000, "list-entry-max-size": 2000,
"list-entry-sweep-freq": 60000000000, "list-entry-sweep-freq": 60000000000,
"list-entry-ttl": 1800000000000, "list-entry-ttl": 1800000000000,