[feature] Filters v1 (#2594)

* Implement client-side v1 filters

* Exclude linter false positives

* Update test/envparsing.sh

* Fix minor Swagger, style, and Bun usage issues

* Regenerate Swagger

* De-generify filter keywords

* Remove updating filter statuses

This is an operation that the Mastodon v2 filter API doesn't actually have, because filter statuses, unlike keywords, don't have options: the only info they contain is the status ID to be filtered.

* Add a test for filter statuses specifically

* De-generify filter statuses

* Inline FilterEntry

* Use vertical style for Bun operations consistently

* Add comment on Filter DB interface

* Remove GoLand linter control comments

Our existing linters should catch these, or they don't matter very much

* Reduce memory ratio for filters
This commit is contained in:
Vyr Cossont 2024-03-06 02:15:58 -08:00 committed by GitHub
parent 7bc536d1f7
commit 61a2b91f45
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
50 changed files with 4672 additions and 52 deletions

View file

@ -83,3 +83,12 @@ linters-settings:
# Enable all checks, but disable SA1012: nil context passing. # Enable all checks, but disable SA1012: nil context passing.
# See: https://staticcheck.io/docs/configuration/options/#checks # See: https://staticcheck.io/docs/configuration/options/#checks
checks: ["all", "-SA1012"] checks: ["all", "-SA1012"]
issues:
exclude-rules:
# Exclude VSCode custom folding region comments in files that use them.
# Already fixed in go-critic and can be removed next time go-critic is updated.
- linters:
- gocritic
path: internal/db/filter.go
text: 'commentFormatting: put a space between `//` and comment text'

View file

@ -1209,6 +1209,61 @@ definitions:
type: object type: object
x-go-name: Field x-go-name: Field
x-go-package: github.com/superseriousbusiness/gotosocial/internal/api/model x-go-package: github.com/superseriousbusiness/gotosocial/internal/api/model
filterContext:
description: v1 and v2 filter APIs use the same set of contexts.
title: FilterContext represents the context in which to apply a filter.
type: string
x-go-name: FilterContext
x-go-package: github.com/superseriousbusiness/gotosocial/internal/api/model
filterV1:
description: |-
Note that v1 filters are mapped to v2 filters and v2 filter keywords internally.
If whole_word is true, client app should do:
Define word constituent character for your app. In the official implementation, its [A-Za-z0-9_] in JavaScript, and [[:word:]] in Ruby.
Ruby uses the POSIX character class (Letter | Mark | Decimal_Number | Connector_Punctuation).
If the phrase starts with a word character, and if the previous character before matched range is a word character, its matched range should be treated to not match.
If the phrase ends with a word character, and if the next character after matched range is a word character, its matched range should be treated to not match.
Please check app/javascript/mastodon/selectors/index.js and app/lib/feed_manager.rb in the Mastodon source code for more details.
properties:
context:
description: The contexts in which the filter should be applied.
example:
- home
- public
items:
$ref: '#/definitions/filterContext'
minLength: 1
type: array
uniqueItems: true
x-go-name: Context
expires_at:
description: When the filter should no longer be applied. Null if the filter does not expire.
example: "2024-02-01T02:57:49Z"
type: string
x-go-name: ExpiresAt
id:
description: The ID of the filter in the database.
type: string
x-go-name: ID
irreversible:
description: Should matching entities be removed from the user's timelines/views, instead of hidden?
example: false
type: boolean
x-go-name: Irreversible
phrase:
description: The text to be filtered.
example: fnord
type: string
x-go-name: Phrase
whole_word:
description: Should the filter consider word boundaries?
example: true
type: boolean
x-go-name: WholeWord
title: FilterV1 represents a user-defined filter for determining which statuses should not be shown to the user.
type: object
x-go-name: FilterV1
x-go-package: github.com/superseriousbusiness/gotosocial/internal/api/model
headerFilterCreateRequest: headerFilterCreateRequest:
properties: properties:
header: header:
@ -5570,6 +5625,246 @@ paths:
summary: Get an array of all hashtags that you currently have featured on your profile. summary: Get an array of all hashtags that you currently have featured on your profile.
tags: tags:
- featured_tags - featured_tags
/api/v1/filters:
get:
operationId: filtersV1Get
produces:
- application/json
responses:
"200":
description: Requested filters.
schema:
$ref: '#/definitions/filterV1'
"400":
description: bad request
"401":
description: unauthorized
"404":
description: not found
"406":
description: not acceptable
"500":
description: internal server error
security:
- OAuth2 Bearer:
- read:filters
summary: Get all filters for the authenticated account.
tags:
- filters
post:
consumes:
- application/json
- application/xml
- application/x-www-form-urlencoded
operationId: filterV1Post
parameters:
- description: The text to be filtered.
example: fnord
in: formData
maxLength: 40
name: phrase
required: true
type: string
- description: The contexts in which the filter should be applied.
enum:
- home
- notifications
- public
- thread
- account
example:
- home
- public
in: formData
items:
$ref: '#/definitions/filterContext'
minLength: 1
name: context
required: true
type: array
uniqueItems: true
- description: Number of seconds from now that the filter should expire. If omitted, filter never expires.
example: 86400
in: formData
name: expires_in
type: number
- default: false
description: Should matching entities be removed from the user's timelines/views, instead of hidden? Not supported yet.
example: false
in: formData
name: irreversible
type: boolean
- default: false
description: Should the filter consider word boundaries?
example: true
in: formData
name: whole_word
type: boolean
produces:
- application/json
responses:
"200":
description: New filter.
schema:
$ref: '#/definitions/filterV1'
"400":
description: bad request
"401":
description: unauthorized
"404":
description: not found
"406":
description: not acceptable
"422":
description: unprocessable content
"500":
description: internal server error
security:
- OAuth2 Bearer:
- write:filters
summary: Create a single filter.
tags:
- filters
/api/v1/filters/{id}:
delete:
operationId: filterV1Delete
parameters:
- description: ID of the list
in: path
name: id
required: true
type: string
produces:
- application/json
responses:
"200":
description: filter deleted
"400":
description: bad request
"401":
description: unauthorized
"404":
description: not found
"406":
description: not acceptable
"500":
description: internal server error
security:
- OAuth2 Bearer:
- write:filters
summary: Delete a single filter with the given ID.
tags:
- filters
get:
operationId: filterV1Get
parameters:
- description: ID of the filter
in: path
name: id
required: true
type: string
produces:
- application/json
responses:
"200":
description: Requested filter.
schema:
$ref: '#/definitions/filterV1'
"400":
description: bad request
"401":
description: unauthorized
"404":
description: not found
"406":
description: not acceptable
"500":
description: internal server error
security:
- OAuth2 Bearer:
- read:filters
summary: Get a single filter with the given ID.
tags:
- filters
put:
consumes:
- application/json
- application/xml
- application/x-www-form-urlencoded
operationId: filterV1Put
parameters:
- description: ID of the filter.
in: path
name: id
required: true
type: string
- description: The text to be filtered.
example: fnord
in: formData
maxLength: 40
name: phrase
required: true
type: string
- description: The contexts in which the filter should be applied.
enum:
- home
- notifications
- public
- thread
- account
example:
- home
- public
in: formData
items:
$ref: '#/definitions/filterContext'
minLength: 1
name: context
required: true
type: array
uniqueItems: true
- description: Number of seconds from now that the filter should expire. If omitted, filter never expires.
example: 86400
in: formData
name: expires_in
type: number
- default: false
description: Should matching entities be removed from the user's timelines/views, instead of hidden? Not supported yet.
example: false
in: formData
name: irreversible
type: boolean
- default: false
description: Should the filter consider word boundaries?
example: true
in: formData
name: whole_word
type: boolean
produces:
- application/json
responses:
"200":
description: Updated filter.
schema:
$ref: '#/definitions/filterV1'
"400":
description: bad request
"401":
description: unauthorized
"404":
description: not found
"406":
description: not acceptable
"422":
description: unprocessable content
"500":
description: internal server error
security:
- OAuth2 Bearer:
- write:filters
summary: Update a single filter with the given ID.
tags:
- filters
/api/v1/follow_requests: /api/v1/follow_requests:
get: get:
description: |- description: |-
@ -7971,6 +8266,7 @@ securityDefinitions:
read:blocks: grant read access to blocks read:blocks: grant read access to blocks
read:custom_emojis: grant read access to custom_emojis read:custom_emojis: grant read access to custom_emojis
read:favourites: grant read access to favourites read:favourites: grant read access to favourites
read:filters: grant read access to filters
read:follows: grant read access to follows read:follows: grant read access to follows
read:lists: grant read access to lists read:lists: grant read access to lists
read:media: grant read access to media read:media: grant read access to media
@ -7983,6 +8279,7 @@ securityDefinitions:
write: grants write access to everything write: grants write access to everything
write:accounts: grants write access to accounts write:accounts: grants write access to accounts
write:blocks: grants write access to blocks write:blocks: grants write access to blocks
write:filters: grants write access to filters
write:follows: grants write access to follows write:follows: grants write access to follows
write:lists: grants write access to lists write:lists: grants write access to lists
write:media: grants write access to media write:media: grants write access to media

View file

@ -36,6 +36,7 @@
// read:blocks: grant read access to blocks // read:blocks: grant read access to blocks
// read:custom_emojis: grant read access to custom_emojis // read:custom_emojis: grant read access to custom_emojis
// read:favourites: grant read access to favourites // read:favourites: grant read access to favourites
// read:filters: grant read access to filters
// read:follows: grant read access to follows // read:follows: grant read access to follows
// read:lists: grant read access to lists // read:lists: grant read access to lists
// read:media: grant read access to media // read:media: grant read access to media
@ -48,6 +49,7 @@
// write: grants write access to everything // write: grants write access to everything
// write:accounts: grants write access to accounts // write:accounts: grants write access to accounts
// write:blocks: grants write access to blocks // write:blocks: grants write access to blocks
// write:filters: grants write access to filters
// write:follows: grants write access to follows // write:follows: grants write access to follows
// write:lists: grants write access to lists // write:lists: grants write access to lists
// write:media: grants write access to media // write:media: grants write access to media

View file

@ -29,7 +29,7 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/api/client/customemojis" "github.com/superseriousbusiness/gotosocial/internal/api/client/customemojis"
"github.com/superseriousbusiness/gotosocial/internal/api/client/favourites" "github.com/superseriousbusiness/gotosocial/internal/api/client/favourites"
"github.com/superseriousbusiness/gotosocial/internal/api/client/featuredtags" "github.com/superseriousbusiness/gotosocial/internal/api/client/featuredtags"
filter "github.com/superseriousbusiness/gotosocial/internal/api/client/filters" filtersV1 "github.com/superseriousbusiness/gotosocial/internal/api/client/filters/v1"
"github.com/superseriousbusiness/gotosocial/internal/api/client/followrequests" "github.com/superseriousbusiness/gotosocial/internal/api/client/followrequests"
"github.com/superseriousbusiness/gotosocial/internal/api/client/instance" "github.com/superseriousbusiness/gotosocial/internal/api/client/instance"
"github.com/superseriousbusiness/gotosocial/internal/api/client/lists" "github.com/superseriousbusiness/gotosocial/internal/api/client/lists"
@ -62,7 +62,7 @@ type Client struct {
customEmojis *customemojis.Module // api/v1/custom_emojis customEmojis *customemojis.Module // api/v1/custom_emojis
favourites *favourites.Module // api/v1/favourites favourites *favourites.Module // api/v1/favourites
featuredTags *featuredtags.Module // api/v1/featured_tags featuredTags *featuredtags.Module // api/v1/featured_tags
filters *filter.Module // api/v1/filters filtersV1 *filtersV1.Module // api/v1/filters
followRequests *followrequests.Module // api/v1/follow_requests followRequests *followrequests.Module // api/v1/follow_requests
instance *instance.Module // api/v1/instance instance *instance.Module // api/v1/instance
lists *lists.Module // api/v1/lists lists *lists.Module // api/v1/lists
@ -104,7 +104,7 @@ func (c *Client) Route(r *router.Router, m ...gin.HandlerFunc) {
c.customEmojis.Route(h) c.customEmojis.Route(h)
c.favourites.Route(h) c.favourites.Route(h)
c.featuredTags.Route(h) c.featuredTags.Route(h)
c.filters.Route(h) c.filtersV1.Route(h)
c.followRequests.Route(h) c.followRequests.Route(h)
c.instance.Route(h) c.instance.Route(h)
c.lists.Route(h) c.lists.Route(h)
@ -134,7 +134,7 @@ func NewClient(db db.DB, p *processing.Processor) *Client {
customEmojis: customemojis.New(p), customEmojis: customemojis.New(p),
favourites: favourites.New(p), favourites: favourites.New(p),
featuredTags: featuredtags.New(p), featuredTags: featuredtags.New(p),
filters: filter.New(p), filtersV1: filtersV1.New(p),
followRequests: followrequests.New(p), followRequests: followrequests.New(p),
instance: instance.New(p), instance: instance.New(p),
lists: lists.New(p), lists: lists.New(p),

View file

@ -15,20 +15,23 @@
// You should have received a copy of the GNU Affero General Public License // You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>. // along with this program. If not, see <http://www.gnu.org/licenses/>.
package filter package v1
import ( import (
"net/http"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util"
"github.com/superseriousbusiness/gotosocial/internal/processing" "github.com/superseriousbusiness/gotosocial/internal/processing"
"net/http"
) )
const ( const (
// BasePath is the base path for serving the filters API, minus the 'api' prefix // BasePath is the base path for serving the filters API, minus the 'api' prefix
BasePath = "/v1/filters" BasePath = "/v1/filters"
// BasePathWithID is the base path with the ID key in it, for operations on an existing filter.
BasePathWithID = BasePath + "/:" + apiutil.IDKey
) )
// Module implements APIs for client-side aka "v1" filtering.
type Module struct { type Module struct {
processor *processing.Processor processor *processing.Processor
} }
@ -41,4 +44,8 @@ func New(processor *processing.Processor) *Module {
func (m *Module) Route(attachHandler func(method string, path string, f ...gin.HandlerFunc) gin.IRoutes) { func (m *Module) Route(attachHandler func(method string, path string, f ...gin.HandlerFunc) gin.IRoutes) {
attachHandler(http.MethodGet, BasePath, m.FiltersGETHandler) attachHandler(http.MethodGet, BasePath, m.FiltersGETHandler)
attachHandler(http.MethodPost, BasePath, m.FilterPOSTHandler)
attachHandler(http.MethodGet, BasePathWithID, m.FilterGETHandler)
attachHandler(http.MethodPut, BasePathWithID, m.FilterPUTHandler)
attachHandler(http.MethodDelete, BasePathWithID, m.FilterDELETEHandler)
} }

View file

@ -0,0 +1,117 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package v1_test
import (
"github.com/stretchr/testify/suite"
filtersV1 "github.com/superseriousbusiness/gotosocial/internal/api/client/filters/v1"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/email"
"github.com/superseriousbusiness/gotosocial/internal/federation"
"github.com/superseriousbusiness/gotosocial/internal/filter/visibility"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/media"
"github.com/superseriousbusiness/gotosocial/internal/processing"
"github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/storage"
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
"github.com/superseriousbusiness/gotosocial/testrig"
"testing"
)
type FiltersTestSuite struct {
suite.Suite
db db.DB
storage *storage.Driver
mediaManager *media.Manager
federator *federation.Federator
processor *processing.Processor
emailSender email.Sender
sentEmails map[string]string
state state.State
// standard suite models
testTokens map[string]*gtsmodel.Token
testClients map[string]*gtsmodel.Client
testApplications map[string]*gtsmodel.Application
testUsers map[string]*gtsmodel.User
testAccounts map[string]*gtsmodel.Account
testStatuses map[string]*gtsmodel.Status
testFilters map[string]*gtsmodel.Filter
testFilterKeywords map[string]*gtsmodel.FilterKeyword
testFilterStatuses map[string]*gtsmodel.FilterStatus
// module being tested
filtersModule *filtersV1.Module
}
func (suite *FiltersTestSuite) SetupSuite() {
suite.testTokens = testrig.NewTestTokens()
suite.testClients = testrig.NewTestClients()
suite.testApplications = testrig.NewTestApplications()
suite.testUsers = testrig.NewTestUsers()
suite.testAccounts = testrig.NewTestAccounts()
suite.testStatuses = testrig.NewTestStatuses()
suite.testFilters = testrig.NewTestFilters()
suite.testFilterKeywords = testrig.NewTestFilterKeywords()
suite.testFilterStatuses = testrig.NewTestFilterStatuses()
}
func (suite *FiltersTestSuite) SetupTest() {
suite.state.Caches.Init()
testrig.StartNoopWorkers(&suite.state)
testrig.InitTestConfig()
config.Config(func(cfg *config.Configuration) {
cfg.WebAssetBaseDir = "../../../../../web/assets/"
cfg.WebTemplateBaseDir = "../../../../../web/templates/"
})
testrig.InitTestLog()
suite.db = testrig.NewTestDB(&suite.state)
suite.state.DB = suite.db
suite.storage = testrig.NewInMemoryStorage()
suite.state.Storage = suite.storage
testrig.StartTimelines(
&suite.state,
visibility.NewFilter(&suite.state),
typeutils.NewConverter(&suite.state),
)
suite.mediaManager = testrig.NewTestMediaManager(&suite.state)
suite.federator = testrig.NewTestFederator(&suite.state, testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../../../../testrig/media")), suite.mediaManager)
suite.sentEmails = make(map[string]string)
suite.emailSender = testrig.NewEmailSender("../../../../../web/template/", suite.sentEmails)
suite.processor = testrig.NewTestProcessor(&suite.state, suite.federator, suite.emailSender, suite.mediaManager)
suite.filtersModule = filtersV1.New(suite.processor)
testrig.StandardDBSetup(suite.db, nil)
testrig.StandardStorageSetup(suite.storage, "../../../../../testrig/media")
}
func (suite *FiltersTestSuite) TearDownTest() {
testrig.StandardDBTeardown(suite.db)
testrig.StandardStorageTeardown(suite.storage)
testrig.StopWorkers(&suite.state)
}
func TestFiltersTestSuite(t *testing.T) {
suite.Run(t, new(FiltersTestSuite))
}

View file

@ -0,0 +1,90 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package v1
import (
"net/http"
"github.com/gin-gonic/gin"
apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/oauth"
)
// FilterDELETEHandler swagger:operation DELETE /api/v1/filters/{id} filterV1Delete
//
// Delete a single filter with the given ID.
//
// ---
// tags:
// - filters
//
// produces:
// - application/json
//
// parameters:
// -
// name: id
// type: string
// description: ID of the list
// in: path
// required: true
//
// security:
// - OAuth2 Bearer:
// - write:filters
//
// responses:
// '200':
// description: filter deleted
// '400':
// description: bad request
// '401':
// description: unauthorized
// '404':
// description: not found
// '406':
// description: not acceptable
// '500':
// description: internal server error
func (m *Module) FilterDELETEHandler(c *gin.Context) {
authed, err := oauth.Authed(c, true, true, true, true)
if err != nil {
apiutil.ErrorHandler(c, gtserror.NewErrorUnauthorized(err, err.Error()), m.processor.InstanceGetV1)
return
}
if _, err := apiutil.NegotiateAccept(c, apiutil.JSONAcceptHeaders...); err != nil {
apiutil.ErrorHandler(c, gtserror.NewErrorNotAcceptable(err, err.Error()), m.processor.InstanceGetV1)
return
}
id, errWithCode := apiutil.ParseID(c.Param(apiutil.IDKey))
if errWithCode != nil {
apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1)
return
}
errWithCode = m.processor.FiltersV1().Delete(c.Request.Context(), authed.Account, id)
if errWithCode != nil {
apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1)
return
}
c.JSON(http.StatusOK, apiutil.EmptyJSONObject)
}

View file

@ -0,0 +1,112 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package v1_test
import (
"encoding/json"
"io"
"net/http"
"net/http/httptest"
filtersV1 "github.com/superseriousbusiness/gotosocial/internal/api/client/filters/v1"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/oauth"
"github.com/superseriousbusiness/gotosocial/testrig"
)
func (suite *FiltersTestSuite) deleteFilter(
filterKeywordID string,
expectedHTTPStatus int,
expectedBody string,
) error {
// instantiate recorder + test context
recorder := httptest.NewRecorder()
ctx, _ := testrig.CreateGinTestContext(recorder, nil)
ctx.Set(oauth.SessionAuthorizedAccount, suite.testAccounts["local_account_1"])
ctx.Set(oauth.SessionAuthorizedToken, oauth.DBTokenToToken(suite.testTokens["local_account_1"]))
ctx.Set(oauth.SessionAuthorizedApplication, suite.testApplications["application_1"])
ctx.Set(oauth.SessionAuthorizedUser, suite.testUsers["local_account_1"])
// create the request
ctx.Request = httptest.NewRequest(http.MethodDelete, config.GetProtocol()+"://"+config.GetHost()+"/api/"+filtersV1.BasePath+"/"+filterKeywordID, nil)
ctx.Request.Header.Set("accept", "application/json")
ctx.AddParam("id", filterKeywordID)
// trigger the handler
suite.filtersModule.FilterDELETEHandler(ctx)
// read the response
result := recorder.Result()
defer result.Body.Close()
b, err := io.ReadAll(result.Body)
if err != nil {
return err
}
errs := gtserror.NewMultiError(2)
// check code + body
if resultCode := recorder.Code; expectedHTTPStatus != resultCode {
errs.Appendf("expected %d got %d", expectedHTTPStatus, resultCode)
}
// if we got an expected body, return early
if expectedBody != "" {
if string(b) != expectedBody {
errs.Appendf("expected %s got %s", expectedBody, string(b))
}
return errs.Combine()
}
resp := &struct{}{}
if err := json.Unmarshal(b, resp); err != nil {
return err
}
return nil
}
func (suite *FiltersTestSuite) TestDeleteFilter() {
id := suite.testFilterKeywords["local_account_1_filter_1_keyword_1"].ID
err := suite.deleteFilter(id, http.StatusOK, "")
if err != nil {
suite.FailNow(err.Error())
}
}
func (suite *FiltersTestSuite) TestDeleteAnotherAccountsFilter() {
id := suite.testFilterKeywords["local_account_2_filter_1_keyword_1"].ID
err := suite.deleteFilter(id, http.StatusNotFound, `{"error":"Not Found"}`)
if err != nil {
suite.FailNow(err.Error())
}
}
func (suite *FiltersTestSuite) TestDeleteNonexistentFilter() {
id := "not_even_a_real_ULID"
err := suite.deleteFilter(id, http.StatusNotFound, `{"error":"Not Found"}`)
if err != nil {
suite.FailNow(err.Error())
}
}

View file

@ -0,0 +1,93 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package v1
import (
"net/http"
"github.com/gin-gonic/gin"
apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/oauth"
)
// FilterGETHandler swagger:operation GET /api/v1/filters/{id} filterV1Get
//
// Get a single filter with the given ID.
//
// ---
// tags:
// - filters
//
// produces:
// - application/json
//
// parameters:
// -
// name: id
// type: string
// description: ID of the filter
// in: path
// required: true
//
// security:
// - OAuth2 Bearer:
// - read:filters
//
// responses:
// '200':
// name: filter
// description: Requested filter.
// schema:
// "$ref": "#/definitions/filterV1"
// '400':
// description: bad request
// '401':
// description: unauthorized
// '404':
// description: not found
// '406':
// description: not acceptable
// '500':
// description: internal server error
func (m *Module) FilterGETHandler(c *gin.Context) {
authed, err := oauth.Authed(c, true, true, true, true)
if err != nil {
apiutil.ErrorHandler(c, gtserror.NewErrorUnauthorized(err, err.Error()), m.processor.InstanceGetV1)
return
}
if _, err := apiutil.NegotiateAccept(c, apiutil.JSONAcceptHeaders...); err != nil {
apiutil.ErrorHandler(c, gtserror.NewErrorNotAcceptable(err, err.Error()), m.processor.InstanceGetV1)
return
}
id, errWithCode := apiutil.ParseID(c.Param(apiutil.IDKey))
if errWithCode != nil {
apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1)
return
}
apiFilter, errWithCode := m.processor.FiltersV1().Get(c.Request.Context(), authed.Account, id)
if errWithCode != nil {
apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1)
return
}
c.JSON(http.StatusOK, apiFilter)
}

View file

@ -0,0 +1,121 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package v1_test
import (
"encoding/json"
"io"
"net/http"
"net/http/httptest"
filtersV1 "github.com/superseriousbusiness/gotosocial/internal/api/client/filters/v1"
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/oauth"
"github.com/superseriousbusiness/gotosocial/testrig"
)
func (suite *FiltersTestSuite) getFilter(
filterKeywordID string,
expectedHTTPStatus int,
expectedBody string,
) (*apimodel.FilterV1, error) {
// instantiate recorder + test context
recorder := httptest.NewRecorder()
ctx, _ := testrig.CreateGinTestContext(recorder, nil)
ctx.Set(oauth.SessionAuthorizedAccount, suite.testAccounts["local_account_1"])
ctx.Set(oauth.SessionAuthorizedToken, oauth.DBTokenToToken(suite.testTokens["local_account_1"]))
ctx.Set(oauth.SessionAuthorizedApplication, suite.testApplications["application_1"])
ctx.Set(oauth.SessionAuthorizedUser, suite.testUsers["local_account_1"])
// create the request
ctx.Request = httptest.NewRequest(http.MethodGet, config.GetProtocol()+"://"+config.GetHost()+"/api/"+filtersV1.BasePath+"/"+filterKeywordID, nil)
ctx.Request.Header.Set("accept", "application/json")
ctx.AddParam("id", filterKeywordID)
// trigger the handler
suite.filtersModule.FilterGETHandler(ctx)
// read the response
result := recorder.Result()
defer result.Body.Close()
b, err := io.ReadAll(result.Body)
if err != nil {
return nil, err
}
errs := gtserror.NewMultiError(2)
// check code + body
if resultCode := recorder.Code; expectedHTTPStatus != resultCode {
errs.Appendf("expected %d got %d", expectedHTTPStatus, resultCode)
}
// if we got an expected body, return early
if expectedBody != "" {
if string(b) != expectedBody {
errs.Appendf("expected %s got %s", expectedBody, string(b))
}
return nil, errs.Combine()
}
resp := &apimodel.FilterV1{}
if err := json.Unmarshal(b, resp); err != nil {
return nil, err
}
return resp, nil
}
func (suite *FiltersTestSuite) TestGetFilter() {
// v1 filters map to individual filter keywords, but also use the settings of the associated filter.
expectedFilterGtsModel := suite.testFilters["local_account_1_filter_1"]
expectedFilterKeywordGtsModel := suite.testFilterKeywords["local_account_1_filter_1_keyword_1"]
filter, err := suite.getFilter(expectedFilterKeywordGtsModel.ID, http.StatusOK, "")
if err != nil {
suite.FailNow(err.Error())
}
suite.NotEmpty(filter)
suite.Equal(expectedFilterGtsModel.Action == gtsmodel.FilterActionHide, filter.Irreversible)
suite.Equal(expectedFilterKeywordGtsModel.ID, filter.ID)
suite.Equal(expectedFilterKeywordGtsModel.Keyword, filter.Phrase)
}
func (suite *FiltersTestSuite) TestGetAnotherAccountsFilter() {
id := suite.testFilterKeywords["local_account_2_filter_1_keyword_1"].ID
_, err := suite.getFilter(id, http.StatusNotFound, `{"error":"Not Found"}`)
if err != nil {
suite.FailNow(err.Error())
}
}
func (suite *FiltersTestSuite) TestGetNonexistentFilter() {
id := "not_even_a_real_ULID"
_, err := suite.getFilter(id, http.StatusNotFound, `{"error":"Not Found"}`)
if err != nil {
suite.FailNow(err.Error())
}
}

View file

@ -0,0 +1,147 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package v1
import (
"net/http"
"github.com/gin-gonic/gin"
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/oauth"
)
// FilterPOSTHandler swagger:operation POST /api/v1/filters filterV1Post
//
// Create a single filter.
//
// ---
// tags:
// - filters
//
// consumes:
// - application/json
// - application/xml
// - application/x-www-form-urlencoded
//
// produces:
// - application/json
//
// parameters:
// -
// name: phrase
// in: formData
// required: true
// description: The text to be filtered.
// maxLength: 40
// type: string
// example: "fnord"
// -
// name: context
// in: formData
// required: true
// description: The contexts in which the filter should be applied.
// enum:
// - home
// - notifications
// - public
// - thread
// - account
// example:
// - home
// - public
// items:
// $ref: '#/definitions/filterContext'
// minLength: 1
// type: array
// uniqueItems: true
// -
// name: expires_in
// in: formData
// description: Number of seconds from now that the filter should expire. If omitted, filter never expires.
// type: number
// example: 86400
// -
// name: irreversible
// in: formData
// description: Should matching entities be removed from the user's timelines/views, instead of hidden? Not supported yet.
// type: boolean
// default: false
// example: false
// -
// name: whole_word
// in: formData
// description: Should the filter consider word boundaries?
// type: boolean
// default: false
// example: true
//
// security:
// - OAuth2 Bearer:
// - write:filters
//
// responses:
// '200':
// name: filter
// description: New filter.
// schema:
// "$ref": "#/definitions/filterV1"
// '400':
// description: bad request
// '401':
// description: unauthorized
// '404':
// description: not found
// '406':
// description: not acceptable
// '422':
// description: unprocessable content
// '500':
// description: internal server error
func (m *Module) FilterPOSTHandler(c *gin.Context) {
authed, err := oauth.Authed(c, true, true, true, true)
if err != nil {
apiutil.ErrorHandler(c, gtserror.NewErrorUnauthorized(err, err.Error()), m.processor.InstanceGetV1)
return
}
if _, err := apiutil.NegotiateAccept(c, apiutil.JSONAcceptHeaders...); err != nil {
apiutil.ErrorHandler(c, gtserror.NewErrorNotAcceptable(err, err.Error()), m.processor.InstanceGetV1)
return
}
form := &apimodel.FilterCreateUpdateRequestV1{}
if err := c.ShouldBind(form); err != nil {
apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGetV1)
return
}
if err := validateNormalizeCreateUpdateFilter(form); err != nil {
apiutil.ErrorHandler(c, gtserror.NewErrorUnprocessableEntity(err, err.Error()), m.processor.InstanceGetV1)
return
}
apiFilter, errWithCode := m.processor.FiltersV1().Create(c.Request.Context(), authed.Account, form)
if errWithCode != nil {
apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1)
return
}
apiutil.JSON(c, http.StatusOK, apiFilter)
}

View file

@ -0,0 +1,239 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package v1_test
import (
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"net/url"
"strconv"
"strings"
filtersV1 "github.com/superseriousbusiness/gotosocial/internal/api/client/filters/v1"
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/oauth"
"github.com/superseriousbusiness/gotosocial/testrig"
)
func (suite *FiltersTestSuite) postFilter(
phrase *string,
context *[]string,
irreversible *bool,
wholeWord *bool,
expiresIn *int,
requestJson *string,
expectedHTTPStatus int,
expectedBody string,
) (*apimodel.FilterV1, error) {
// instantiate recorder + test context
recorder := httptest.NewRecorder()
ctx, _ := testrig.CreateGinTestContext(recorder, nil)
ctx.Set(oauth.SessionAuthorizedAccount, suite.testAccounts["local_account_1"])
ctx.Set(oauth.SessionAuthorizedToken, oauth.DBTokenToToken(suite.testTokens["local_account_1"]))
ctx.Set(oauth.SessionAuthorizedApplication, suite.testApplications["application_1"])
ctx.Set(oauth.SessionAuthorizedUser, suite.testUsers["local_account_1"])
// create the request
ctx.Request = httptest.NewRequest(http.MethodPost, config.GetProtocol()+"://"+config.GetHost()+"/api/"+filtersV1.BasePath, nil)
ctx.Request.Header.Set("accept", "application/json")
if requestJson != nil {
ctx.Request.Header.Set("content-type", "application/json")
ctx.Request.Body = io.NopCloser(strings.NewReader(*requestJson))
} else {
ctx.Request.Form = make(url.Values)
if phrase != nil {
ctx.Request.Form["phrase"] = []string{*phrase}
}
if context != nil {
ctx.Request.Form["context[]"] = *context
}
if irreversible != nil {
ctx.Request.Form["irreversible"] = []string{strconv.FormatBool(*irreversible)}
}
if wholeWord != nil {
ctx.Request.Form["whole_word"] = []string{strconv.FormatBool(*wholeWord)}
}
if expiresIn != nil {
ctx.Request.Form["expires_in"] = []string{strconv.Itoa(*expiresIn)}
}
}
// trigger the handler
suite.filtersModule.FilterPOSTHandler(ctx)
// read the response
result := recorder.Result()
defer result.Body.Close()
b, err := io.ReadAll(result.Body)
if err != nil {
return nil, err
}
errs := gtserror.NewMultiError(2)
// check code + body
if resultCode := recorder.Code; expectedHTTPStatus != resultCode {
errs.Appendf("expected %d got %d", expectedHTTPStatus, resultCode)
}
// if we got an expected body, return early
if expectedBody != "" {
if string(b) != expectedBody {
errs.Appendf("expected %s got %s", expectedBody, string(b))
}
return nil, errs.Combine()
}
resp := &apimodel.FilterV1{}
if err := json.Unmarshal(b, resp); err != nil {
return nil, err
}
return resp, nil
}
func (suite *FiltersTestSuite) TestPostFilterFull() {
phrase := "GNU/Linux"
context := []string{"home", "public"}
irreversible := false
wholeWord := true
expiresIn := 86400
filter, err := suite.postFilter(&phrase, &context, &irreversible, &wholeWord, &expiresIn, nil, http.StatusOK, "")
if err != nil {
suite.FailNow(err.Error())
}
suite.Equal(phrase, filter.Phrase)
filterContext := make([]string, 0, len(filter.Context))
for _, c := range filter.Context {
filterContext = append(filterContext, string(c))
}
suite.ElementsMatch(context, filterContext)
suite.Equal(irreversible, filter.Irreversible)
suite.Equal(wholeWord, filter.WholeWord)
if suite.NotNil(filter.ExpiresAt) {
suite.NotEmpty(*filter.ExpiresAt)
}
}
func (suite *FiltersTestSuite) TestPostFilterFullJSON() {
// Use a numeric literal with a fractional part to test the JSON-specific handling for non-integer "expires_in".
requestJson := `{
"phrase":"GNU/Linux",
"context": ["home", "public"],
"irreversible": false,
"whole_word": true,
"expires_in": 86400.1
}`
filter, err := suite.postFilter(nil, nil, nil, nil, nil, &requestJson, http.StatusOK, "")
if err != nil {
suite.FailNow(err.Error())
}
suite.Equal("GNU/Linux", filter.Phrase)
suite.ElementsMatch(
[]apimodel.FilterContext{
apimodel.FilterContextHome,
apimodel.FilterContextPublic,
},
filter.Context,
)
suite.Equal(false, filter.Irreversible)
suite.Equal(true, filter.WholeWord)
if suite.NotNil(filter.ExpiresAt) {
suite.NotEmpty(*filter.ExpiresAt)
}
}
func (suite *FiltersTestSuite) TestPostFilterMinimal() {
phrase := "GNU/Linux"
context := []string{"home"}
filter, err := suite.postFilter(&phrase, &context, nil, nil, nil, nil, http.StatusOK, "")
if err != nil {
suite.FailNow(err.Error())
}
suite.Equal(phrase, filter.Phrase)
filterContext := make([]string, 0, len(filter.Context))
for _, c := range filter.Context {
filterContext = append(filterContext, string(c))
}
suite.ElementsMatch(context, filterContext)
suite.False(filter.Irreversible)
suite.False(filter.WholeWord)
suite.Nil(filter.ExpiresAt)
}
func (suite *FiltersTestSuite) TestPostFilterEmptyPhrase() {
phrase := ""
context := []string{"home"}
_, err := suite.postFilter(&phrase, &context, nil, nil, nil, nil, http.StatusUnprocessableEntity, "")
if err != nil {
suite.FailNow(err.Error())
}
}
func (suite *FiltersTestSuite) TestPostFilterMissingPhrase() {
context := []string{"home"}
_, err := suite.postFilter(nil, &context, nil, nil, nil, nil, http.StatusUnprocessableEntity, "")
if err != nil {
suite.FailNow(err.Error())
}
}
func (suite *FiltersTestSuite) TestPostFilterEmptyContext() {
phrase := "GNU/Linux"
context := []string{}
_, err := suite.postFilter(&phrase, &context, nil, nil, nil, nil, http.StatusUnprocessableEntity, "")
if err != nil {
suite.FailNow(err.Error())
}
}
func (suite *FiltersTestSuite) TestPostFilterMissingContext() {
phrase := "GNU/Linux"
_, err := suite.postFilter(&phrase, nil, nil, nil, nil, nil, http.StatusUnprocessableEntity, "")
if err != nil {
suite.FailNow(err.Error())
}
}
// There should be a filter with this phrase as its title in our test fixtures. Creating another should fail.
func (suite *FiltersTestSuite) TestPostFilterTitleConflict() {
phrase := "fnord"
_, err := suite.postFilter(&phrase, nil, nil, nil, nil, nil, http.StatusUnprocessableEntity, "")
if err != nil {
suite.FailNow(err.Error())
}
}
// FUTURE: this should be removed once we support server-side filters.
func (suite *FiltersTestSuite) TestPostFilterIrreversibleNotSupported() {
phrase := "GNU/Linux"
context := []string{"home"}
irreversible := true
_, err := suite.postFilter(&phrase, &context, &irreversible, nil, nil, nil, http.StatusUnprocessableEntity, "")
if err != nil {
suite.FailNow(err.Error())
}
}

View file

@ -0,0 +1,159 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package v1
import (
"net/http"
"github.com/gin-gonic/gin"
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/oauth"
)
// FilterPUTHandler swagger:operation PUT /api/v1/filters/{id} filterV1Put
//
// Update a single filter with the given ID.
//
// ---
// tags:
// - filters
//
// consumes:
// - application/json
// - application/xml
// - application/x-www-form-urlencoded
//
// produces:
// - application/json
//
// parameters:
// -
// name: id
// in: path
// type: string
// required: true
// description: ID of the filter.
// -
// name: phrase
// in: formData
// required: true
// description: The text to be filtered.
// maxLength: 40
// type: string
// example: "fnord"
// -
// name: context
// in: formData
// required: true
// description: The contexts in which the filter should be applied.
// enum:
// - home
// - notifications
// - public
// - thread
// - account
// example:
// - home
// - public
// items:
// $ref: '#/definitions/filterContext'
// minLength: 1
// type: array
// uniqueItems: true
// -
// name: expires_in
// in: formData
// description: Number of seconds from now that the filter should expire. If omitted, filter never expires.
// type: number
// example: 86400
// -
// name: irreversible
// in: formData
// description: Should matching entities be removed from the user's timelines/views, instead of hidden? Not supported yet.
// type: boolean
// default: false
// example: false
// -
// name: whole_word
// in: formData
// description: Should the filter consider word boundaries?
// type: boolean
// default: false
// example: true
//
// security:
// - OAuth2 Bearer:
// - write:filters
//
// responses:
// '200':
// name: filter
// description: Updated filter.
// schema:
// "$ref": "#/definitions/filterV1"
// '400':
// description: bad request
// '401':
// description: unauthorized
// '404':
// description: not found
// '406':
// description: not acceptable
// '422':
// description: unprocessable content
// '500':
// description: internal server error
func (m *Module) FilterPUTHandler(c *gin.Context) {
authed, err := oauth.Authed(c, true, true, true, true)
if err != nil {
apiutil.ErrorHandler(c, gtserror.NewErrorUnauthorized(err, err.Error()), m.processor.InstanceGetV1)
return
}
if _, err := apiutil.NegotiateAccept(c, apiutil.JSONAcceptHeaders...); err != nil {
apiutil.ErrorHandler(c, gtserror.NewErrorNotAcceptable(err, err.Error()), m.processor.InstanceGetV1)
return
}
id, errWithCode := apiutil.ParseID(c.Param(apiutil.IDKey))
if errWithCode != nil {
apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1)
return
}
form := &apimodel.FilterCreateUpdateRequestV1{}
if err := c.ShouldBind(form); err != nil {
apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGetV1)
return
}
if err := validateNormalizeCreateUpdateFilter(form); err != nil {
apiutil.ErrorHandler(c, gtserror.NewErrorUnprocessableEntity(err, err.Error()), m.processor.InstanceGetV1)
return
}
apiFilter, errWithCode := m.processor.FiltersV1().Update(c.Request.Context(), authed.Account, id, form)
if errWithCode != nil {
apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1)
return
}
apiutil.JSON(c, http.StatusOK, apiFilter)
}

View file

@ -0,0 +1,269 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package v1_test
import (
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"net/url"
"strconv"
"strings"
filtersV1 "github.com/superseriousbusiness/gotosocial/internal/api/client/filters/v1"
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/oauth"
"github.com/superseriousbusiness/gotosocial/testrig"
)
func (suite *FiltersTestSuite) putFilter(
filterKeywordID string,
phrase *string,
context *[]string,
irreversible *bool,
wholeWord *bool,
expiresIn *int,
requestJson *string,
expectedHTTPStatus int,
expectedBody string,
) (*apimodel.FilterV1, error) {
// instantiate recorder + test context
recorder := httptest.NewRecorder()
ctx, _ := testrig.CreateGinTestContext(recorder, nil)
ctx.Set(oauth.SessionAuthorizedAccount, suite.testAccounts["local_account_1"])
ctx.Set(oauth.SessionAuthorizedToken, oauth.DBTokenToToken(suite.testTokens["local_account_1"]))
ctx.Set(oauth.SessionAuthorizedApplication, suite.testApplications["application_1"])
ctx.Set(oauth.SessionAuthorizedUser, suite.testUsers["local_account_1"])
// create the request
ctx.Request = httptest.NewRequest(http.MethodPut, config.GetProtocol()+"://"+config.GetHost()+"/api/"+filtersV1.BasePath+"/"+filterKeywordID, nil)
ctx.Request.Header.Set("accept", "application/json")
if requestJson != nil {
ctx.Request.Header.Set("content-type", "application/json")
ctx.Request.Body = io.NopCloser(strings.NewReader(*requestJson))
} else {
ctx.Request.Form = make(url.Values)
if phrase != nil {
ctx.Request.Form["phrase"] = []string{*phrase}
}
if context != nil {
ctx.Request.Form["context[]"] = *context
}
if irreversible != nil {
ctx.Request.Form["irreversible"] = []string{strconv.FormatBool(*irreversible)}
}
if wholeWord != nil {
ctx.Request.Form["whole_word"] = []string{strconv.FormatBool(*wholeWord)}
}
if expiresIn != nil {
ctx.Request.Form["expires_in"] = []string{strconv.Itoa(*expiresIn)}
}
}
ctx.AddParam("id", filterKeywordID)
// trigger the handler
suite.filtersModule.FilterPUTHandler(ctx)
// read the response
result := recorder.Result()
defer result.Body.Close()
b, err := io.ReadAll(result.Body)
if err != nil {
return nil, err
}
errs := gtserror.NewMultiError(2)
// check code + body
if resultCode := recorder.Code; expectedHTTPStatus != resultCode {
errs.Appendf("expected %d got %d", expectedHTTPStatus, resultCode)
}
// if we got an expected body, return early
if expectedBody != "" {
if string(b) != expectedBody {
errs.Appendf("expected %s got %s", expectedBody, string(b))
}
return nil, errs.Combine()
}
resp := &apimodel.FilterV1{}
if err := json.Unmarshal(b, resp); err != nil {
return nil, err
}
return resp, nil
}
func (suite *FiltersTestSuite) TestPutFilterFull() {
id := suite.testFilterKeywords["local_account_1_filter_1_keyword_1"].ID
phrase := "GNU/Linux"
context := []string{"home", "public"}
irreversible := false
wholeWord := true
expiresIn := 86400
filter, err := suite.putFilter(id, &phrase, &context, &irreversible, &wholeWord, &expiresIn, nil, http.StatusOK, "")
if err != nil {
suite.FailNow(err.Error())
}
suite.Equal(phrase, filter.Phrase)
filterContext := make([]string, 0, len(filter.Context))
for _, c := range filter.Context {
filterContext = append(filterContext, string(c))
}
suite.ElementsMatch(context, filterContext)
suite.Equal(irreversible, filter.Irreversible)
suite.Equal(wholeWord, filter.WholeWord)
if suite.NotNil(filter.ExpiresAt) {
suite.NotEmpty(*filter.ExpiresAt)
}
}
func (suite *FiltersTestSuite) TestPutFilterFullJSON() {
id := suite.testFilterKeywords["local_account_1_filter_1_keyword_1"].ID
// Use a numeric literal with a fractional part to test the JSON-specific handling for non-integer "expires_in".
requestJson := `{
"phrase":"GNU/Linux",
"context": ["home", "public"],
"irreversible": false,
"whole_word": true,
"expires_in": 86400.1
}`
filter, err := suite.putFilter(id, nil, nil, nil, nil, nil, &requestJson, http.StatusOK, "")
if err != nil {
suite.FailNow(err.Error())
}
suite.Equal("GNU/Linux", filter.Phrase)
suite.ElementsMatch(
[]apimodel.FilterContext{
apimodel.FilterContextHome,
apimodel.FilterContextPublic,
},
filter.Context,
)
suite.Equal(false, filter.Irreversible)
suite.Equal(true, filter.WholeWord)
if suite.NotNil(filter.ExpiresAt) {
suite.NotEmpty(*filter.ExpiresAt)
}
}
func (suite *FiltersTestSuite) TestPutFilterMinimal() {
id := suite.testFilterKeywords["local_account_1_filter_1_keyword_1"].ID
phrase := "GNU/Linux"
context := []string{"home"}
filter, err := suite.putFilter(id, &phrase, &context, nil, nil, nil, nil, http.StatusOK, "")
if err != nil {
suite.FailNow(err.Error())
}
suite.Equal(phrase, filter.Phrase)
filterContext := make([]string, 0, len(filter.Context))
for _, c := range filter.Context {
filterContext = append(filterContext, string(c))
}
suite.ElementsMatch(context, filterContext)
suite.False(filter.Irreversible)
suite.False(filter.WholeWord)
suite.Nil(filter.ExpiresAt)
}
func (suite *FiltersTestSuite) TestPutFilterEmptyPhrase() {
id := suite.testFilterKeywords["local_account_1_filter_1_keyword_1"].ID
phrase := ""
context := []string{"home"}
_, err := suite.putFilter(id, &phrase, &context, nil, nil, nil, nil, http.StatusUnprocessableEntity, "")
if err != nil {
suite.FailNow(err.Error())
}
}
func (suite *FiltersTestSuite) TestPutFilterMissingPhrase() {
id := suite.testFilterKeywords["local_account_1_filter_1_keyword_1"].ID
context := []string{"home"}
_, err := suite.putFilter(id, nil, &context, nil, nil, nil, nil, http.StatusUnprocessableEntity, "")
if err != nil {
suite.FailNow(err.Error())
}
}
func (suite *FiltersTestSuite) TestPutFilterEmptyContext() {
id := suite.testFilterKeywords["local_account_1_filter_1_keyword_1"].ID
phrase := "GNU/Linux"
context := []string{}
_, err := suite.putFilter(id, &phrase, &context, nil, nil, nil, nil, http.StatusUnprocessableEntity, "")
if err != nil {
suite.FailNow(err.Error())
}
}
func (suite *FiltersTestSuite) TestPutFilterMissingContext() {
id := suite.testFilterKeywords["local_account_1_filter_1_keyword_1"].ID
phrase := "GNU/Linux"
_, err := suite.putFilter(id, &phrase, nil, nil, nil, nil, nil, http.StatusUnprocessableEntity, "")
if err != nil {
suite.FailNow(err.Error())
}
}
// There should be a filter with this phrase as its title in our test fixtures. Changing ours to that title should fail.
func (suite *FiltersTestSuite) TestPutFilterTitleConflict() {
id := suite.testFilterKeywords["local_account_1_filter_1_keyword_1"].ID
phrase := "metasyntactic variables"
_, err := suite.putFilter(id, &phrase, nil, nil, nil, nil, nil, http.StatusUnprocessableEntity, "")
if err != nil {
suite.FailNow(err.Error())
}
}
// FUTURE: this should be removed once we support server-side filters.
func (suite *FiltersTestSuite) TestPutFilterIrreversibleNotSupported() {
id := suite.testFilterKeywords["local_account_1_filter_1_keyword_1"].ID
irreversible := true
_, err := suite.putFilter(id, nil, nil, &irreversible, nil, nil, nil, http.StatusUnprocessableEntity, "")
if err != nil {
suite.FailNow(err.Error())
}
}
func (suite *FiltersTestSuite) TestPutAnotherAccountsFilter() {
id := suite.testFilterKeywords["local_account_2_filter_1_keyword_1"].ID
phrase := "GNU/Linux"
context := []string{"home"}
_, err := suite.putFilter(id, &phrase, &context, nil, nil, nil, nil, http.StatusNotFound, `{"error":"Not Found"}`)
if err != nil {
suite.FailNow(err.Error())
}
}
func (suite *FiltersTestSuite) TestPutNonexistentFilter() {
id := "not_even_a_real_ULID"
phrase := "GNU/Linux"
context := []string{"home"}
_, err := suite.putFilter(id, &phrase, &context, nil, nil, nil, nil, http.StatusNotFound, `{"error":"Not Found"}`)
if err != nil {
suite.FailNow(err.Error())
}
}

View file

@ -15,7 +15,7 @@
// You should have received a copy of the GNU Affero General Public License // You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>. // along with this program. If not, see <http://www.gnu.org/licenses/>.
package filter package v1
import ( import (
"net/http" "net/http"
@ -26,9 +26,40 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/oauth" "github.com/superseriousbusiness/gotosocial/internal/oauth"
) )
// FiltersGETHandler returns a list of filters set by/for the authed account // FiltersGETHandler swagger:operation GET /api/v1/filters filtersV1Get
//
// Get all filters for the authenticated account.
//
// ---
// tags:
// - filters
//
// produces:
// - application/json
//
// security:
// - OAuth2 Bearer:
// - read:filters
//
// responses:
// '200':
// name: filter
// description: Requested filters.
// schema:
// "$ref": "#/definitions/filterV1"
// '400':
// description: bad request
// '401':
// description: unauthorized
// '404':
// description: not found
// '406':
// description: not acceptable
// '500':
// description: internal server error
func (m *Module) FiltersGETHandler(c *gin.Context) { func (m *Module) FiltersGETHandler(c *gin.Context) {
if _, err := oauth.Authed(c, true, true, true, true); err != nil { authed, err := oauth.Authed(c, true, true, true, true)
if err != nil {
apiutil.ErrorHandler(c, gtserror.NewErrorUnauthorized(err, err.Error()), m.processor.InstanceGetV1) apiutil.ErrorHandler(c, gtserror.NewErrorUnauthorized(err, err.Error()), m.processor.InstanceGetV1)
return return
} }
@ -38,5 +69,11 @@ func (m *Module) FiltersGETHandler(c *gin.Context) {
return return
} }
apiutil.Data(c, http.StatusOK, apiutil.AppJSON, apiutil.EmptyJSONArray) apiFilters, errWithCode := m.processor.FiltersV1().GetAll(c.Request.Context(), authed.Account)
if errWithCode != nil {
apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1)
return
}
c.JSON(http.StatusOK, apiFilters)
} }

View file

@ -0,0 +1,114 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package v1_test
import (
"encoding/json"
"io"
"net/http"
"net/http/httptest"
filtersV1 "github.com/superseriousbusiness/gotosocial/internal/api/client/filters/v1"
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/oauth"
"github.com/superseriousbusiness/gotosocial/testrig"
)
func (suite *FiltersTestSuite) getFilters(
expectedHTTPStatus int,
expectedBody string,
) ([]*apimodel.FilterV1, error) {
// instantiate recorder + test context
recorder := httptest.NewRecorder()
ctx, _ := testrig.CreateGinTestContext(recorder, nil)
ctx.Set(oauth.SessionAuthorizedAccount, suite.testAccounts["local_account_1"])
ctx.Set(oauth.SessionAuthorizedToken, oauth.DBTokenToToken(suite.testTokens["local_account_1"]))
ctx.Set(oauth.SessionAuthorizedApplication, suite.testApplications["application_1"])
ctx.Set(oauth.SessionAuthorizedUser, suite.testUsers["local_account_1"])
// create the request
ctx.Request = httptest.NewRequest(http.MethodGet, config.GetProtocol()+"://"+config.GetHost()+"/api/"+filtersV1.BasePath, nil)
ctx.Request.Header.Set("accept", "application/json")
// trigger the handler
suite.filtersModule.FiltersGETHandler(ctx)
// read the response
result := recorder.Result()
defer result.Body.Close()
b, err := io.ReadAll(result.Body)
if err != nil {
return nil, err
}
errs := gtserror.NewMultiError(2)
// check code + body
if resultCode := recorder.Code; expectedHTTPStatus != resultCode {
errs.Appendf("expected %d got %d", expectedHTTPStatus, resultCode)
}
// if we got an expected body, return early
if expectedBody != "" {
if string(b) != expectedBody {
errs.Appendf("expected %s got %s", expectedBody, string(b))
}
return nil, errs.Combine()
}
resp := make([]*apimodel.FilterV1, 0)
if err := json.Unmarshal(b, &resp); err != nil {
return nil, err
}
return resp, nil
}
func (suite *FiltersTestSuite) TestGetFilters() {
// v1 filters map to individual filter keywords.
expectedFilterIDs := make([]string, 0, len(suite.testFilterKeywords))
expectedFilterKeywords := make([]string, 0, len(suite.testFilterKeywords))
for _, filterKeyword := range suite.testFilterKeywords {
if filterKeyword.AccountID == suite.testAccounts["local_account_1"].ID {
expectedFilterIDs = append(expectedFilterIDs, filterKeyword.ID)
expectedFilterKeywords = append(expectedFilterKeywords, filterKeyword.Keyword)
}
}
suite.NotEmpty(expectedFilterIDs)
suite.NotEmpty(expectedFilterKeywords)
// Fetch all filters for the logged-in account.
filters, err := suite.getFilters(http.StatusOK, "")
if err != nil {
suite.FailNow(err.Error())
}
suite.NotEmpty(filters)
// Check that we got the right ones.
actualFilterIDs := make([]string, 0, len(filters))
actualFilterKeywords := make([]string, 0, len(filters))
for _, filter := range filters {
actualFilterIDs = append(actualFilterIDs, filter.ID)
actualFilterKeywords = append(actualFilterKeywords, filter.Phrase)
}
suite.ElementsMatch(expectedFilterIDs, actualFilterIDs)
suite.ElementsMatch(expectedFilterKeywords, actualFilterKeywords)
}

View file

@ -0,0 +1,68 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package v1
import (
"errors"
"fmt"
"strconv"
"github.com/superseriousbusiness/gotosocial/internal/api/model"
"github.com/superseriousbusiness/gotosocial/internal/util"
"github.com/superseriousbusiness/gotosocial/internal/validate"
)
func validateNormalizeCreateUpdateFilter(form *model.FilterCreateUpdateRequestV1) error {
if err := validate.FilterKeyword(form.Phrase); err != nil {
return err
}
if err := validate.FilterContexts(form.Context); err != nil {
return err
}
// Apply defaults for missing fields.
form.WholeWord = util.Ptr(util.PtrValueOr(form.WholeWord, false))
form.Irreversible = util.Ptr(util.PtrValueOr(form.Irreversible, false))
if *form.Irreversible {
return errors.New("irreversible aka server-side drop filters are not supported yet")
}
// Normalize filter expiry if necessary.
// If we parsed this as JSON, expires_in
// may be either a float64 or a string.
if ei := form.ExpiresInI; ei != nil {
switch e := ei.(type) {
case float64:
form.ExpiresIn = util.Ptr(int(e))
case string:
expiresIn, err := strconv.Atoi(e)
if err != nil {
return fmt.Errorf("could not parse expires_in value %s as integer: %w", e, err)
}
form.ExpiresIn = &expiresIn
default:
return fmt.Errorf("could not parse expires_in type %T as integer", ei)
}
}
return nil
}

View file

@ -17,29 +17,23 @@
package model package model
// Filter represents a user-defined filter for determining which statuses should not be shown to the user. // FilterContext represents the context in which to apply a filter.
// If whole_word is true , client app should do: // v1 and v2 filter APIs use the same set of contexts.
// Define word constituent character for your app. In the official implementation, its [A-Za-z0-9_] in JavaScript, and [[:word:]] in Ruby. //
// Ruby uses the POSIX character class (Letter | Mark | Decimal_Number | Connector_Punctuation). // swagger:model filterContext
// If the phrase starts with a word character, and if the previous character before matched range is a word character, its matched range should be treated to not match. type FilterContext string
// If the phrase ends with a word character, and if the next character after matched range is a word character, its matched range should be treated to not match.
// Please check app/javascript/mastodon/selectors/index.js and app/lib/feed_manager.rb in the Mastodon source code for more details. const (
type Filter struct { // FilterContextHome means this filter should be applied to the home timeline and lists.
// The ID of the filter in the database. FilterContextHome FilterContext = "home"
ID string `json:"id"` // FilterContextNotifications means this filter should be applied to the notifications timeline.
// The text to be filtered. FilterContextNotifications FilterContext = "notifications"
Phrase string `json:"text"` // FilterContextPublic means this filter should be applied to public timelines.
// The contexts in which the filter should be applied. FilterContextPublic FilterContext = "public"
// Array of String (Enumerable anyOf) // FilterContextThread means this filter should be applied to the expanded thread of a detailed status.
// home = home timeline and lists FilterContextThread FilterContext = "thread"
// notifications = notifications timeline // FilterContextAccount means this filter should be applied when viewing a profile.
// public = public timelines FilterContextAccount FilterContext = "account"
// thread = expanded thread of a detailed status
Context []string `json:"context"` FilterContextNumValues = 5
// Should the filter consider word boundaries? )
WholeWord bool `json:"whole_word"`
// When the filter should no longer be applied (ISO 8601 Datetime), or null if the filter does not expire
ExpiresAt string `json:"expires_at,omitempty"`
// Should matching entities in home and notifications be dropped by the server?
Irreversible bool `json:"irreversible"`
}

View file

@ -0,0 +1,99 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package model
// FilterV1 represents a user-defined filter for determining which statuses should not be shown to the user.
// Note that v1 filters are mapped to v2 filters and v2 filter keywords internally.
// If whole_word is true, client app should do:
// Define word constituent character for your app. In the official implementation, its [A-Za-z0-9_] in JavaScript, and [[:word:]] in Ruby.
// Ruby uses the POSIX character class (Letter | Mark | Decimal_Number | Connector_Punctuation).
// If the phrase starts with a word character, and if the previous character before matched range is a word character, its matched range should be treated to not match.
// If the phrase ends with a word character, and if the next character after matched range is a word character, its matched range should be treated to not match.
// Please check app/javascript/mastodon/selectors/index.js and app/lib/feed_manager.rb in the Mastodon source code for more details.
//
// swagger:model filterV1
//
// ---
// tags:
// - filters
type FilterV1 struct {
// The ID of the filter in the database.
ID string `json:"id"`
// The text to be filtered.
//
// Example: fnord
Phrase string `json:"phrase"`
// The contexts in which the filter should be applied.
//
// Minimum length: 1
// Unique: true
// Enum:
// - home
// - notifications
// - public
// - thread
// - account
// Example: ["home", "public"]
Context []FilterContext `json:"context"`
// Should the filter consider word boundaries?
//
// Example: true
WholeWord bool `json:"whole_word"`
// Should matching entities be removed from the user's timelines/views, instead of hidden?
//
// Example: false
Irreversible bool `json:"irreversible"`
// When the filter should no longer be applied. Null if the filter does not expire.
//
// Example: 2024-02-01T02:57:49Z
ExpiresAt *string `json:"expires_at"`
}
// FilterCreateUpdateRequestV1 captures params for creating or updating a v1 filter.
//
// swagger:ignore
type FilterCreateUpdateRequestV1 struct {
// The text to be filtered.
//
// Required: true
// Maximum length: 40
// Example: fnord
Phrase string `form:"phrase" json:"phrase" xml:"phrase"`
// The contexts in which the filter should be applied.
//
// Required: true
// Minimum length: 1
// Unique: true
// Enum: home,notifications,public,thread,account
// Example: ["home", "public"]
Context []FilterContext `form:"context[]" json:"context" xml:"context"`
// Should matching entities be removed from the user's timelines/views, instead of hidden?
//
// Example: false
Irreversible *bool `form:"irreversible" json:"irreversible" xml:"irreversible"`
// Should the filter consider word boundaries?
//
// Example: true
WholeWord *bool `form:"whole_word" json:"whole_word" xml:"whole_word"`
// Number of seconds from now that the filter should expire. If omitted, filter never expires.
ExpiresIn *int `json:"-" form:"expires_in" xml:"expires_in"`
// Number of seconds from now that the filter should expire. If omitted, filter never expires.
//
// Example: 86400
ExpiresInI interface{} `json:"expires_in"`
}

View file

@ -61,6 +61,9 @@ func (c *Caches) Init() {
c.initDomainBlock() c.initDomainBlock()
c.initEmoji() c.initEmoji()
c.initEmojiCategory() c.initEmojiCategory()
c.initFilter()
c.initFilterKeyword()
c.initFilterStatus()
c.initFollow() c.initFollow()
c.initFollowIDs() c.initFollowIDs()
c.initFollowRequest() c.initFollowRequest()
@ -119,6 +122,9 @@ func (c *Caches) Sweep(threshold float64) {
c.GTS.BlockIDs.Trim(threshold) c.GTS.BlockIDs.Trim(threshold)
c.GTS.Emoji.Trim(threshold) c.GTS.Emoji.Trim(threshold)
c.GTS.EmojiCategory.Trim(threshold) c.GTS.EmojiCategory.Trim(threshold)
c.GTS.Filter.Trim(threshold)
c.GTS.FilterKeyword.Trim(threshold)
c.GTS.FilterStatus.Trim(threshold)
c.GTS.Follow.Trim(threshold) c.GTS.Follow.Trim(threshold)
c.GTS.FollowIDs.Trim(threshold) c.GTS.FollowIDs.Trim(threshold)
c.GTS.FollowRequest.Trim(threshold) c.GTS.FollowRequest.Trim(threshold)

108
internal/cache/db.go vendored
View file

@ -67,6 +67,15 @@ type GTSCaches struct {
// EmojiCategory provides access to the gtsmodel EmojiCategory database cache. // EmojiCategory provides access to the gtsmodel EmojiCategory database cache.
EmojiCategory structr.Cache[*gtsmodel.EmojiCategory] EmojiCategory structr.Cache[*gtsmodel.EmojiCategory]
// Filter provides access to the gtsmodel Filter database cache.
Filter structr.Cache[*gtsmodel.Filter]
// FilterKeyword provides access to the gtsmodel FilterKeyword database cache.
FilterKeyword structr.Cache[*gtsmodel.FilterKeyword]
// FilterStatus provides access to the gtsmodel FilterStatus database cache.
FilterStatus structr.Cache[*gtsmodel.FilterStatus]
// Follow provides access to the gtsmodel Follow database cache. // Follow provides access to the gtsmodel Follow database cache.
Follow structr.Cache[*gtsmodel.Follow] Follow structr.Cache[*gtsmodel.Follow]
@ -409,6 +418,105 @@ func (c *Caches) initEmojiCategory() {
}) })
} }
func (c *Caches) initFilter() {
// Calculate maximum cache size.
cap := calculateResultCacheMax(
sizeofFilter(), // model in-mem size.
config.GetCacheFilterMemRatio(),
)
log.Infof(nil, "cache size = %d", cap)
copyF := func(filter1 *gtsmodel.Filter) *gtsmodel.Filter {
filter2 := new(gtsmodel.Filter)
*filter2 = *filter1
// Don't include ptr fields that
// will be populated separately.
// See internal/db/bundb/filter.go.
filter2.Keywords = nil
filter2.Statuses = nil
return filter2
}
c.GTS.Filter.Init(structr.Config[*gtsmodel.Filter]{
Indices: []structr.IndexConfig{
{Fields: "ID"},
{Fields: "AccountID", Multiple: true},
},
MaxSize: cap,
IgnoreErr: ignoreErrors,
CopyValue: copyF,
})
}
func (c *Caches) initFilterKeyword() {
// Calculate maximum cache size.
cap := calculateResultCacheMax(
sizeofFilterKeyword(), // model in-mem size.
config.GetCacheFilterKeywordMemRatio(),
)
log.Infof(nil, "cache size = %d", cap)
copyF := func(filterKeyword1 *gtsmodel.FilterKeyword) *gtsmodel.FilterKeyword {
filterKeyword2 := new(gtsmodel.FilterKeyword)
*filterKeyword2 = *filterKeyword1
// Don't include ptr fields that
// will be populated separately.
// See internal/db/bundb/filter.go.
filterKeyword2.Filter = nil
return filterKeyword2
}
c.GTS.FilterKeyword.Init(structr.Config[*gtsmodel.FilterKeyword]{
Indices: []structr.IndexConfig{
{Fields: "ID"},
{Fields: "AccountID", Multiple: true},
{Fields: "FilterID", Multiple: true},
},
MaxSize: cap,
IgnoreErr: ignoreErrors,
CopyValue: copyF,
})
}
func (c *Caches) initFilterStatus() {
// Calculate maximum cache size.
cap := calculateResultCacheMax(
sizeofFilterStatus(), // model in-mem size.
config.GetCacheFilterStatusMemRatio(),
)
log.Infof(nil, "cache size = %d", cap)
copyF := func(filterStatus1 *gtsmodel.FilterStatus) *gtsmodel.FilterStatus {
filterStatus2 := new(gtsmodel.FilterStatus)
*filterStatus2 = *filterStatus1
// Don't include ptr fields that
// will be populated separately.
// See internal/db/bundb/filter.go.
filterStatus2.Filter = nil
return filterStatus2
}
c.GTS.FilterStatus.Init(structr.Config[*gtsmodel.FilterStatus]{
Indices: []structr.IndexConfig{
{Fields: "ID"},
{Fields: "AccountID", Multiple: true},
{Fields: "FilterID", Multiple: true},
},
MaxSize: cap,
IgnoreErr: ignoreErrors,
CopyValue: copyF,
})
}
func (c *Caches) initFollow() { func (c *Caches) initFollow() {
// Calculate maximum cache size. // Calculate maximum cache size.
cap := calculateResultCacheMax( cap := calculateResultCacheMax(

View file

@ -309,6 +309,38 @@ func sizeofEmojiCategory() uintptr {
})) }))
} }
func sizeofFilter() uintptr {
return uintptr(size.Of(&gtsmodel.Filter{
ID: exampleID,
CreatedAt: exampleTime,
UpdatedAt: exampleTime,
ExpiresAt: exampleTime,
AccountID: exampleID,
Title: exampleTextSmall,
Action: gtsmodel.FilterActionHide,
}))
}
func sizeofFilterKeyword() uintptr {
return uintptr(size.Of(&gtsmodel.FilterKeyword{
ID: exampleID,
CreatedAt: exampleTime,
UpdatedAt: exampleTime,
FilterID: exampleID,
Keyword: exampleTextSmall,
}))
}
func sizeofFilterStatus() uintptr {
return uintptr(size.Of(&gtsmodel.FilterStatus{
ID: exampleID,
CreatedAt: exampleTime,
UpdatedAt: exampleTime,
FilterID: exampleID,
StatusID: exampleID,
}))
}
func sizeofFollow() uintptr { func sizeofFollow() uintptr {
return uintptr(size.Of(&gtsmodel.Follow{ return uintptr(size.Of(&gtsmodel.Follow{
ID: exampleID, ID: exampleID,

View file

@ -201,6 +201,9 @@ type CacheConfiguration struct {
BoostOfIDsMemRatio float64 `name:"boost-of-ids-mem-ratio"` BoostOfIDsMemRatio float64 `name:"boost-of-ids-mem-ratio"`
EmojiMemRatio float64 `name:"emoji-mem-ratio"` EmojiMemRatio float64 `name:"emoji-mem-ratio"`
EmojiCategoryMemRatio float64 `name:"emoji-category-mem-ratio"` EmojiCategoryMemRatio float64 `name:"emoji-category-mem-ratio"`
FilterMemRatio float64 `name:"filter-mem-ratio"`
FilterKeywordMemRatio float64 `name:"filter-keyword-mem-ratio"`
FilterStatusMemRatio float64 `name:"filter-status-mem-ratio"`
FollowMemRatio float64 `name:"follow-mem-ratio"` FollowMemRatio float64 `name:"follow-mem-ratio"`
FollowIDsMemRatio float64 `name:"follow-ids-mem-ratio"` FollowIDsMemRatio float64 `name:"follow-ids-mem-ratio"`
FollowRequestMemRatio float64 `name:"follow-request-mem-ratio"` FollowRequestMemRatio float64 `name:"follow-request-mem-ratio"`

View file

@ -165,6 +165,9 @@ var Defaults = Configuration{
BoostOfIDsMemRatio: 3, BoostOfIDsMemRatio: 3,
EmojiMemRatio: 3, EmojiMemRatio: 3,
EmojiCategoryMemRatio: 0.1, EmojiCategoryMemRatio: 0.1,
FilterMemRatio: 0.5,
FilterKeywordMemRatio: 0.5,
FilterStatusMemRatio: 0.5,
FollowMemRatio: 2, FollowMemRatio: 2,
FollowIDsMemRatio: 4, FollowIDsMemRatio: 4,
FollowRequestMemRatio: 2, FollowRequestMemRatio: 2,

View file

@ -2975,6 +2975,81 @@ func GetCacheEmojiCategoryMemRatio() float64 { return global.GetCacheEmojiCatego
// SetCacheEmojiCategoryMemRatio safely sets the value for global configuration 'Cache.EmojiCategoryMemRatio' field // SetCacheEmojiCategoryMemRatio safely sets the value for global configuration 'Cache.EmojiCategoryMemRatio' field
func SetCacheEmojiCategoryMemRatio(v float64) { global.SetCacheEmojiCategoryMemRatio(v) } func SetCacheEmojiCategoryMemRatio(v float64) { global.SetCacheEmojiCategoryMemRatio(v) }
// GetCacheFilterMemRatio safely fetches the Configuration value for state's 'Cache.FilterMemRatio' field
func (st *ConfigState) GetCacheFilterMemRatio() (v float64) {
st.mutex.RLock()
v = st.config.Cache.FilterMemRatio
st.mutex.RUnlock()
return
}
// SetCacheFilterMemRatio safely sets the Configuration value for state's 'Cache.FilterMemRatio' field
func (st *ConfigState) SetCacheFilterMemRatio(v float64) {
st.mutex.Lock()
defer st.mutex.Unlock()
st.config.Cache.FilterMemRatio = v
st.reloadToViper()
}
// CacheFilterMemRatioFlag returns the flag name for the 'Cache.FilterMemRatio' field
func CacheFilterMemRatioFlag() string { return "cache-filter-mem-ratio" }
// GetCacheFilterMemRatio safely fetches the value for global configuration 'Cache.FilterMemRatio' field
func GetCacheFilterMemRatio() float64 { return global.GetCacheFilterMemRatio() }
// SetCacheFilterMemRatio safely sets the value for global configuration 'Cache.FilterMemRatio' field
func SetCacheFilterMemRatio(v float64) { global.SetCacheFilterMemRatio(v) }
// GetCacheFilterKeywordMemRatio safely fetches the Configuration value for state's 'Cache.FilterKeywordMemRatio' field
func (st *ConfigState) GetCacheFilterKeywordMemRatio() (v float64) {
st.mutex.RLock()
v = st.config.Cache.FilterKeywordMemRatio
st.mutex.RUnlock()
return
}
// SetCacheFilterKeywordMemRatio safely sets the Configuration value for state's 'Cache.FilterKeywordMemRatio' field
func (st *ConfigState) SetCacheFilterKeywordMemRatio(v float64) {
st.mutex.Lock()
defer st.mutex.Unlock()
st.config.Cache.FilterKeywordMemRatio = v
st.reloadToViper()
}
// CacheFilterKeywordMemRatioFlag returns the flag name for the 'Cache.FilterKeywordMemRatio' field
func CacheFilterKeywordMemRatioFlag() string { return "cache-filter-keyword-mem-ratio" }
// GetCacheFilterKeywordMemRatio safely fetches the value for global configuration 'Cache.FilterKeywordMemRatio' field
func GetCacheFilterKeywordMemRatio() float64 { return global.GetCacheFilterKeywordMemRatio() }
// SetCacheFilterKeywordMemRatio safely sets the value for global configuration 'Cache.FilterKeywordMemRatio' field
func SetCacheFilterKeywordMemRatio(v float64) { global.SetCacheFilterKeywordMemRatio(v) }
// GetCacheFilterStatusMemRatio safely fetches the Configuration value for state's 'Cache.FilterStatusMemRatio' field
func (st *ConfigState) GetCacheFilterStatusMemRatio() (v float64) {
st.mutex.RLock()
v = st.config.Cache.FilterStatusMemRatio
st.mutex.RUnlock()
return
}
// SetCacheFilterStatusMemRatio safely sets the Configuration value for state's 'Cache.FilterStatusMemRatio' field
func (st *ConfigState) SetCacheFilterStatusMemRatio(v float64) {
st.mutex.Lock()
defer st.mutex.Unlock()
st.config.Cache.FilterStatusMemRatio = v
st.reloadToViper()
}
// CacheFilterStatusMemRatioFlag returns the flag name for the 'Cache.FilterStatusMemRatio' field
func CacheFilterStatusMemRatioFlag() string { return "cache-filter-status-mem-ratio" }
// GetCacheFilterStatusMemRatio safely fetches the value for global configuration 'Cache.FilterStatusMemRatio' field
func GetCacheFilterStatusMemRatio() float64 { return global.GetCacheFilterStatusMemRatio() }
// SetCacheFilterStatusMemRatio safely sets the value for global configuration 'Cache.FilterStatusMemRatio' field
func SetCacheFilterStatusMemRatio(v float64) { global.SetCacheFilterStatusMemRatio(v) }
// GetCacheFollowMemRatio safely fetches the Configuration value for state's 'Cache.FollowMemRatio' field // GetCacheFollowMemRatio safely fetches the Configuration value for state's 'Cache.FollowMemRatio' field
func (st *ConfigState) GetCacheFollowMemRatio() (v float64) { func (st *ConfigState) GetCacheFollowMemRatio() (v float64) {
st.mutex.RLock() st.mutex.RLock()

View file

@ -62,6 +62,7 @@ type DBService struct {
db.Emoji db.Emoji
db.HeaderFilter db.HeaderFilter
db.Instance db.Instance
db.Filter
db.List db.List
db.Marker db.Marker
db.Media db.Media
@ -200,6 +201,10 @@ func NewBunDBService(ctx context.Context, state *state.State) (db.DB, error) {
db: db, db: db,
state: state, state: state,
}, },
Filter: &filterDB{
db: db,
state: state,
},
List: &listDB{ List: &listDB{
db: db, db: db,
state: state, state: state,

339
internal/db/bundb/filter.go Normal file
View file

@ -0,0 +1,339 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package bundb
import (
"context"
"slices"
"time"
"github.com/superseriousbusiness/gotosocial/internal/gtscontext"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/util"
"github.com/uptrace/bun"
)
type filterDB struct {
db *bun.DB
state *state.State
}
func (f *filterDB) GetFilterByID(ctx context.Context, id string) (*gtsmodel.Filter, error) {
filter, err := f.state.Caches.GTS.Filter.LoadOne(
"ID",
func() (*gtsmodel.Filter, error) {
var filter gtsmodel.Filter
err := f.db.
NewSelect().
Model(&filter).
Where("? = ?", bun.Ident("id"), id).
Scan(ctx)
return &filter, err
},
id,
)
if err != nil {
// already processed
return nil, err
}
if !gtscontext.Barebones(ctx) {
if err := f.populateFilter(ctx, filter); err != nil {
return nil, err
}
}
return filter, nil
}
func (f *filterDB) GetFiltersForAccountID(ctx context.Context, accountID string) ([]*gtsmodel.Filter, error) {
// Fetch IDs of all filters owned by this account.
var filterIDs []string
if err := f.db.
NewSelect().
Model((*gtsmodel.Filter)(nil)).
Column("id").
Where("? = ?", bun.Ident("account_id"), accountID).
Scan(ctx, &filterIDs); err != nil {
return nil, err
}
if len(filterIDs) == 0 {
return nil, nil
}
// Get each filter by ID from the cache or DB.
uncachedFilterIDs := make([]string, 0, len(filterIDs))
filters, err := f.state.Caches.GTS.Filter.Load(
"ID",
func(load func(keyParts ...any) bool) {
for _, id := range filterIDs {
if !load(id) {
uncachedFilterIDs = append(uncachedFilterIDs, id)
}
}
},
func() ([]*gtsmodel.Filter, error) {
uncachedFilters := make([]*gtsmodel.Filter, 0, len(uncachedFilterIDs))
if err := f.db.
NewSelect().
Model(&uncachedFilters).
Where("? IN (?)", bun.Ident("id"), bun.In(uncachedFilterIDs)).
Scan(ctx); err != nil {
return nil, err
}
return uncachedFilters, nil
},
)
if err != nil {
return nil, err
}
// Put the filter structs in the same order as the filter IDs.
util.OrderBy(filters, filterIDs, func(filter *gtsmodel.Filter) string { return filter.ID })
if gtscontext.Barebones(ctx) {
return filters, nil
}
// Populate the filters. Remove any that we can't populate from the return slice.
errs := gtserror.NewMultiError(len(filters))
filters = slices.DeleteFunc(filters, func(filter *gtsmodel.Filter) bool {
if err := f.populateFilter(ctx, filter); err != nil {
errs.Appendf("error populating filter %s: %w", filter.ID, err)
return true
}
return false
})
return filters, errs.Combine()
}
func (f *filterDB) populateFilter(ctx context.Context, filter *gtsmodel.Filter) error {
var err error
errs := gtserror.NewMultiError(2)
if filter.Keywords == nil {
// Filter keywords are not set, fetch from the database.
filter.Keywords, err = f.state.DB.GetFilterKeywordsForFilterID(
gtscontext.SetBarebones(ctx),
filter.ID,
)
if err != nil {
errs.Appendf("error populating filter keywords: %w", err)
}
for i := range filter.Keywords {
filter.Keywords[i].Filter = filter
}
}
if filter.Statuses == nil {
// Filter statuses are not set, fetch from the database.
filter.Statuses, err = f.state.DB.GetFilterStatusesForFilterID(
gtscontext.SetBarebones(ctx),
filter.ID,
)
if err != nil {
errs.Appendf("error populating filter statuses: %w", err)
}
for i := range filter.Statuses {
filter.Statuses[i].Filter = filter
}
}
return errs.Combine()
}
func (f *filterDB) PutFilter(ctx context.Context, filter *gtsmodel.Filter) error {
// Update database.
if err := f.db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error {
if _, err := tx.
NewInsert().
Model(filter).
Exec(ctx); err != nil {
return err
}
if len(filter.Keywords) > 0 {
if _, err := tx.
NewInsert().
Model(&filter.Keywords).
Exec(ctx); err != nil {
return err
}
}
if len(filter.Statuses) > 0 {
if _, err := tx.
NewInsert().
Model(&filter.Statuses).
Exec(ctx); err != nil {
return err
}
}
return nil
}); err != nil {
return err
}
// Update cache.
f.state.Caches.GTS.Filter.Put(filter)
f.state.Caches.GTS.FilterKeyword.Put(filter.Keywords...)
f.state.Caches.GTS.FilterStatus.Put(filter.Statuses...)
return nil
}
func (f *filterDB) UpdateFilter(
ctx context.Context,
filter *gtsmodel.Filter,
filterColumns []string,
filterKeywordColumns []string,
deleteFilterKeywordIDs []string,
deleteFilterStatusIDs []string,
) error {
updatedAt := time.Now()
filter.UpdatedAt = updatedAt
for _, filterKeyword := range filter.Keywords {
filterKeyword.UpdatedAt = updatedAt
}
for _, filterStatus := range filter.Statuses {
filterStatus.UpdatedAt = updatedAt
}
// If we're updating by column, ensure "updated_at" is included.
if len(filterColumns) > 0 {
filterColumns = append(filterColumns, "updated_at")
}
if len(filterKeywordColumns) > 0 {
filterKeywordColumns = append(filterKeywordColumns, "updated_at")
}
// Update database.
if err := f.db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error {
if _, err := tx.
NewUpdate().
Model(filter).
Column(filterColumns...).
Where("? = ?", bun.Ident("id"), filter.ID).
Exec(ctx); err != nil {
return err
}
if len(filter.Keywords) > 0 {
if _, err := NewUpsert(tx).
Model(&filter.Keywords).
Constraint("id").
Column(filterKeywordColumns...).
Exec(ctx); err != nil {
return err
}
}
if len(filter.Statuses) > 0 {
if _, err := tx.
NewInsert().
Ignore().
Model(&filter.Statuses).
Exec(ctx); err != nil {
return err
}
}
if len(deleteFilterKeywordIDs) > 0 {
if _, err := tx.
NewDelete().
Model((*gtsmodel.FilterKeyword)(nil)).
Where("? = (?)", bun.Ident("id"), bun.In(deleteFilterKeywordIDs)).
Exec(ctx); err != nil {
return err
}
}
if len(deleteFilterStatusIDs) > 0 {
if _, err := tx.
NewDelete().
Model((*gtsmodel.FilterStatus)(nil)).
Where("? = (?)", bun.Ident("id"), bun.In(deleteFilterStatusIDs)).
Exec(ctx); err != nil {
return err
}
}
return nil
}); err != nil {
return err
}
// Update cache.
f.state.Caches.GTS.Filter.Put(filter)
f.state.Caches.GTS.FilterKeyword.Put(filter.Keywords...)
f.state.Caches.GTS.FilterStatus.Put(filter.Statuses...)
// TODO: (Vyr) replace with cache multi-invalidate call
for _, id := range deleteFilterKeywordIDs {
f.state.Caches.GTS.FilterKeyword.Invalidate("ID", id)
}
for _, id := range deleteFilterStatusIDs {
f.state.Caches.GTS.FilterStatus.Invalidate("ID", id)
}
return nil
}
func (f *filterDB) DeleteFilterByID(ctx context.Context, id string) error {
if err := f.db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error {
// Delete all keywords attached to filter.
if _, err := tx.
NewDelete().
Model((*gtsmodel.FilterKeyword)(nil)).
Where("? = ?", bun.Ident("filter_id"), id).
Exec(ctx); err != nil {
return err
}
// Delete all statuses attached to filter.
if _, err := tx.
NewDelete().
Model((*gtsmodel.FilterStatus)(nil)).
Where("? = ?", bun.Ident("filter_id"), id).
Exec(ctx); err != nil {
return err
}
// Delete the filter itself.
_, err := tx.
NewDelete().
Model((*gtsmodel.Filter)(nil)).
Where("? = ?", bun.Ident("id"), id).
Exec(ctx)
return err
}); err != nil {
return err
}
// Invalidate this filter.
f.state.Caches.GTS.Filter.Invalidate("ID", id)
// Invalidate all keywords and statuses for this filter.
f.state.Caches.GTS.FilterKeyword.Invalidate("FilterID", id)
f.state.Caches.GTS.FilterStatus.Invalidate("FilterID", id)
return nil
}

View file

@ -0,0 +1,252 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package bundb_test
import (
"context"
"errors"
"testing"
"github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/util"
)
type FilterTestSuite struct {
BunDBStandardTestSuite
}
// TestFilterCRUD tests CRUD and read-all operations on filters.
func (suite *FilterTestSuite) TestFilterCRUD() {
t := suite.T()
// Create new example filter with attached keyword.
filter := &gtsmodel.Filter{
ID: "01HNEJNVZZVXJTRB3FX3K2B1YF",
AccountID: "01HNEJXCPRTJVJY9MV0VVHGD47",
Title: "foss jail",
Action: gtsmodel.FilterActionWarn,
ContextHome: util.Ptr(true),
ContextPublic: util.Ptr(true),
}
filterKeyword := &gtsmodel.FilterKeyword{
ID: "01HNEK4RW5QEAMG9Y4ET6ST0J4",
AccountID: filter.AccountID,
FilterID: filter.ID,
Keyword: "GNU/Linux",
}
filter.Keywords = []*gtsmodel.FilterKeyword{filterKeyword}
// Create new cancellable test context.
ctx := context.Background()
ctx, cncl := context.WithCancel(ctx)
defer cncl()
// Insert the example filter into db.
if err := suite.db.PutFilter(ctx, filter); err != nil {
t.Fatalf("error inserting filter: %v", err)
}
// Now fetch newly created filter.
check, err := suite.db.GetFilterByID(ctx, filter.ID)
if err != nil {
t.Fatalf("error fetching filter: %v", err)
}
// Check all expected fields match.
suite.Equal(filter.ID, check.ID)
suite.Equal(filter.AccountID, check.AccountID)
suite.Equal(filter.Title, check.Title)
suite.Equal(filter.Action, check.Action)
suite.Equal(filter.ContextHome, check.ContextHome)
suite.Equal(filter.ContextNotifications, check.ContextNotifications)
suite.Equal(filter.ContextPublic, check.ContextPublic)
suite.Equal(filter.ContextThread, check.ContextThread)
suite.Equal(filter.ContextAccount, check.ContextAccount)
suite.NotZero(check.CreatedAt)
suite.NotZero(check.UpdatedAt)
suite.Equal(len(filter.Keywords), len(check.Keywords))
suite.Equal(filter.Keywords[0].ID, check.Keywords[0].ID)
suite.Equal(filter.Keywords[0].AccountID, check.Keywords[0].AccountID)
suite.Equal(filter.Keywords[0].FilterID, check.Keywords[0].FilterID)
suite.Equal(filter.Keywords[0].Keyword, check.Keywords[0].Keyword)
suite.Equal(filter.Keywords[0].FilterID, check.Keywords[0].FilterID)
suite.NotZero(check.Keywords[0].CreatedAt)
suite.NotZero(check.Keywords[0].UpdatedAt)
suite.Equal(len(filter.Statuses), len(check.Statuses))
// Fetch all filters.
all, err := suite.db.GetFiltersForAccountID(ctx, filter.AccountID)
if err != nil {
t.Fatalf("error fetching filters: %v", err)
}
// Ensure the result contains our example filter.
suite.Len(all, 1)
suite.Equal(filter.ID, all[0].ID)
suite.Len(all[0].Keywords, 1)
suite.Equal(filter.Keywords[0].ID, all[0].Keywords[0].ID)
suite.Empty(all[0].Statuses)
// Update the filter context and add another keyword and a status.
check.ContextNotifications = util.Ptr(true)
newKeyword := &gtsmodel.FilterKeyword{
ID: "01HNEMY810E5XKWDDMN5ZRE749",
FilterID: filter.ID,
AccountID: filter.AccountID,
Keyword: "tux",
}
check.Keywords = append(check.Keywords, newKeyword)
newStatus := &gtsmodel.FilterStatus{
ID: "01HNEMYD5XE7C8HH8TNCZ76FN2",
FilterID: filter.ID,
AccountID: filter.AccountID,
StatusID: "01HNEKZW34SQZ8PSDQ0Z10NZES",
}
check.Statuses = append(check.Statuses, newStatus)
if err := suite.db.UpdateFilter(ctx, check, nil, nil, nil, nil); err != nil {
t.Fatalf("error updating filter: %v", err)
}
// Now fetch newly updated filter.
check, err = suite.db.GetFilterByID(ctx, filter.ID)
if err != nil {
t.Fatalf("error fetching updated filter: %v", err)
}
// Ensure expected fields were modified on check filter.
suite.True(check.UpdatedAt.After(filter.UpdatedAt))
if suite.NotNil(check.ContextHome) {
suite.True(*check.ContextHome)
}
if suite.NotNil(check.ContextNotifications) {
suite.True(*check.ContextNotifications)
}
if suite.NotNil(check.ContextPublic) {
suite.True(*check.ContextPublic)
}
if suite.NotNil(check.ContextThread) {
suite.False(*check.ContextThread)
}
if suite.NotNil(check.ContextAccount) {
suite.False(*check.ContextAccount)
}
// Ensure keyword entries were added.
suite.Len(check.Keywords, 2)
checkFilterKeywordIDs := make([]string, 0, 2)
for _, checkFilterKeyword := range check.Keywords {
checkFilterKeywordIDs = append(checkFilterKeywordIDs, checkFilterKeyword.ID)
}
suite.ElementsMatch([]string{filterKeyword.ID, newKeyword.ID}, checkFilterKeywordIDs)
// Ensure status entry was added.
suite.Len(check.Statuses, 1)
checkFilterStatusIDs := make([]string, 0, 1)
for _, checkFilterStatus := range check.Statuses {
checkFilterStatusIDs = append(checkFilterStatusIDs, checkFilterStatus.ID)
}
suite.ElementsMatch([]string{newStatus.ID}, checkFilterStatusIDs)
// Update one filter keyword and delete another. Don't change the filter or the filter status.
filterKeyword.WholeWord = util.Ptr(true)
check.Keywords = []*gtsmodel.FilterKeyword{filterKeyword}
check.Statuses = nil
if err := suite.db.UpdateFilter(ctx, check, nil, nil, []string{newKeyword.ID}, nil); err != nil {
t.Fatalf("error updating filter: %v", err)
}
check, err = suite.db.GetFilterByID(ctx, filter.ID)
if err != nil {
t.Fatalf("error fetching updated filter: %v", err)
}
// Ensure expected fields were not modified.
suite.Equal(filter.Title, check.Title)
suite.Equal(gtsmodel.FilterActionWarn, check.Action)
if suite.NotNil(check.ContextHome) {
suite.True(*check.ContextHome)
}
if suite.NotNil(check.ContextNotifications) {
suite.True(*check.ContextNotifications)
}
if suite.NotNil(check.ContextPublic) {
suite.True(*check.ContextPublic)
}
if suite.NotNil(check.ContextThread) {
suite.False(*check.ContextThread)
}
if suite.NotNil(check.ContextAccount) {
suite.False(*check.ContextAccount)
}
// Ensure only changed field of keyword was modified, and other keyword was deleted.
suite.Len(check.Keywords, 1)
suite.Equal(filterKeyword.ID, check.Keywords[0].ID)
suite.Equal("GNU/Linux", check.Keywords[0].Keyword)
if suite.NotNil(check.Keywords[0].WholeWord) {
suite.True(*check.Keywords[0].WholeWord)
}
// Ensure status entry was not deleted.
suite.Len(check.Statuses, 1)
suite.Equal(newStatus.ID, check.Statuses[0].ID)
// Add another status entry for the same status ID. It should be ignored without problems.
redundantStatus := &gtsmodel.FilterStatus{
ID: "01HQXJ5Y405XZSQ67C2BSQ6HJ0",
FilterID: filter.ID,
AccountID: filter.AccountID,
StatusID: newStatus.StatusID,
}
check.Statuses = []*gtsmodel.FilterStatus{redundantStatus}
if err := suite.db.UpdateFilter(ctx, check, nil, nil, nil, nil); err != nil {
t.Fatalf("error updating filter: %v", err)
}
check, err = suite.db.GetFilterByID(ctx, filter.ID)
if err != nil {
t.Fatalf("error fetching updated filter: %v", err)
}
// Ensure status entry was not deleted, updated, or duplicated.
suite.Len(check.Statuses, 1)
suite.Equal(newStatus.ID, check.Statuses[0].ID)
suite.Equal(newStatus.StatusID, check.Statuses[0].StatusID)
// Now delete the filter from the DB.
if err := suite.db.DeleteFilterByID(ctx, filter.ID); err != nil {
t.Fatalf("error deleting filter: %v", err)
}
// Ensure we can't refetch it.
_, err = suite.db.GetFilterByID(ctx, filter.ID)
if !errors.Is(err, db.ErrNoEntries) {
t.Fatalf("fetching deleted filter returned unexpected error: %v", err)
}
}
func TestFilterTestSuite(t *testing.T) {
suite.Run(t, new(FilterTestSuite))
}

View file

@ -0,0 +1,191 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package bundb
import (
"context"
"slices"
"time"
"github.com/superseriousbusiness/gotosocial/internal/gtscontext"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/util"
"github.com/uptrace/bun"
)
func (f *filterDB) GetFilterKeywordByID(ctx context.Context, id string) (*gtsmodel.FilterKeyword, error) {
filterKeyword, err := f.state.Caches.GTS.FilterKeyword.LoadOne(
"ID",
func() (*gtsmodel.FilterKeyword, error) {
var filterKeyword gtsmodel.FilterKeyword
err := f.db.
NewSelect().
Model(&filterKeyword).
Where("? = ?", bun.Ident("id"), id).
Scan(ctx)
return &filterKeyword, err
},
id,
)
if err != nil {
return nil, err
}
if !gtscontext.Barebones(ctx) {
err = f.populateFilterKeyword(ctx, filterKeyword)
if err != nil {
return nil, err
}
}
return filterKeyword, nil
}
func (f *filterDB) populateFilterKeyword(ctx context.Context, filterKeyword *gtsmodel.FilterKeyword) error {
if filterKeyword.Filter == nil {
// Filter is not set, fetch from the cache or database.
filter, err := f.state.DB.GetFilterByID(
// Don't populate the filter with all of its keywords and statuses or we'll just end up back here.
gtscontext.SetBarebones(ctx),
filterKeyword.FilterID,
)
if err != nil {
return err
}
filterKeyword.Filter = filter
}
return nil
}
func (f *filterDB) GetFilterKeywordsForFilterID(ctx context.Context, filterID string) ([]*gtsmodel.FilterKeyword, error) {
return f.getFilterKeywords(ctx, "filter_id", filterID)
}
func (f *filterDB) GetFilterKeywordsForAccountID(ctx context.Context, accountID string) ([]*gtsmodel.FilterKeyword, error) {
return f.getFilterKeywords(ctx, "account_id", accountID)
}
func (f *filterDB) getFilterKeywords(ctx context.Context, idColumn string, id string) ([]*gtsmodel.FilterKeyword, error) {
var filterKeywordIDs []string
if err := f.db.
NewSelect().
Model((*gtsmodel.FilterKeyword)(nil)).
Column("id").
Where("? = ?", bun.Ident(idColumn), id).
Scan(ctx, &filterKeywordIDs); err != nil {
return nil, err
}
if len(filterKeywordIDs) == 0 {
return nil, nil
}
// Get each filter keyword by ID from the cache or DB.
uncachedFilterKeywordIDs := make([]string, 0, len(filterKeywordIDs))
filterKeywords, err := f.state.Caches.GTS.FilterKeyword.Load(
"ID",
func(load func(keyParts ...any) bool) {
for _, id := range filterKeywordIDs {
if !load(id) {
uncachedFilterKeywordIDs = append(uncachedFilterKeywordIDs, id)
}
}
},
func() ([]*gtsmodel.FilterKeyword, error) {
uncachedFilterKeywords := make([]*gtsmodel.FilterKeyword, 0, len(uncachedFilterKeywordIDs))
if err := f.db.
NewSelect().
Model(&uncachedFilterKeywords).
Where("? IN (?)", bun.Ident("id"), bun.In(uncachedFilterKeywordIDs)).
Scan(ctx); err != nil {
return nil, err
}
return uncachedFilterKeywords, nil
},
)
if err != nil {
return nil, err
}
// Put the filter keyword structs in the same order as the filter keyword IDs.
util.OrderBy(filterKeywords, filterKeywordIDs, func(filterKeyword *gtsmodel.FilterKeyword) string {
return filterKeyword.ID
})
if gtscontext.Barebones(ctx) {
return filterKeywords, nil
}
// Populate the filter keywords. Remove any that we can't populate from the return slice.
errs := gtserror.NewMultiError(len(filterKeywords))
filterKeywords = slices.DeleteFunc(filterKeywords, func(filterKeyword *gtsmodel.FilterKeyword) bool {
if err := f.populateFilterKeyword(ctx, filterKeyword); err != nil {
errs.Appendf(
"error populating filter keyword %s: %w",
filterKeyword.ID,
err,
)
return true
}
return false
})
return filterKeywords, errs.Combine()
}
func (f *filterDB) PutFilterKeyword(ctx context.Context, filterKeyword *gtsmodel.FilterKeyword) error {
return f.state.Caches.GTS.FilterKeyword.Store(filterKeyword, func() error {
_, err := f.db.
NewInsert().
Model(filterKeyword).
Exec(ctx)
return err
})
}
func (f *filterDB) UpdateFilterKeyword(ctx context.Context, filterKeyword *gtsmodel.FilterKeyword, columns ...string) error {
filterKeyword.UpdatedAt = time.Now()
if len(columns) > 0 {
columns = append(columns, "updated_at")
}
return f.state.Caches.GTS.FilterKeyword.Store(filterKeyword, func() error {
_, err := f.db.
NewUpdate().
Model(filterKeyword).
Where("? = ?", bun.Ident("id"), filterKeyword.ID).
Column(columns...).
Exec(ctx)
return err
})
}
func (f *filterDB) DeleteFilterKeywordByID(ctx context.Context, id string) error {
if _, err := f.db.
NewDelete().
Model((*gtsmodel.FilterKeyword)(nil)).
Where("? = ?", bun.Ident("id"), id).
Exec(ctx); err != nil {
return err
}
f.state.Caches.GTS.FilterKeyword.Invalidate("ID", id)
return nil
}

View file

@ -0,0 +1,143 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package bundb_test
import (
"context"
"errors"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/util"
)
// TestFilterKeywordCRUD tests CRUD and read-all operations on filter keywords.
func (suite *FilterTestSuite) TestFilterKeywordCRUD() {
t := suite.T()
// Create new filter.
filter := &gtsmodel.Filter{
ID: "01HNEJNVZZVXJTRB3FX3K2B1YF",
AccountID: "01HNEJXCPRTJVJY9MV0VVHGD47",
Title: "foss jail",
Action: gtsmodel.FilterActionWarn,
ContextHome: util.Ptr(true),
ContextPublic: util.Ptr(true),
}
// Create new cancellable test context.
ctx := context.Background()
ctx, cncl := context.WithCancel(ctx)
defer cncl()
// Insert the new filter into the DB.
err := suite.db.PutFilter(ctx, filter)
if err != nil {
t.Fatalf("error inserting filter: %v", err)
}
// There should be no filter keywords yet.
all, err := suite.db.GetFilterKeywordsForAccountID(ctx, filter.AccountID)
if err != nil {
t.Fatalf("error fetching filter keywords: %v", err)
}
suite.Empty(all)
// Add a filter keyword to it.
filterKeyword := &gtsmodel.FilterKeyword{
ID: "01HNEK4RW5QEAMG9Y4ET6ST0J4",
AccountID: filter.AccountID,
FilterID: filter.ID,
Keyword: "GNU/Linux",
}
// Insert the new filter keyword into the DB.
err = suite.db.PutFilterKeyword(ctx, filterKeyword)
if err != nil {
t.Fatalf("error inserting filter keyword: %v", err)
}
// Try to find it again and ensure it has the fields we expect.
check, err := suite.db.GetFilterKeywordByID(ctx, filterKeyword.ID)
if err != nil {
t.Fatalf("error fetching filter keyword: %v", err)
}
suite.Equal(filterKeyword.ID, check.ID)
suite.NotZero(check.CreatedAt)
suite.NotZero(check.UpdatedAt)
suite.Equal(filterKeyword.AccountID, check.AccountID)
suite.Equal(filterKeyword.FilterID, check.FilterID)
suite.Equal(filterKeyword.Keyword, check.Keyword)
suite.Equal(filterKeyword.WholeWord, check.WholeWord)
// Loading filter keywords by account ID should find the one we inserted.
all, err = suite.db.GetFilterKeywordsForAccountID(ctx, filter.AccountID)
if err != nil {
t.Fatalf("error fetching filter keywords: %v", err)
}
suite.Len(all, 1)
suite.Equal(filterKeyword.ID, all[0].ID)
// Loading filter keywords by filter ID should also find the one we inserted.
all, err = suite.db.GetFilterKeywordsForFilterID(ctx, filter.ID)
if err != nil {
t.Fatalf("error fetching filter keywords: %v", err)
}
suite.Len(all, 1)
suite.Equal(filterKeyword.ID, all[0].ID)
// Modify the filter keyword.
filterKeyword.WholeWord = util.Ptr(true)
err = suite.db.UpdateFilterKeyword(ctx, filterKeyword)
if err != nil {
t.Fatalf("error updating filter keyword: %v", err)
}
// Try to find it again and ensure it has the updated fields we expect.
check, err = suite.db.GetFilterKeywordByID(ctx, filterKeyword.ID)
if err != nil {
t.Fatalf("error fetching filter keyword: %v", err)
}
suite.Equal(filterKeyword.ID, check.ID)
suite.NotZero(check.CreatedAt)
suite.True(check.UpdatedAt.After(check.CreatedAt))
suite.Equal(filterKeyword.AccountID, check.AccountID)
suite.Equal(filterKeyword.FilterID, check.FilterID)
suite.Equal(filterKeyword.Keyword, check.Keyword)
suite.Equal(filterKeyword.WholeWord, check.WholeWord)
// Delete the filter keyword from the DB.
err = suite.db.DeleteFilterKeywordByID(ctx, filter.ID)
if err != nil {
t.Fatalf("error deleting filter keyword: %v", err)
}
// Ensure we can't refetch it.
check, err = suite.db.GetFilterKeywordByID(ctx, filter.ID)
if !errors.Is(err, db.ErrNoEntries) {
t.Fatalf("fetching deleted filter keyword returned unexpected error: %v", err)
}
suite.Nil(check)
// Ensure the filter itself is still there.
checkFilter, err := suite.db.GetFilterByID(ctx, filter.ID)
if err != nil {
t.Fatalf("error fetching filter: %v", err)
}
suite.Equal(filter.ID, checkFilter.ID)
}

View file

@ -0,0 +1,191 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package bundb
import (
"context"
"slices"
"time"
"github.com/superseriousbusiness/gotosocial/internal/gtscontext"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/util"
"github.com/uptrace/bun"
)
func (f *filterDB) GetFilterStatusByID(ctx context.Context, id string) (*gtsmodel.FilterStatus, error) {
filterStatus, err := f.state.Caches.GTS.FilterStatus.LoadOne(
"ID",
func() (*gtsmodel.FilterStatus, error) {
var filterStatus gtsmodel.FilterStatus
err := f.db.
NewSelect().
Model(&filterStatus).
Where("? = ?", bun.Ident("id"), id).
Scan(ctx)
return &filterStatus, err
},
id,
)
if err != nil {
return nil, err
}
if !gtscontext.Barebones(ctx) {
err = f.populateFilterStatus(ctx, filterStatus)
if err != nil {
return nil, err
}
}
return filterStatus, nil
}
func (f *filterDB) populateFilterStatus(ctx context.Context, filterStatus *gtsmodel.FilterStatus) error {
if filterStatus.Filter == nil {
// Filter is not set, fetch from the cache or database.
filter, err := f.state.DB.GetFilterByID(
// Don't populate the filter with all of its keywords and statuses or we'll just end up back here.
gtscontext.SetBarebones(ctx),
filterStatus.FilterID,
)
if err != nil {
return err
}
filterStatus.Filter = filter
}
return nil
}
func (f *filterDB) GetFilterStatusesForFilterID(ctx context.Context, filterID string) ([]*gtsmodel.FilterStatus, error) {
return f.getFilterStatuses(ctx, "filter_id", filterID)
}
func (f *filterDB) GetFilterStatusesForAccountID(ctx context.Context, accountID string) ([]*gtsmodel.FilterStatus, error) {
return f.getFilterStatuses(ctx, "account_id", accountID)
}
func (f *filterDB) getFilterStatuses(ctx context.Context, idColumn string, id string) ([]*gtsmodel.FilterStatus, error) {
var filterStatusIDs []string
if err := f.db.
NewSelect().
Model((*gtsmodel.FilterStatus)(nil)).
Column("id").
Where("? = ?", bun.Ident(idColumn), id).
Scan(ctx, &filterStatusIDs); err != nil {
return nil, err
}
if len(filterStatusIDs) == 0 {
return nil, nil
}
// Get each filter status by ID from the cache or DB.
uncachedFilterStatusIDs := make([]string, 0, len(filterStatusIDs))
filterStatuses, err := f.state.Caches.GTS.FilterStatus.Load(
"ID",
func(load func(keyParts ...any) bool) {
for _, id := range filterStatusIDs {
if !load(id) {
uncachedFilterStatusIDs = append(uncachedFilterStatusIDs, id)
}
}
},
func() ([]*gtsmodel.FilterStatus, error) {
uncachedFilterStatuses := make([]*gtsmodel.FilterStatus, 0, len(uncachedFilterStatusIDs))
if err := f.db.
NewSelect().
Model(&uncachedFilterStatuses).
Where("? IN (?)", bun.Ident("id"), bun.In(uncachedFilterStatusIDs)).
Scan(ctx); err != nil {
return nil, err
}
return uncachedFilterStatuses, nil
},
)
if err != nil {
return nil, err
}
// Put the filter status structs in the same order as the filter status IDs.
util.OrderBy(filterStatuses, filterStatusIDs, func(filterStatus *gtsmodel.FilterStatus) string {
return filterStatus.ID
})
if gtscontext.Barebones(ctx) {
return filterStatuses, nil
}
// Populate the filter statuses. Remove any that we can't populate from the return slice.
errs := gtserror.NewMultiError(len(filterStatuses))
filterStatuses = slices.DeleteFunc(filterStatuses, func(filterStatus *gtsmodel.FilterStatus) bool {
if err := f.populateFilterStatus(ctx, filterStatus); err != nil {
errs.Appendf(
"error populating filter status %s: %w",
filterStatus.ID,
err,
)
return true
}
return false
})
return filterStatuses, errs.Combine()
}
func (f *filterDB) PutFilterStatus(ctx context.Context, filterStatus *gtsmodel.FilterStatus) error {
return f.state.Caches.GTS.FilterStatus.Store(filterStatus, func() error {
_, err := f.db.
NewInsert().
Model(filterStatus).
Exec(ctx)
return err
})
}
func (f *filterDB) UpdateFilterStatus(ctx context.Context, filterStatus *gtsmodel.FilterStatus, columns ...string) error {
filterStatus.UpdatedAt = time.Now()
if len(columns) > 0 {
columns = append(columns, "updated_at")
}
return f.state.Caches.GTS.FilterStatus.Store(filterStatus, func() error {
_, err := f.db.
NewUpdate().
Model(filterStatus).
Where("? = ?", bun.Ident("id"), filterStatus.ID).
Column(columns...).
Exec(ctx)
return err
})
}
func (f *filterDB) DeleteFilterStatusByID(ctx context.Context, id string) error {
if _, err := f.db.
NewDelete().
Model((*gtsmodel.FilterStatus)(nil)).
Where("? = ?", bun.Ident("id"), id).
Exec(ctx); err != nil {
return err
}
f.state.Caches.GTS.FilterStatus.Invalidate("ID", id)
return nil
}

View file

@ -0,0 +1,122 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package bundb_test
import (
"context"
"errors"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/util"
)
// TestFilterStatusCRD tests CRD (no U) and read-all operations on filter statuses.
func (suite *FilterTestSuite) TestFilterStatusCRD() {
t := suite.T()
// Create new filter.
filter := &gtsmodel.Filter{
ID: "01HNEJNVZZVXJTRB3FX3K2B1YF",
AccountID: "01HNEJXCPRTJVJY9MV0VVHGD47",
Title: "foss jail",
Action: gtsmodel.FilterActionWarn,
ContextHome: util.Ptr(true),
ContextPublic: util.Ptr(true),
}
// Create new cancellable test context.
ctx := context.Background()
ctx, cncl := context.WithCancel(ctx)
defer cncl()
// Insert the new filter into the DB.
err := suite.db.PutFilter(ctx, filter)
if err != nil {
t.Fatalf("error inserting filter: %v", err)
}
// There should be no filter statuses yet.
all, err := suite.db.GetFilterStatusesForAccountID(ctx, filter.AccountID)
if err != nil {
t.Fatalf("error fetching filter statuses: %v", err)
}
suite.Empty(all)
// Add a filter status to it.
filterStatus := &gtsmodel.FilterStatus{
ID: "01HNEK4RW5QEAMG9Y4ET6ST0J4",
AccountID: filter.AccountID,
FilterID: filter.ID,
StatusID: "01HQXGMQ3QFXRT4GX9WNQ8KC0X",
}
// Insert the new filter status into the DB.
err = suite.db.PutFilterStatus(ctx, filterStatus)
if err != nil {
t.Fatalf("error inserting filter status: %v", err)
}
// Try to find it again and ensure it has the fields we expect.
check, err := suite.db.GetFilterStatusByID(ctx, filterStatus.ID)
if err != nil {
t.Fatalf("error fetching filter status: %v", err)
}
suite.Equal(filterStatus.ID, check.ID)
suite.NotZero(check.CreatedAt)
suite.NotZero(check.UpdatedAt)
suite.Equal(filterStatus.AccountID, check.AccountID)
suite.Equal(filterStatus.FilterID, check.FilterID)
suite.Equal(filterStatus.StatusID, check.StatusID)
// Loading filter statuses by account ID should find the one we inserted.
all, err = suite.db.GetFilterStatusesForAccountID(ctx, filter.AccountID)
if err != nil {
t.Fatalf("error fetching filter statuses: %v", err)
}
suite.Len(all, 1)
suite.Equal(filterStatus.ID, all[0].ID)
// Loading filter statuses by filter ID should also find the one we inserted.
all, err = suite.db.GetFilterStatusesForFilterID(ctx, filter.ID)
if err != nil {
t.Fatalf("error fetching filter statuses: %v", err)
}
suite.Len(all, 1)
suite.Equal(filterStatus.ID, all[0].ID)
// Delete the filter status from the DB.
err = suite.db.DeleteFilterStatusByID(ctx, filter.ID)
if err != nil {
t.Fatalf("error deleting filter status: %v", err)
}
// Ensure we can't refetch it.
check, err = suite.db.GetFilterStatusByID(ctx, filter.ID)
if !errors.Is(err, db.ErrNoEntries) {
t.Fatalf("fetching deleted filter status returned unexpected error: %v", err)
}
suite.Nil(check)
// Ensure the filter itself is still there.
checkFilter, err := suite.db.GetFilterByID(ctx, filter.ID)
if err != nil {
t.Fatalf("error fetching filter: %v", err)
}
suite.Equal(filter.ID, checkFilter.ID)
}

View file

@ -0,0 +1,97 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package migrations
import (
"context"
gtsmodel "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/uptrace/bun"
)
func init() {
up := func(ctx context.Context, db *bun.DB) error {
return db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error {
// Filter table.
if _, err := tx.
NewCreateTable().
Model(&gtsmodel.Filter{}).
IfNotExists().
Exec(ctx); err != nil {
return err
}
// Filter keyword table.
if _, err := tx.
NewCreateTable().
Model(&gtsmodel.FilterKeyword{}).
IfNotExists().
Exec(ctx); err != nil {
return err
}
// Filter status table.
if _, err := tx.
NewCreateTable().
Model(&gtsmodel.FilterStatus{}).
IfNotExists().
Exec(ctx); err != nil {
return err
}
// Add indexes to the filter tables.
for table, indexes := range map[string]map[string][]string{
"filters": {
"filters_account_id_idx": {"account_id"},
},
"filter_keywords": {
"filter_keywords_account_id_idx": {"account_id"},
"filter_keywords_filter_id_idx": {"filter_id"},
},
"filter_statuses": {
"filter_statuses_account_id_idx": {"account_id"},
"filter_statuses_filter_id_idx": {"filter_id"},
},
} {
for index, columns := range indexes {
if _, err := tx.
NewCreateIndex().
Table(table).
Index(index).
Column(columns...).
IfNotExists().
Exec(ctx); err != nil {
return err
}
}
}
return nil
})
}
down := func(ctx context.Context, db *bun.DB) error {
return db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error {
return nil
})
}
if err := Migrations.Register(up, down); err != nil {
panic(err)
}
}

230
internal/db/bundb/upsert.go Normal file
View file

@ -0,0 +1,230 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package bundb
import (
"context"
"database/sql"
"reflect"
"strings"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/uptrace/bun"
"github.com/uptrace/bun/dialect"
)
// UpsertQuery is a wrapper around an insert query that can update if an insert fails.
// Doesn't implement the full set of Bun query methods, but we can add more if we need them.
// See https://bun.uptrace.dev/guide/query-insert.html#upsert
type UpsertQuery struct {
db bun.IDB
model interface{}
constraints []string
columns []string
}
func NewUpsert(idb bun.IDB) *UpsertQuery {
// note: passing in rawtx as conn iface so no double query-hook
// firing when passed through the bun.Tx.Query___() functions.
return &UpsertQuery{db: idb}
}
// Model sets the model or models to upsert.
func (u *UpsertQuery) Model(model interface{}) *UpsertQuery {
u.model = model
return u
}
// Constraint sets the columns or indexes that are used to check for conflicts.
// This is required.
func (u *UpsertQuery) Constraint(constraints ...string) *UpsertQuery {
u.constraints = constraints
return u
}
// Column sets the columns to update if an insert does't happen.
// If empty, all columns not being used for constraints will be updated.
// Cannot overlap with Constraint.
func (u *UpsertQuery) Column(columns ...string) *UpsertQuery {
u.columns = columns
return u
}
// insertDialect errors if we're using a dialect in which we don't know how to upsert.
func (u *UpsertQuery) insertDialect() error {
dialectName := u.db.Dialect().Name()
switch dialectName {
case dialect.PG, dialect.SQLite:
return nil
default:
// FUTURE: MySQL has its own variation on upserts, but the syntax is different.
return gtserror.Newf("UpsertQuery: upsert not supported by SQL dialect: %s", dialectName)
}
}
// insertConstraints checks that we have constraints and returns them.
func (u *UpsertQuery) insertConstraints() ([]string, error) {
if len(u.constraints) == 0 {
return nil, gtserror.New("UpsertQuery: upserts require at least one constraint column or index, none provided")
}
return u.constraints, nil
}
// insertColumns returns the non-constraint columns we'll be updating.
func (u *UpsertQuery) insertColumns(constraints []string) ([]string, error) {
// Constraints as a set.
constraintSet := make(map[string]struct{}, len(constraints))
for _, constraint := range constraints {
constraintSet[constraint] = struct{}{}
}
var columns []string
var err error
if len(u.columns) == 0 {
columns, err = u.insertColumnsDefault(constraintSet)
} else {
columns, err = u.insertColumnsSpecified(constraintSet)
}
if err != nil {
return nil, err
}
if len(columns) == 0 {
return nil, gtserror.New("UpsertQuery: there are no columns to update when upserting")
}
return columns, nil
}
// hasElem returns whether the type has an element and can call [reflect.Type.Elem] without panicking.
func hasElem(modelType reflect.Type) bool {
switch modelType.Kind() {
case reflect.Array, reflect.Chan, reflect.Map, reflect.Pointer, reflect.Slice:
return true
default:
return false
}
}
// insertColumnsDefault returns all non-constraint columns from the model schema.
func (u *UpsertQuery) insertColumnsDefault(constraintSet map[string]struct{}) ([]string, error) {
// Get underlying struct type.
modelType := reflect.TypeOf(u.model)
for hasElem(modelType) {
modelType = modelType.Elem()
}
table := u.db.Dialect().Tables().Get(modelType)
if table == nil {
return nil, gtserror.Newf("UpsertQuery: couldn't find the table schema for model: %v", u.model)
}
columns := make([]string, 0, len(u.columns))
for _, field := range table.Fields {
column := field.Name
if _, overlaps := constraintSet[column]; !overlaps {
columns = append(columns, column)
}
}
return columns, nil
}
// insertColumnsSpecified ensures constraints and specified columns to update don't overlap.
func (u *UpsertQuery) insertColumnsSpecified(constraintSet map[string]struct{}) ([]string, error) {
overlapping := make([]string, 0, min(len(u.constraints), len(u.columns)))
for _, column := range u.columns {
if _, overlaps := constraintSet[column]; overlaps {
overlapping = append(overlapping, column)
}
}
if len(overlapping) > 0 {
return nil, gtserror.Newf(
"UpsertQuery: the following columns can't be used for both constraints and columns to update: %s",
strings.Join(overlapping, ", "),
)
}
return u.columns, nil
}
// insert tries to create a Bun insert query from an upsert query.
func (u *UpsertQuery) insertQuery() (*bun.InsertQuery, error) {
var err error
err = u.insertDialect()
if err != nil {
return nil, err
}
constraints, err := u.insertConstraints()
if err != nil {
return nil, err
}
columns, err := u.insertColumns(constraints)
if err != nil {
return nil, err
}
// Build the parts of the query that need us to generate SQL.
constraintIDPlaceholders := make([]string, 0, len(constraints))
constraintIDs := make([]interface{}, 0, len(constraints))
for _, constraint := range constraints {
constraintIDPlaceholders = append(constraintIDPlaceholders, "?")
constraintIDs = append(constraintIDs, bun.Ident(constraint))
}
onSQL := "conflict (" + strings.Join(constraintIDPlaceholders, ", ") + ") do update"
setClauses := make([]string, 0, len(columns))
setIDs := make([]interface{}, 0, 2*len(columns))
for _, column := range columns {
// "excluded" is a special table that contains only the row involved in a conflict.
setClauses = append(setClauses, "? = excluded.?")
setIDs = append(setIDs, bun.Ident(column), bun.Ident(column))
}
setSQL := strings.Join(setClauses, ", ")
insertQuery := u.db.
NewInsert().
Model(u.model).
On(onSQL, constraintIDs...).
Set(setSQL, setIDs...)
return insertQuery, nil
}
// Exec builds a Bun insert query from the upsert query, and executes it.
func (u *UpsertQuery) Exec(ctx context.Context, dest ...interface{}) (sql.Result, error) {
insertQuery, err := u.insertQuery()
if err != nil {
return nil, err
}
return insertQuery.Exec(ctx, dest...)
}
// Scan builds a Bun insert query from the upsert query, and scans it.
func (u *UpsertQuery) Scan(ctx context.Context, dest ...interface{}) error {
insertQuery, err := u.insertQuery()
if err != nil {
return err
}
return insertQuery.Scan(ctx, dest...)
}

View file

@ -32,6 +32,7 @@ type DB interface {
Emoji Emoji
HeaderFilter HeaderFilter
Instance Instance
Filter
List List
Marker Marker
Media Media

101
internal/db/filter.go Normal file
View file

@ -0,0 +1,101 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package db
import (
"context"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
// Filter contains methods for creating, reading, updating, and deleting filters and their keyword and status entries.
type Filter interface {
//<editor-fold desc="Filter methods">
// GetFilterByID gets one filter with the given id.
GetFilterByID(ctx context.Context, id string) (*gtsmodel.Filter, error)
// GetFiltersForAccountID gets all filters owned by the given accountID.
GetFiltersForAccountID(ctx context.Context, accountID string) ([]*gtsmodel.Filter, error)
// PutFilter puts a new filter in the database, adding any attached keywords or statuses.
// It uses a transaction to ensure no partial updates.
PutFilter(ctx context.Context, filter *gtsmodel.Filter) error
// UpdateFilter updates the given filter,
// upserts any attached keywords and inserts any new statuses (existing statuses cannot be updated),
// and deletes indicated filter keywords and statuses by ID.
// It uses a transaction to ensure no partial updates.
// The column lists are optional; if not specified, all columns will be updated.
UpdateFilter(
ctx context.Context,
filter *gtsmodel.Filter,
filterColumns []string,
filterKeywordColumns []string,
deleteFilterKeywordIDs []string,
deleteFilterStatusIDs []string,
) error
// DeleteFilterByID deletes one filter with the given ID.
// It uses a transaction to ensure no partial updates.
DeleteFilterByID(ctx context.Context, id string) error
//</editor-fold>
//<editor-fold desc="Filter keyword methods">
// GetFilterKeywordByID gets one filter keyword with the given ID.
GetFilterKeywordByID(ctx context.Context, id string) (*gtsmodel.FilterKeyword, error)
// GetFilterKeywordsForFilterID gets filter keywords from the given filterID.
GetFilterKeywordsForFilterID(ctx context.Context, filterID string) ([]*gtsmodel.FilterKeyword, error)
// GetFilterKeywordsForAccountID gets filter keywords from the given accountID.
GetFilterKeywordsForAccountID(ctx context.Context, accountID string) ([]*gtsmodel.FilterKeyword, error)
// PutFilterKeyword inserts a single filter keyword into the database.
PutFilterKeyword(ctx context.Context, filterKeyword *gtsmodel.FilterKeyword) error
// UpdateFilterKeyword updates the given filter keyword.
// Columns is optional, if not specified all will be updated.
UpdateFilterKeyword(ctx context.Context, filterKeyword *gtsmodel.FilterKeyword, columns ...string) error
// DeleteFilterKeywordByID deletes one filter keyword with the given id.
DeleteFilterKeywordByID(ctx context.Context, id string) error
//</editor-fold>
//<editor-fold desc="Filter status methods">
// GetFilterStatusByID gets one filter status with the given ID.
GetFilterStatusByID(ctx context.Context, id string) (*gtsmodel.FilterStatus, error)
// GetFilterStatusesForFilterID gets filter statuses from the given filterID.
GetFilterStatusesForFilterID(ctx context.Context, filterID string) ([]*gtsmodel.FilterStatus, error)
// GetFilterStatusesForAccountID gets filter keywords from the given accountID.
GetFilterStatusesForAccountID(ctx context.Context, accountID string) ([]*gtsmodel.FilterStatus, error)
// PutFilterStatus inserts a single filter status into the database.
PutFilterStatus(ctx context.Context, filterStatus *gtsmodel.FilterStatus) error
// DeleteFilterStatusByID deletes one filter status with the given id.
DeleteFilterStatusByID(ctx context.Context, id string) error
//</editor-fold>
}

View file

@ -0,0 +1,71 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package gtsmodel
import "time"
// Filter stores a filter created by a local account.
type Filter struct {
ID string `bun:"type:CHAR(26),pk,nullzero,notnull,unique"` // id of this item in the database
CreatedAt time.Time `bun:"type:timestamptz,nullzero,notnull,default:current_timestamp"` // when was item created
UpdatedAt time.Time `bun:"type:timestamptz,nullzero,notnull,default:current_timestamp"` // when was item last updated
ExpiresAt time.Time `bun:"type:timestamptz,nullzero"` // Time filter should expire. If null, should not expire.
AccountID string `bun:"type:CHAR(26),notnull,nullzero"` // ID of the local account that created the filter.
Title string `bun:",nullzero,notnull,unique"` // The name of the filter.
Action FilterAction `bun:",nullzero,notnull"` // The action to take.
Keywords []*FilterKeyword `bun:"-"` // Keywords for this filter.
Statuses []*FilterStatus `bun:"-"` // Statuses for this filter.
ContextHome *bool `bun:",nullzero,notnull,default:false"` // Apply filter to home timeline and lists.
ContextNotifications *bool `bun:",nullzero,notnull,default:false"` // Apply filter to notifications.
ContextPublic *bool `bun:",nullzero,notnull,default:false"` // Apply filter to home timeline and lists.
ContextThread *bool `bun:",nullzero,notnull,default:false"` // Apply filter when viewing a status's associated thread.
ContextAccount *bool `bun:",nullzero,notnull,default:false"` // Apply filter when viewing an account profile.
}
// FilterKeyword stores a single keyword to filter statuses against.
type FilterKeyword struct {
ID string `bun:"type:CHAR(26),pk,nullzero,notnull,unique"` // id of this item in the database
CreatedAt time.Time `bun:"type:timestamptz,nullzero,notnull,default:current_timestamp"` // when was item created
UpdatedAt time.Time `bun:"type:timestamptz,nullzero,notnull,default:current_timestamp"` // when was item last updated
AccountID string `bun:"type:CHAR(26),notnull,nullzero"` // ID of the local account that created the filter keyword.
FilterID string `bun:"type:CHAR(26),notnull,nullzero,unique:filter_keywords_filter_id_keyword_uniq"` // ID of the filter that this keyword belongs to.
Filter *Filter `bun:"-"` // Filter corresponding to FilterID
Keyword string `bun:",nullzero,notnull,unique:filter_keywords_filter_id_keyword_uniq"` // The keyword or phrase to filter against.
WholeWord *bool `bun:",nullzero,notnull,default:false"` // Should the filter consider word boundaries?
}
// FilterStatus stores a single status to filter.
type FilterStatus struct {
ID string `bun:"type:CHAR(26),pk,nullzero,notnull,unique"` // id of this item in the database
CreatedAt time.Time `bun:"type:timestamptz,nullzero,notnull,default:current_timestamp"` // when was item created
UpdatedAt time.Time `bun:"type:timestamptz,nullzero,notnull,default:current_timestamp"` // when was item last updated
AccountID string `bun:"type:CHAR(26),notnull,nullzero"` // ID of the local account that created the filter keyword.
FilterID string `bun:"type:CHAR(26),notnull,nullzero,unique:filter_statuses_filter_id_status_id_uniq"` // ID of the filter that this keyword belongs to.
Filter *Filter `bun:"-"` // Filter corresponding to FilterID
StatusID string `bun:"type:CHAR(26),notnull,nullzero,unique:filter_statuses_filter_id_status_id_uniq"` // ID of the status to filter.
}
// FilterAction represents the action to take on a filtered status.
type FilterAction string
const (
// FilterActionWarn means that the status should be shown behind a warning.
FilterActionWarn FilterAction = "warn"
// FilterActionHide means that the status should be removed from timeline results entirely.
FilterActionHide FilterAction = "hide"
)

View file

@ -0,0 +1,38 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package v1
import (
"context"
"fmt"
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
// apiFilter is a shortcut to return the API v1 filter version of the given
// filter keyword, or return an appropriate error if conversion fails.
func (p *Processor) apiFilter(ctx context.Context, filterKeyword *gtsmodel.FilterKeyword) (*apimodel.FilterV1, gtserror.WithCode) {
apiFilter, err := p.converter.FilterKeywordToAPIFilterV1(ctx, filterKeyword)
if err != nil {
return nil, gtserror.NewErrorInternalError(fmt.Errorf("error converting filter keyword to API v1 filter: %w", err))
}
return apiFilter, nil
}

View file

@ -0,0 +1,87 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package v1
import (
"context"
"errors"
"fmt"
"time"
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/id"
"github.com/superseriousbusiness/gotosocial/internal/util"
)
// Create a new filter and filter keyword for the given account, using the provided parameters.
// These params should have already been validated by the time they reach this function.
func (p *Processor) Create(ctx context.Context, account *gtsmodel.Account, form *apimodel.FilterCreateUpdateRequestV1) (*apimodel.FilterV1, gtserror.WithCode) {
filter := &gtsmodel.Filter{
ID: id.NewULID(),
AccountID: account.ID,
Title: form.Phrase,
Action: gtsmodel.FilterActionWarn,
}
if *form.Irreversible {
filter.Action = gtsmodel.FilterActionHide
}
if form.ExpiresIn != nil {
filter.ExpiresAt = time.Now().Add(time.Second * time.Duration(*form.ExpiresIn))
}
for _, context := range form.Context {
switch context {
case apimodel.FilterContextHome:
filter.ContextHome = util.Ptr(true)
case apimodel.FilterContextNotifications:
filter.ContextNotifications = util.Ptr(true)
case apimodel.FilterContextPublic:
filter.ContextPublic = util.Ptr(true)
case apimodel.FilterContextThread:
filter.ContextThread = util.Ptr(true)
case apimodel.FilterContextAccount:
filter.ContextAccount = util.Ptr(true)
default:
return nil, gtserror.NewErrorUnprocessableEntity(
fmt.Errorf("unsupported filter context '%s'", context),
)
}
}
filterKeyword := &gtsmodel.FilterKeyword{
ID: id.NewULID(),
AccountID: account.ID,
FilterID: filter.ID,
Filter: filter,
Keyword: form.Phrase,
WholeWord: util.Ptr(util.PtrValueOr(form.WholeWord, false)),
}
filter.Keywords = []*gtsmodel.FilterKeyword{filterKeyword}
if err := p.state.DB.PutFilter(ctx, filter); err != nil {
if errors.Is(err, db.ErrAlreadyExists) {
err = errors.New("you already have a filter with this title")
return nil, gtserror.NewErrorConflict(err, err.Error())
}
return nil, gtserror.NewErrorInternalError(err)
}
return p.apiFilter(ctx, filterKeyword)
}

View file

@ -0,0 +1,67 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package v1
import (
"context"
"errors"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtscontext"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
// Delete an existing filter keyword and (if empty afterwards) filter for the given account.
func (p *Processor) Delete(
ctx context.Context,
account *gtsmodel.Account,
filterKeywordID string,
) gtserror.WithCode {
// Get enough of the filter keyword that we can look up its filter ID.
filterKeyword, err := p.state.DB.GetFilterKeywordByID(gtscontext.SetBarebones(ctx), filterKeywordID)
if err != nil {
if errors.Is(err, db.ErrNoEntries) {
return gtserror.NewErrorNotFound(err)
}
return gtserror.NewErrorInternalError(err)
}
if filterKeyword.AccountID != account.ID {
return gtserror.NewErrorNotFound(nil)
}
// Get the filter for this keyword.
filter, err := p.state.DB.GetFilterByID(ctx, filterKeyword.FilterID)
if err != nil {
return gtserror.NewErrorNotFound(err)
}
if len(filter.Keywords) > 1 || len(filter.Statuses) > 0 {
// The filter has other keywords or statuses. Delete only the requested filter keyword.
if err := p.state.DB.DeleteFilterKeywordByID(ctx, filterKeyword.ID); err != nil {
return gtserror.NewErrorInternalError(err)
}
} else {
// Delete the entire filter.
if err := p.state.DB.DeleteFilterByID(ctx, filter.ID); err != nil {
return gtserror.NewErrorInternalError(err)
}
}
return nil
}

View file

@ -0,0 +1,35 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package v1
import (
"github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
)
type Processor struct {
state *state.State
converter *typeutils.Converter
}
func New(state *state.State, converter *typeutils.Converter) Processor {
return Processor{
state: state,
converter: converter,
}
}

View file

@ -0,0 +1,78 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package v1
import (
"context"
"errors"
"slices"
"strings"
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
// Get looks up a filter keyword by ID and returns it as a v1 filter.
func (p *Processor) Get(ctx context.Context, account *gtsmodel.Account, filterKeywordID string) (*apimodel.FilterV1, gtserror.WithCode) {
filterKeyword, err := p.state.DB.GetFilterKeywordByID(ctx, filterKeywordID)
if err != nil {
if errors.Is(err, db.ErrNoEntries) {
return nil, gtserror.NewErrorNotFound(err)
}
return nil, gtserror.NewErrorInternalError(err)
}
if filterKeyword.AccountID != account.ID {
return nil, gtserror.NewErrorNotFound(nil)
}
return p.apiFilter(ctx, filterKeyword)
}
// GetAll looks up all filter keywords for the current account and returns them as v1 filters.
func (p *Processor) GetAll(ctx context.Context, account *gtsmodel.Account) ([]*apimodel.FilterV1, gtserror.WithCode) {
filters, err := p.state.DB.GetFilterKeywordsForAccountID(
ctx,
account.ID,
)
if err != nil {
if errors.Is(err, db.ErrNoEntries) {
return nil, nil
}
return nil, gtserror.NewErrorInternalError(err)
}
apiFilters := make([]*apimodel.FilterV1, 0, len(filters))
for _, list := range filters {
apiFilter, errWithCode := p.apiFilter(ctx, list)
if errWithCode != nil {
return nil, errWithCode
}
apiFilters = append(apiFilters, apiFilter)
}
// Sort them by ID so that they're in a stable order.
// Clients may opt to sort them lexically in a locale-aware manner.
slices.SortFunc(apiFilters, func(lhs *apimodel.FilterV1, rhs *apimodel.FilterV1) int {
return strings.Compare(lhs.ID, rhs.ID)
})
return apiFilters, nil
}

View file

@ -0,0 +1,165 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package v1
import (
"context"
"errors"
"fmt"
"strings"
"time"
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtscontext"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/util"
)
// Update an existing filter and filter keyword for the given account, using the provided parameters.
// These params should have already been validated by the time they reach this function.
func (p *Processor) Update(
ctx context.Context,
account *gtsmodel.Account,
filterKeywordID string,
form *apimodel.FilterCreateUpdateRequestV1,
) (*apimodel.FilterV1, gtserror.WithCode) {
// Get enough of the filter keyword that we can look up its filter ID.
filterKeyword, err := p.state.DB.GetFilterKeywordByID(gtscontext.SetBarebones(ctx), filterKeywordID)
if err != nil {
if errors.Is(err, db.ErrNoEntries) {
return nil, gtserror.NewErrorNotFound(err)
}
return nil, gtserror.NewErrorInternalError(err)
}
if filterKeyword.AccountID != account.ID {
return nil, gtserror.NewErrorNotFound(nil)
}
// Get the filter for this keyword.
filter, err := p.state.DB.GetFilterByID(ctx, filterKeyword.FilterID)
if err != nil {
if errors.Is(err, db.ErrNoEntries) {
return nil, gtserror.NewErrorNotFound(err)
}
return nil, gtserror.NewErrorInternalError(err)
}
title := form.Phrase
action := gtsmodel.FilterActionWarn
if *form.Irreversible {
action = gtsmodel.FilterActionHide
}
expiresAt := time.Time{}
if form.ExpiresIn != nil {
expiresAt = time.Now().Add(time.Second * time.Duration(*form.ExpiresIn))
}
contextHome := false
contextNotifications := false
contextPublic := false
contextThread := false
contextAccount := false
for _, context := range form.Context {
switch context {
case apimodel.FilterContextHome:
contextHome = true
case apimodel.FilterContextNotifications:
contextNotifications = true
case apimodel.FilterContextPublic:
contextPublic = true
case apimodel.FilterContextThread:
contextThread = true
case apimodel.FilterContextAccount:
contextAccount = true
default:
return nil, gtserror.NewErrorUnprocessableEntity(
fmt.Errorf("unsupported filter context '%s'", context),
)
}
}
// v1 filter APIs can't change certain fields for a filter with multiple keywords or any statuses,
// since it would be an unexpected side effect on filters that, to the v1 API, appear separate.
// See https://docs.joinmastodon.org/methods/filters/#update-v1
if len(filter.Keywords) > 1 || len(filter.Statuses) > 0 {
forbiddenFields := make([]string, 0, 4)
if title != filter.Title {
forbiddenFields = append(forbiddenFields, "phrase")
}
if action != filter.Action {
forbiddenFields = append(forbiddenFields, "irreversible")
}
if expiresAt != filter.ExpiresAt {
forbiddenFields = append(forbiddenFields, "expires_in")
}
if contextHome != util.PtrValueOr(filter.ContextHome, false) ||
contextNotifications != util.PtrValueOr(filter.ContextNotifications, false) ||
contextPublic != util.PtrValueOr(filter.ContextPublic, false) ||
contextThread != util.PtrValueOr(filter.ContextThread, false) ||
contextAccount != util.PtrValueOr(filter.ContextAccount, false) {
forbiddenFields = append(forbiddenFields, "context")
}
if len(forbiddenFields) > 0 {
return nil, gtserror.NewErrorUnprocessableEntity(
fmt.Errorf("v1 filter backwards compatibility: can't change these fields for a filter with multiple keywords or any statuses: %s", strings.Join(forbiddenFields, ", ")),
)
}
}
// Now that we've checked that the changes are legal, apply them to the filter and keyword.
filter.Title = title
filter.Action = action
filter.ExpiresAt = expiresAt
filter.ContextHome = &contextHome
filter.ContextNotifications = &contextNotifications
filter.ContextPublic = &contextPublic
filter.ContextThread = &contextThread
filter.ContextAccount = &contextAccount
filterKeyword.Keyword = form.Phrase
filterKeyword.WholeWord = util.Ptr(util.PtrValueOr(form.WholeWord, false))
// We only want to update the relevant filter keyword.
filter.Keywords = []*gtsmodel.FilterKeyword{filterKeyword}
filter.Statuses = nil
filterKeyword.Filter = filter
filterColumns := []string{
"title",
"action",
"expires_at",
"context_home",
"context_notifications",
"context_public",
"context_thread",
"context_account",
}
filterKeywordColumns := []string{
"keyword",
"whole_word",
}
if err := p.state.DB.UpdateFilter(ctx, filter, filterColumns, filterKeywordColumns, nil, nil); err != nil {
if errors.Is(err, db.ErrAlreadyExists) {
err = errors.New("you already have a filter with this title")
return nil, gtserror.NewErrorConflict(err, err.Error())
}
return nil, gtserror.NewErrorInternalError(err)
}
return p.apiFilter(ctx, filterKeyword)
}

View file

@ -29,6 +29,7 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/processing/admin" "github.com/superseriousbusiness/gotosocial/internal/processing/admin"
"github.com/superseriousbusiness/gotosocial/internal/processing/common" "github.com/superseriousbusiness/gotosocial/internal/processing/common"
"github.com/superseriousbusiness/gotosocial/internal/processing/fedi" "github.com/superseriousbusiness/gotosocial/internal/processing/fedi"
filtersv1 "github.com/superseriousbusiness/gotosocial/internal/processing/filters/v1"
"github.com/superseriousbusiness/gotosocial/internal/processing/list" "github.com/superseriousbusiness/gotosocial/internal/processing/list"
"github.com/superseriousbusiness/gotosocial/internal/processing/markers" "github.com/superseriousbusiness/gotosocial/internal/processing/markers"
"github.com/superseriousbusiness/gotosocial/internal/processing/media" "github.com/superseriousbusiness/gotosocial/internal/processing/media"
@ -68,20 +69,21 @@ type Processor struct {
SUB-PROCESSORS SUB-PROCESSORS
*/ */
account account.Processor account account.Processor
admin admin.Processor admin admin.Processor
fedi fedi.Processor fedi fedi.Processor
list list.Processor filtersv1 filtersv1.Processor
markers markers.Processor list list.Processor
media media.Processor markers markers.Processor
polls polls.Processor media media.Processor
report report.Processor polls polls.Processor
search search.Processor report report.Processor
status status.Processor search search.Processor
stream stream.Processor status status.Processor
timeline timeline.Processor stream stream.Processor
user user.Processor timeline timeline.Processor
workers workers.Processor user user.Processor
workers workers.Processor
} }
func (p *Processor) Account() *account.Processor { func (p *Processor) Account() *account.Processor {
@ -96,6 +98,10 @@ func (p *Processor) Fedi() *fedi.Processor {
return &p.fedi return &p.fedi
} }
func (p *Processor) FiltersV1() *filtersv1.Processor {
return &p.filtersv1
}
func (p *Processor) List() *list.Processor { func (p *Processor) List() *list.Processor {
return &p.list return &p.list
} }
@ -177,6 +183,7 @@ func NewProcessor(
processor.account = account.New(&common, state, converter, mediaManager, oauthServer, federator, filter, parseMentionFunc) processor.account = account.New(&common, state, converter, mediaManager, oauthServer, federator, filter, parseMentionFunc)
processor.admin = admin.New(state, cleaner, converter, mediaManager, federator.TransportController(), emailSender) processor.admin = admin.New(state, cleaner, converter, mediaManager, federator.TransportController(), emailSender)
processor.fedi = fedi.New(state, &common, converter, federator, filter) processor.fedi = fedi.New(state, &common, converter, federator, filter)
processor.filtersv1 = filtersv1.New(state, converter)
processor.list = list.New(state, converter) processor.list = list.New(state, converter)
processor.markers = markers.New(state, converter) processor.markers = markers.New(state, converter)
processor.polls = polls.New(&common, state, converter) processor.polls = polls.New(&common, state, converter)

View file

@ -111,7 +111,6 @@ func (p *Processor) contextGet(
TopoSort(descendants, targetStatus.AccountID) TopoSort(descendants, targetStatus.AccountID)
//goland:noinspection GoImportUsedAsName
context := &apimodel.Context{ context := &apimodel.Context{
Ancestors: make([]apimodel.Status, 0, len(ancestors)), Ancestors: make([]apimodel.Status, 0, len(ancestors)),
Descendants: make([]apimodel.Status, 0, len(descendants)), Descendants: make([]apimodel.Status, 0, len(descendants)),

View file

@ -1617,6 +1617,59 @@ func (c *Converter) convertAttachmentsToAPIAttachments(ctx context.Context, atta
return apiAttachments, errs.Combine() return apiAttachments, errs.Combine()
} }
// FilterToAPIFiltersV1 converts one GTS model filter into an API v1 filter list
func (c *Converter) FilterToAPIFiltersV1(ctx context.Context, filter *gtsmodel.Filter) ([]*apimodel.FilterV1, error) {
apiFilters := make([]*apimodel.FilterV1, 0, len(filter.Keywords))
for _, filterKeyword := range filter.Keywords {
apiFilter, err := c.FilterKeywordToAPIFilterV1(ctx, filterKeyword)
if err != nil {
return nil, err
}
apiFilters = append(apiFilters, apiFilter)
}
return apiFilters, nil
}
// FilterKeywordToAPIFilterV1 converts one GTS model filter and filter keyword into an API v1 filter
func (c *Converter) FilterKeywordToAPIFilterV1(ctx context.Context, filterKeyword *gtsmodel.FilterKeyword) (*apimodel.FilterV1, error) {
if filterKeyword.Filter == nil {
return nil, gtserror.New("FilterKeyword model's Filter field isn't populated, but needs to be")
}
filter := filterKeyword.Filter
apiContexts := make([]apimodel.FilterContext, 0, apimodel.FilterContextNumValues)
if util.PtrValueOr(filter.ContextHome, false) {
apiContexts = append(apiContexts, apimodel.FilterContextHome)
}
if util.PtrValueOr(filter.ContextNotifications, false) {
apiContexts = append(apiContexts, apimodel.FilterContextNotifications)
}
if util.PtrValueOr(filter.ContextPublic, false) {
apiContexts = append(apiContexts, apimodel.FilterContextPublic)
}
if util.PtrValueOr(filter.ContextThread, false) {
apiContexts = append(apiContexts, apimodel.FilterContextThread)
}
if util.PtrValueOr(filter.ContextAccount, false) {
apiContexts = append(apiContexts, apimodel.FilterContextAccount)
}
var expiresAt *string
if !filter.ExpiresAt.IsZero() {
expiresAt = util.Ptr(util.FormatISO8601(filter.ExpiresAt))
}
return &apimodel.FilterV1{
// v1 filters have a single keyword each, so we use the filter keyword ID as the v1 filter ID.
ID: filterKeyword.ID,
Phrase: filterKeyword.Keyword,
Context: apiContexts,
WholeWord: util.PtrValueOr(filterKeyword.WholeWord, false),
ExpiresAt: expiresAt,
Irreversible: filter.Action == gtsmodel.FilterActionHide,
}, nil
}
// convertEmojisToAPIEmojis will convert a slice of GTS model emojis to frontend API model emojis, falling back to IDs if no GTS models supplied. // convertEmojisToAPIEmojis will convert a slice of GTS model emojis to frontend API model emojis, falling back to IDs if no GTS models supplied.
func (c *Converter) convertEmojisToAPIEmojis(ctx context.Context, emojis []*gtsmodel.Emoji, emojiIDs []string) ([]apimodel.Emoji, error) { func (c *Converter) convertEmojisToAPIEmojis(ctx context.Context, emojis []*gtsmodel.Emoji, emojiIDs []string) ([]apimodel.Emoji, error) {
var errs gtserror.MultiError var errs gtserror.MultiError

View file

@ -44,6 +44,7 @@ const (
maximumProfileFieldLength = 255 maximumProfileFieldLength = 255
maximumProfileFields = 6 maximumProfileFields = 6
maximumListTitleLength = 200 maximumListTitleLength = 200
maximumFilterKeywordLength = 40
) )
// Password returns a helpful error if the given password // Password returns a helpful error if the given password
@ -306,3 +307,44 @@ func MarkerName(name string) error {
} }
return fmt.Errorf("marker timeline name '%s' was not recognized, valid options are '%s', '%s'", name, apimodel.MarkerNameHome, apimodel.MarkerNameNotifications) return fmt.Errorf("marker timeline name '%s' was not recognized, valid options are '%s', '%s'", name, apimodel.MarkerNameHome, apimodel.MarkerNameNotifications)
} }
// FilterKeyword validates the title of a new or updated List.
func FilterKeyword(keyword string) error {
if keyword == "" {
return fmt.Errorf("filter keyword must be provided, and must be no more than %d chars", maximumFilterKeywordLength)
}
if length := len([]rune(keyword)); length > maximumFilterKeywordLength {
return fmt.Errorf("filter keyword length must be no more than %d chars, provided keyword was %d chars", maximumFilterKeywordLength, length)
}
return nil
}
// FilterContexts validates the context of a new or updated filter.
func FilterContexts(contexts []apimodel.FilterContext) error {
if len(contexts) == 0 {
return fmt.Errorf("at least one filter context is required")
}
for _, context := range contexts {
switch context {
case apimodel.FilterContextHome,
apimodel.FilterContextNotifications,
apimodel.FilterContextPublic,
apimodel.FilterContextThread,
apimodel.FilterContextAccount:
continue
default:
return fmt.Errorf(
"filter context '%s' was not recognized, valid options are '%s', '%s', '%s', '%s', '%s'",
context,
apimodel.FilterContextHome,
apimodel.FilterContextNotifications,
apimodel.FilterContextPublic,
apimodel.FilterContextThread,
apimodel.FilterContextAccount,
)
}
}
return nil
}

View file

@ -31,6 +31,9 @@ EXPECT=$(cat << "EOF"
"boost-of-ids-mem-ratio": 3, "boost-of-ids-mem-ratio": 3,
"emoji-category-mem-ratio": 0.1, "emoji-category-mem-ratio": 0.1,
"emoji-mem-ratio": 3, "emoji-mem-ratio": 3,
"filter-keyword-mem-ratio": 0.5,
"filter-mem-ratio": 0.5,
"filter-status-mem-ratio": 0.5,
"follow-ids-mem-ratio": 4, "follow-ids-mem-ratio": 4,
"follow-mem-ratio": 2, "follow-mem-ratio": 2,
"follow-request-ids-mem-ratio": 2, "follow-request-ids-mem-ratio": 2,

View file

@ -37,6 +37,9 @@ var testModels = []interface{}{
&gtsmodel.Block{}, &gtsmodel.Block{},
&gtsmodel.DomainBlock{}, &gtsmodel.DomainBlock{},
&gtsmodel.EmailDomainBlock{}, &gtsmodel.EmailDomainBlock{},
&gtsmodel.Filter{},
&gtsmodel.FilterKeyword{},
&gtsmodel.FilterStatus{},
&gtsmodel.Follow{}, &gtsmodel.Follow{},
&gtsmodel.FollowRequest{}, &gtsmodel.FollowRequest{},
&gtsmodel.List{}, &gtsmodel.List{},
@ -329,6 +332,24 @@ func StandardDBSetup(db db.DB, accounts map[string]*gtsmodel.Account) {
} }
} }
for _, v := range NewTestFilters() {
if err := db.Put(ctx, v); err != nil {
log.Panic(nil, err)
}
}
for _, v := range NewTestFilterKeywords() {
if err := db.Put(ctx, v); err != nil {
log.Panic(nil, err)
}
}
for _, v := range NewTestFilterStatuses() {
if err := db.Put(ctx, v); err != nil {
log.Panic(nil, err)
}
}
if err := db.CreateInstanceAccount(ctx); err != nil { if err := db.CreateInstanceAccount(ctx); err != nil {
log.Panic(nil, err) log.Panic(nil, err)
} }

View file

@ -3263,6 +3263,87 @@ func NewTestDereferenceRequests(accounts map[string]*gtsmodel.Account) map[strin
} }
} }
func NewTestFilters() map[string]*gtsmodel.Filter {
return map[string]*gtsmodel.Filter{
"local_account_1_filter_1": {
ID: "01HN26VM6KZTW1ANNRVSBMA461",
CreatedAt: TimeMustParse("2024-01-25T12:20:03+02:00"),
UpdatedAt: TimeMustParse("2024-01-25T12:20:03+02:00"),
AccountID: "01F8MH1H7YV1Z7D2C8K2730QBF",
Title: "fnord",
Action: gtsmodel.FilterActionWarn,
ContextHome: util.Ptr(true),
ContextPublic: util.Ptr(true),
},
"local_account_1_filter_2": {
ID: "01HN277FSPQAWXZXK92QPPYF79",
CreatedAt: TimeMustParse("2024-01-25T12:20:03+02:00"),
UpdatedAt: TimeMustParse("2024-01-25T12:20:03+02:00"),
AccountID: "01F8MH1H7YV1Z7D2C8K2730QBF",
Title: "metasyntactic variables",
Action: gtsmodel.FilterActionWarn,
ContextHome: util.Ptr(true),
ContextPublic: util.Ptr(true),
},
"local_account_2_filter_1": {
ID: "01HNGFYJBED9FS0VWRVMY4TKXH",
CreatedAt: TimeMustParse("2024-01-25T12:20:03+02:00"),
UpdatedAt: TimeMustParse("2024-01-25T12:20:03+02:00"),
AccountID: "01F8MH1VYJAE00TVVGMM5JNJ8X",
Title: "gamer words",
Action: gtsmodel.FilterActionWarn,
ContextHome: util.Ptr(true),
ContextPublic: util.Ptr(true),
},
}
}
func NewTestFilterKeywords() map[string]*gtsmodel.FilterKeyword {
return map[string]*gtsmodel.FilterKeyword{
"local_account_1_filter_1_keyword_1": {
ID: "01HN272TAVWAXX72ZX4M8JZ0PS",
CreatedAt: TimeMustParse("2024-01-25T12:20:03+02:00"),
UpdatedAt: TimeMustParse("2024-01-25T12:20:03+02:00"),
AccountID: "01F8MH1H7YV1Z7D2C8K2730QBF",
FilterID: "01HN26VM6KZTW1ANNRVSBMA461",
Keyword: "fnord",
WholeWord: util.Ptr(true),
},
"local_account_1_filter_2_keyword_1": {
ID: "01HN277Y11ENG4EC1ERMAC9FH4",
CreatedAt: TimeMustParse("2024-01-25T12:20:03+02:00"),
UpdatedAt: TimeMustParse("2024-01-25T12:20:03+02:00"),
AccountID: "01F8MH1H7YV1Z7D2C8K2730QBF",
FilterID: "01HN277FSPQAWXZXK92QPPYF79",
Keyword: "foo",
WholeWord: util.Ptr(true),
},
"local_account_1_filter_2_keyword_2": {
ID: "01HN278494N88BA2FY4DZ5JTNS",
CreatedAt: TimeMustParse("2024-01-25T12:20:03+02:00"),
UpdatedAt: TimeMustParse("2024-01-25T12:20:03+02:00"),
AccountID: "01F8MH1H7YV1Z7D2C8K2730QBF",
FilterID: "01HN277FSPQAWXZXK92QPPYF79",
Keyword: "bar",
WholeWord: util.Ptr(true),
},
"local_account_2_filter_1_keyword_1": {
ID: "01HNGG51HV2JT67XQ5MQ7RA1WE",
CreatedAt: TimeMustParse("2024-01-25T12:20:03+02:00"),
UpdatedAt: TimeMustParse("2024-01-25T12:20:03+02:00"),
AccountID: "01F8MH1VYJAE00TVVGMM5JNJ8X",
FilterID: "01HNGFYJBED9FS0VWRVMY4TKXH",
Keyword: "Virtual Boy",
WholeWord: util.Ptr(true),
},
}
}
func NewTestFilterStatuses() map[string]*gtsmodel.FilterStatus {
// FUTURE: (filters v2) test filter statuses
return map[string]*gtsmodel.FilterStatus{}
}
// GetSignatureForActivity prepares a mock HTTP request as if it were going to deliver activity to destination signed for privkey and pubKeyID, signs the request and returns the header values. // GetSignatureForActivity prepares a mock HTTP request as if it were going to deliver activity to destination signed for privkey and pubKeyID, signs the request and returns the header values.
func GetSignatureForActivity(activity pub.Activity, pubKeyID string, privkey *rsa.PrivateKey, destination *url.URL) (signatureHeader string, digestHeader string, dateHeader string) { func GetSignatureForActivity(activity pub.Activity, pubKeyID string, privkey *rsa.PrivateKey, destination *url.URL) (signatureHeader string, digestHeader string, dateHeader string) {
// convert the activity into json bytes // convert the activity into json bytes