diff --git a/internal/api/s2s/user/statusget.go b/internal/api/s2s/user/statusget.go index d16026d4..3dee0c88 100644 --- a/internal/api/s2s/user/statusget.go +++ b/internal/api/s2s/user/statusget.go @@ -22,6 +22,7 @@ import ( "encoding/json" "fmt" "net/http" + "strings" "github.com/gin-gonic/gin" "github.com/sirupsen/logrus" @@ -35,13 +36,15 @@ func (m *Module) StatusGETHandler(c *gin.Context) { "url": c.Request.RequestURI, }) - requestedUsername := c.Param(UsernameKey) + // usernames on our instance are always lowercase + requestedUsername := strings.ToLower(c.Param(UsernameKey)) if requestedUsername == "" { c.JSON(http.StatusBadRequest, gin.H{"error": "no username specified in request"}) return } - requestedStatusID := c.Param(StatusIDKey) + // status IDs on our instance are always uppercase + requestedStatusID := strings.ToUpper(c.Param(StatusIDKey)) if requestedStatusID == "" { c.JSON(http.StatusBadRequest, gin.H{"error": "no status id specified in request"}) return diff --git a/internal/db/bundb/account.go b/internal/db/bundb/account.go index 876fb518..59292055 100644 --- a/internal/db/bundb/account.go +++ b/internal/db/bundb/account.go @@ -22,6 +22,7 @@ import ( "context" "errors" "fmt" + "strings" "time" "github.com/spf13/viper" @@ -199,7 +200,7 @@ func (a *accountDB) GetLocalAccountByUsername(ctx context.Context, username stri account := new(gtsmodel.Account) q := a.newAccountQ(account). - Where("LOWER(?) = LOWER(?)", bun.Ident("username"), username). // ignore casing + Where("username = ?", strings.ToLower(username)). // usernames on our instance will always be lowercase WhereGroup(" AND ", whereEmptyOrNull("domain")) if err := q.Scan(ctx); err != nil { diff --git a/internal/db/bundb/conn.go b/internal/db/bundb/conn.go index 3b5a3ac9..baa0baea 100644 --- a/internal/db/bundb/conn.go +++ b/internal/db/bundb/conn.go @@ -68,13 +68,12 @@ func (conn *DBConn) ProcessError(err error) db.Error { // Exists checks the results of a SelectQuery for the existence of the data in question, masking ErrNoEntries errors func (conn *DBConn) Exists(ctx context.Context, query *bun.SelectQuery) (bool, db.Error) { - // Get the select query result - count, err := query.Count(ctx) + exists, err := query.Exists(ctx) // Process error as our own and check if it exists switch err := conn.ProcessError(err); err { case nil: - return (count != 0), nil + return exists, nil case db.ErrNoEntries: return false, nil default: diff --git a/internal/db/bundb/domain.go b/internal/db/bundb/domain.go index 417b2bec..e63a584b 100644 --- a/internal/db/bundb/domain.go +++ b/internal/db/bundb/domain.go @@ -21,6 +21,7 @@ package bundb import ( "context" "net/url" + "strings" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" @@ -39,7 +40,8 @@ func (d *domainDB) IsDomainBlocked(ctx context.Context, domain string) (bool, db q := d.conn. NewSelect(). Model(>smodel.DomainBlock{}). - Where("LOWER(domain) = LOWER(?)", domain). + ExcludeColumn("id", "created_at", "updated_at", "created_by_account_id", "private_comment", "public_comment", "obfuscate", "subscription_id"). + Where("domain = ?", domain). Limit(1) return d.conn.Exists(ctx, q) @@ -50,7 +52,7 @@ func (d *domainDB) AreDomainsBlocked(ctx context.Context, domains []string) (boo uniqueDomains := util.UniqueStrings(domains) for _, domain := range uniqueDomains { - if blocked, err := d.IsDomainBlocked(ctx, domain); err != nil { + if blocked, err := d.IsDomainBlocked(ctx, strings.ToLower(domain)); err != nil { return false, err } else if blocked { return blocked, nil diff --git a/internal/db/bundb/domain_test.go b/internal/db/bundb/domain_test.go new file mode 100644 index 00000000..1a3fed24 --- /dev/null +++ b/internal/db/bundb/domain_test.go @@ -0,0 +1,57 @@ +/* + GoToSocial + Copyright (C) 2021-2022 GoToSocial Authors admin@gotosocial.org + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see . +*/ + +package bundb_test + +import ( + "context" + "testing" + + "github.com/stretchr/testify/suite" + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +) + +type DomainTestSuite struct { + BunDBStandardTestSuite +} + +func (suite *DomainTestSuite) TestIsDomainBlocked() { + ctx := context.Background() + + domainBlock := >smodel.DomainBlock{ + ID: "01G204214Y9TNJEBX39C7G88SW", + Domain: "some.bad.apples", + CreatedByAccountID: suite.testAccounts["admin_account"].ID, + } + + // no domain block exists for the given domain yet + blocked, err := suite.db.IsDomainBlocked(ctx, domainBlock.Domain) + suite.NoError(err) + suite.False(blocked) + + suite.db.Put(ctx, domainBlock) + + // domain block now exists + blocked, err = suite.db.IsDomainBlocked(ctx, domainBlock.Domain) + suite.NoError(err) + suite.True(blocked) +} + +func TestDomainTestSuite(t *testing.T) { + suite.Run(t, new(DomainTestSuite)) +} diff --git a/internal/db/bundb/relationship.go b/internal/db/bundb/relationship.go index 36955320..e2e2c96b 100644 --- a/internal/db/bundb/relationship.go +++ b/internal/db/bundb/relationship.go @@ -52,14 +52,25 @@ func (r *relationshipDB) IsBlocked(ctx context.Context, account1 string, account q := r.conn. NewSelect(). Model(>smodel.Block{}). - Where("account_id = ?", account1). - Where("target_account_id = ?", account2). + ExcludeColumn("id", "created_at", "updated_at", "uri"). Limit(1) if eitherDirection { q = q. - WhereOr("target_account_id = ?", account1). - Where("account_id = ?", account2) + WhereGroup(" OR ", func(inner *bun.SelectQuery) *bun.SelectQuery { + return inner. + Where("account_id = ?", account1). + Where("target_account_id = ?", account2) + }). + WhereGroup(" OR ", func(inner *bun.SelectQuery) *bun.SelectQuery { + return inner. + Where("account_id = ?", account2). + Where("target_account_id = ?", account1) + }) + } else { + q = q. + Where("account_id = ?", account1). + Where("target_account_id = ?", account2) } return r.conn.Exists(ctx, q) diff --git a/internal/db/bundb/relationship_test.go b/internal/db/bundb/relationship_test.go index bb0f0e3d..34fe85a5 100644 --- a/internal/db/bundb/relationship_test.go +++ b/internal/db/bundb/relationship_test.go @@ -25,6 +25,7 @@ import ( "github.com/stretchr/testify/suite" "github.com/superseriousbusiness/gotosocial/internal/db" + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" ) type RelationshipTestSuite struct { @@ -32,7 +33,45 @@ type RelationshipTestSuite struct { } func (suite *RelationshipTestSuite) TestIsBlocked() { - suite.Suite.T().Skip("TODO: implement") + ctx := context.Background() + + account1 := suite.testAccounts["local_account_1"].ID + account2 := suite.testAccounts["local_account_2"].ID + + // no blocks exist between account 1 and account 2 + blocked, err := suite.db.IsBlocked(ctx, account1, account2, false) + suite.NoError(err) + suite.False(blocked) + + blocked, err = suite.db.IsBlocked(ctx, account2, account1, false) + suite.NoError(err) + suite.False(blocked) + + // have account1 block account2 + suite.db.Put(ctx, >smodel.Block{ + ID: "01G202BCSXXJZ70BHB5KCAHH8C", + URI: "http://localhost:8080/some_block_uri_1", + AccountID: account1, + TargetAccountID: account2, + }) + + // account 1 now blocks account 2 + blocked, err = suite.db.IsBlocked(ctx, account1, account2, false) + suite.NoError(err) + suite.True(blocked) + + // account 2 doesn't block account 1 + blocked, err = suite.db.IsBlocked(ctx, account2, account1, false) + suite.NoError(err) + suite.False(blocked) + + // a block exists in either direction between the two + blocked, err = suite.db.IsBlocked(ctx, account1, account2, true) + suite.NoError(err) + suite.True(blocked) + blocked, err = suite.db.IsBlocked(ctx, account2, account1, true) + suite.NoError(err) + suite.True(blocked) } func (suite *RelationshipTestSuite) TestGetBlock() { diff --git a/internal/db/bundb/status.go b/internal/db/bundb/status.go index 1783723b..4e670f59 100644 --- a/internal/db/bundb/status.go +++ b/internal/db/bundb/status.go @@ -70,7 +70,7 @@ func (s *statusDB) GetStatusByID(ctx context.Context, id string) (*gtsmodel.Stat return s.cache.GetByID(id) }, func(status *gtsmodel.Status) error { - return s.newStatusQ(status).Where("LOWER(status.id) = LOWER(?)", id).Scan(ctx) + return s.newStatusQ(status).Where("status.id = ?", id).Scan(ctx) }, ) } @@ -82,7 +82,7 @@ func (s *statusDB) GetStatusByURI(ctx context.Context, uri string) (*gtsmodel.St return s.cache.GetByURI(uri) }, func(status *gtsmodel.Status) error { - return s.newStatusQ(status).Where("LOWER(status.uri) = LOWER(?)", uri).Scan(ctx) + return s.newStatusQ(status).Where("status.uri = ?", uri).Scan(ctx) }, ) } @@ -94,7 +94,7 @@ func (s *statusDB) GetStatusByURL(ctx context.Context, url string) (*gtsmodel.St return s.cache.GetByURL(url) }, func(status *gtsmodel.Status) error { - return s.newStatusQ(status).Where("LOWER(status.url) = LOWER(?)", url).Scan(ctx) + return s.newStatusQ(status).Where("status.url = ?", url).Scan(ctx) }, ) } diff --git a/internal/db/bundb/trace.go b/internal/db/bundb/trace.go index 93c23178..27b5e22a 100644 --- a/internal/db/bundb/trace.go +++ b/internal/db/bundb/trace.go @@ -47,6 +47,11 @@ func (q *debugQueryHook) AfterQuery(_ context.Context, event *bun.QueryEvent) { "operation": event.Operation(), }) + if dur > 1*time.Second { + l.Warnf("SLOW DATABASE QUERY [%s] %s", dur, event.Query) + return + } + if logrus.GetLevel() == logrus.TraceLevel { l.Tracef("[%s] %s", dur, event.Query) } else { diff --git a/internal/federation/federatingprotocol.go b/internal/federation/federatingprotocol.go index 7bcefc14..e1ca3e7e 100644 --- a/internal/federation/federatingprotocol.go +++ b/internal/federation/federatingprotocol.go @@ -134,7 +134,7 @@ func (f *federator) AuthenticatePostInbox(ctx context.Context, w http.ResponseWr // authentication has passed, so add an instance entry for this instance if it hasn't been done already i := >smodel.Instance{} - if err := f.db.GetWhere(ctx, []db.Where{{Key: "domain", Value: publicKeyOwnerURI.Host, CaseInsensitive: true}}, i); err != nil { + if err := f.db.GetWhere(ctx, []db.Where{{Key: "domain", Value: publicKeyOwnerURI.Host}}, i); err != nil { if err != db.ErrNoEntries { // there's been an actual error return ctx, false, fmt.Errorf("error getting requesting account with public key id %s: %s", publicKeyOwnerURI.String(), err) diff --git a/internal/processing/admin/createdomainblock.go b/internal/processing/admin/createdomainblock.go index 9bf7c2fd..3cfaabce 100644 --- a/internal/processing/admin/createdomainblock.go +++ b/internal/processing/admin/createdomainblock.go @@ -21,6 +21,7 @@ package admin import ( "context" "fmt" + "strings" "time" "github.com/sirupsen/logrus" @@ -35,9 +36,12 @@ import ( ) func (p *processor) DomainBlockCreate(ctx context.Context, account *gtsmodel.Account, domain string, obfuscate bool, publicComment string, privateComment string, subscriptionID string) (*apimodel.DomainBlock, gtserror.WithCode) { + // domain blocks will always be lowercase + domain = strings.ToLower(domain) + // first check if we already have a block -- if err == nil we already had a block so we can skip a whole lot of work domainBlock := >smodel.DomainBlock{} - err := p.db.GetWhere(ctx, []db.Where{{Key: "domain", Value: domain, CaseInsensitive: true}}, domainBlock) + err := p.db.GetWhere(ctx, []db.Where{{Key: "domain", Value: domain}}, domainBlock) if err != nil { if err != db.ErrNoEntries { // something went wrong in the DB @@ -95,7 +99,7 @@ func (p *processor) initiateDomainBlockSideEffects(ctx context.Context, account // if we have an instance entry for this domain, update it with the new block ID and clear all fields instance := >smodel.Instance{} - if err := p.db.GetWhere(ctx, []db.Where{{Key: "domain", Value: block.Domain, CaseInsensitive: true}}, instance); err == nil { + if err := p.db.GetWhere(ctx, []db.Where{{Key: "domain", Value: block.Domain}}, instance); err == nil { instance.Title = "" instance.UpdatedAt = time.Now() instance.SuspendedAt = time.Now() diff --git a/internal/processing/federation/getstatus.go b/internal/processing/federation/getstatus.go index 820f1a19..2cc37071 100644 --- a/internal/processing/federation/getstatus.go +++ b/internal/processing/federation/getstatus.go @@ -24,9 +24,7 @@ import ( "net/url" "github.com/superseriousbusiness/activity/streams" - "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/gtserror" - "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" ) func (p *processor) GetStatus(ctx context.Context, requestedUsername string, requestedStatusID string, requestURL *url.URL) (interface{}, gtserror.WithCode) { @@ -59,14 +57,15 @@ func (p *processor) GetStatus(ctx context.Context, requestedUsername string, req } // get the status out of the database here - s := >smodel.Status{} - if err := p.db.GetWhere(ctx, []db.Where{ - {Key: "id", Value: requestedStatusID, CaseInsensitive: true}, - {Key: "account_id", Value: requestedAccount.ID, CaseInsensitive: true}, - }, s); err != nil { + s, err := p.db.GetStatusByID(ctx, requestedStatusID) + if err != nil { return nil, gtserror.NewErrorNotFound(fmt.Errorf("database error getting status with id %s and account id %s: %s", requestedStatusID, requestedAccount.ID, err)) } + if s.AccountID != requestedAccount.ID { + return nil, gtserror.NewErrorNotFound(fmt.Errorf("status with id %s does not belong to account with id %s", s.ID, requestedAccount.ID)) + } + visible, err := p.filter.StatusVisible(ctx, s, requestingAccount) if err != nil { return nil, gtserror.NewErrorInternalError(err) diff --git a/internal/web/thread.go b/internal/web/thread.go index 4a448690..0450f6b2 100644 --- a/internal/web/thread.go +++ b/internal/web/thread.go @@ -36,14 +36,16 @@ func (m *Module) threadTemplateHandler(c *gin.Context) { ctx := c.Request.Context() - username := c.Param(usernameKey) + // usernames on our instance will always be lowercase + username := strings.ToLower(c.Param(usernameKey)) if username == "" { c.JSON(http.StatusBadRequest, gin.H{"error": "no account username specified"}) return } - statusID := c.Param(statusIDKey) - if username == "" { + // status ids will always be uppercase + statusID := strings.ToUpper(c.Param(statusIDKey)) + if statusID == "" { c.JSON(http.StatusBadRequest, gin.H{"error": "no status id specified"}) return }