Database updates (#144)

* start moving some database stuff around

* continue moving db stuff around

* more fiddling

* more updates

* and some more

* and yet more

* i broke SOMETHING but what, it's a mystery

* tidy up

* vendor ttlcache

* use ttlcache

* fix up some tests

* rename some stuff

* little reminder

* some more updates
This commit is contained in:
tobi 2021-08-20 12:26:56 +02:00 committed by GitHub
parent ce190d867c
commit 4920229a3b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
164 changed files with 4850 additions and 2617 deletions

View file

@ -137,6 +137,7 @@ The following libraries and frameworks are used by GoToSocial, with gratitude
* [mvdan/xurls](https://github.com/mvdan/xurls); URL parsing regular expressions. [BSD-3-Clause License](https://spdx.org/licenses/BSD-3-Clause.html). * [mvdan/xurls](https://github.com/mvdan/xurls); URL parsing regular expressions. [BSD-3-Clause License](https://spdx.org/licenses/BSD-3-Clause.html).
* [nfnt/resize](https://github.com/nfnt/resize); convenient image resizing. [ISC License](https://spdx.org/licenses/ISC.html). * [nfnt/resize](https://github.com/nfnt/resize); convenient image resizing. [ISC License](https://spdx.org/licenses/ISC.html).
* [oklog/ulid](https://github.com/oklog/ulid); sequential, database-friendly ID generation. [Apache-2.0 License](https://spdx.org/licenses/Apache-2.0.html). * [oklog/ulid](https://github.com/oklog/ulid); sequential, database-friendly ID generation. [Apache-2.0 License](https://spdx.org/licenses/Apache-2.0.html).
* [ReneKroon/ttlcache](https://github.com/ReneKroon/ttlcache); in-memory caching. [MIT License](https://spdx.org/licenses/MIT.html).
* [russross/blackfriday](https://github.com/russross/blackfriday); markdown parsing for statuses. [Simplified BSD License](https://spdx.org/licenses/BSD-2-Clause.html). * [russross/blackfriday](https://github.com/russross/blackfriday); markdown parsing for statuses. [Simplified BSD License](https://spdx.org/licenses/BSD-2-Clause.html).
* [sirupsen/logrus](https://github.com/sirupsen/logrus); logging. [MIT License](https://spdx.org/licenses/MIT.html). * [sirupsen/logrus](https://github.com/sirupsen/logrus); logging. [MIT License](https://spdx.org/licenses/MIT.html).
* [stretchr/testify](https://github.com/stretchr/testify); test framework. [MIT License](https://spdx.org/licenses/MIT.html). * [stretchr/testify](https://github.com/stretchr/testify); test framework. [MIT License](https://spdx.org/licenses/MIT.html).

View file

@ -1679,7 +1679,7 @@ info:
name: AGPL3 name: AGPL3
url: https://www.gnu.org/licenses/agpl-3.0.en.html url: https://www.gnu.org/licenses/agpl-3.0.en.html
title: GoToSocial title: GoToSocial
version: 0.1.0-SNAPSHOT-dereference_remote_replies version: 0.1.0-SNAPSHOT
paths: paths:
/api/v1/accounts: /api/v1/accounts:
post: post:
@ -3404,6 +3404,8 @@ paths:
description: "" description: ""
schema: schema:
$ref: '#/definitions/swaggerStatusRepliesCollection' $ref: '#/definitions/swaggerStatusRepliesCollection'
"400":
description: bad request
"401": "401":
description: unauthorized description: unauthorized
"403": "403":

1
go.mod
View file

@ -3,6 +3,7 @@ module github.com/superseriousbusiness/gotosocial
go 1.16 go 1.16
require ( require (
github.com/ReneKroon/ttlcache v1.7.0
github.com/buckket/go-blurhash v1.1.0 github.com/buckket/go-blurhash v1.1.0
github.com/coreos/go-oidc/v3 v3.0.0 github.com/coreos/go-oidc/v3 v3.0.0
github.com/cpuguy83/go-md2man/v2 v2.0.1 // indirect github.com/cpuguy83/go-md2man/v2 v2.0.1 // indirect

4
go.sum
View file

@ -33,6 +33,8 @@ cloud.google.com/go/storage v1.10.0/go.mod h1:FLPqc6j+Ki4BU591ie1oL6qBQGu2Bl/tZ9
dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7gZCb22OMCxBHrMx7a5I7Hp++hsVxbQ4BYO7hU= dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7gZCb22OMCxBHrMx7a5I7Hp++hsVxbQ4BYO7hU=
github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo= github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo=
github.com/ReneKroon/ttlcache v1.7.0 h1:8BkjFfrzVFXyrqnMtezAaJ6AHPSsVV10m6w28N/Fgkk=
github.com/ReneKroon/ttlcache v1.7.0/go.mod h1:8BGGzdumrIjWxdRx8zpK6L3oGMWvIXdvB2GD1cfvd+I=
github.com/ajg/form v1.5.1 h1:t9c7v8JUKu/XxOGBU0yjNpaMloxGEJhUkqFRq0ibGeU= github.com/ajg/form v1.5.1 h1:t9c7v8JUKu/XxOGBU0yjNpaMloxGEJhUkqFRq0ibGeU=
github.com/ajg/form v1.5.1/go.mod h1:uL1WgH+h2mgNtvBq0339dVnzXdBETtL2LeUXaIv25UY= github.com/ajg/form v1.5.1/go.mod h1:uL1WgH+h2mgNtvBq0339dVnzXdBETtL2LeUXaIv25UY=
github.com/andybalholm/brotli v1.0.0 h1:7UCwP93aiSfvWpapti8g88vVVGp2qqtGyePsSuDafo4= github.com/andybalholm/brotli v1.0.0 h1:7UCwP93aiSfvWpapti8g88vVVGp2qqtGyePsSuDafo4=
@ -425,6 +427,8 @@ go.opencensus.io v0.22.2/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw=
go.opencensus.io v0.22.3/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw= go.opencensus.io v0.22.3/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw=
go.opencensus.io v0.22.4/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw= go.opencensus.io v0.22.4/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw=
go.opentelemetry.io/otel v0.13.0/go.mod h1:dlSNewoRYikTkotEnxdmuBHgzT+k/idJSfDv/FxEnOY= go.opentelemetry.io/otel v0.13.0/go.mod h1:dlSNewoRYikTkotEnxdmuBHgzT+k/idJSfDv/FxEnOY=
go.uber.org/goleak v0.10.0 h1:G3eWbSNIskeRqtsN/1uI5B+eP73y3JUuBsv9AZjehb4=
go.uber.org/goleak v0.10.0/go.mod h1:VCZuO8V8mFPlL0F5J5GK1rtHV3DrFcQ1R8ryq7FK0aI=
golang.org/x/crypto v0.0.0-20180527072434-ab813273cd59/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20180527072434-ab813273cd59/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
golang.org/x/crypto v0.0.0-20180910181607-0e37d006457b/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20180910181607-0e37d006457b/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=

View file

@ -581,7 +581,7 @@ func ExtractMention(i Mentionable) (*gtsmodel.Mention, error) {
if hrefProp == nil || !hrefProp.IsIRI() { if hrefProp == nil || !hrefProp.IsIRI() {
return nil, errors.New("no href prop") return nil, errors.New("no href prop")
} }
mention.MentionedAccountURI = hrefProp.GetIRI().String() mention.TargetAccountURI = hrefProp.GetIRI().String()
return mention, nil return mention, nil
} }

View file

@ -116,7 +116,7 @@ func (m *Module) parseUserFromClaims(claims *oidc.Claims, ip net.IP, appID strin
return user, nil return user, nil
} }
if _, ok := err.(db.ErrNoEntries); !ok { if err != db.ErrNoEntries {
// we have an actual error in the database // we have an actual error in the database
return nil, fmt.Errorf("error checking database for email %s: %s", claims.Email, err) return nil, fmt.Errorf("error checking database for email %s: %s", claims.Email, err)
} }
@ -128,7 +128,7 @@ func (m *Module) parseUserFromClaims(claims *oidc.Claims, ip net.IP, appID strin
return nil, fmt.Errorf("user with email address %s is unconfirmed", claims.Email) return nil, fmt.Errorf("user with email address %s is unconfirmed", claims.Email)
} }
if _, ok := err.(db.ErrNoEntries); !ok { if err != db.ErrNoEntries {
// we have an actual error in the database // we have an actual error in the database
return nil, fmt.Errorf("error checking database for email %s: %s", claims.Email, err) return nil, fmt.Errorf("error checking database for email %s: %s", claims.Email, err)
} }

View file

@ -6,8 +6,6 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/go-fed/httpsig" "github.com/go-fed/httpsig"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/util" "github.com/superseriousbusiness/gotosocial/internal/util"
) )
@ -33,13 +31,13 @@ func (m *Module) SignatureCheck(c *gin.Context) {
// we managed to parse the url! // we managed to parse the url!
// if the domain is blocked we want to bail as early as possible // if the domain is blocked we want to bail as early as possible
blockedDomain, err := m.blockedDomain(requestingPublicKeyID.Host) blocked, err := m.db.IsURIBlocked(requestingPublicKeyID)
if err != nil { if err != nil {
l.Errorf("could not tell if domain %s was blocked or not: %s", requestingPublicKeyID.Host, err) l.Errorf("could not tell if domain %s was blocked or not: %s", requestingPublicKeyID.Host, err)
c.AbortWithStatus(http.StatusInternalServerError) c.AbortWithStatus(http.StatusInternalServerError)
return return
} }
if blockedDomain { if blocked {
l.Infof("domain %s is blocked", requestingPublicKeyID.Host) l.Infof("domain %s is blocked", requestingPublicKeyID.Host)
c.AbortWithStatus(http.StatusForbidden) c.AbortWithStatus(http.StatusForbidden)
return return
@ -50,20 +48,3 @@ func (m *Module) SignatureCheck(c *gin.Context) {
} }
} }
} }
func (m *Module) blockedDomain(host string) (bool, error) {
b := &gtsmodel.DomainBlock{}
err := m.db.GetWhere([]db.Where{{Key: "domain", Value: host, CaseInsensitive: true}}, b)
if err == nil {
// block exists
return true, nil
}
if _, ok := err.(db.ErrNoEntries); ok {
// there are no entries so there's no block
return false, nil
}
// there's an actual error
return false, err
}

View file

@ -18,8 +18,28 @@
package cache package cache
import (
"time"
"github.com/ReneKroon/ttlcache"
)
// Cache defines an in-memory cache that is safe to be wiped when the application is restarted // Cache defines an in-memory cache that is safe to be wiped when the application is restarted
type Cache interface { type Cache interface {
Store(k string, v interface{}) error Store(k string, v interface{}) error
Fetch(k string) (interface{}, error) Fetch(k string) (interface{}, error)
} }
type cache struct {
c *ttlcache.Cache
}
// New returns a new in-memory cache.
func New() Cache {
c := ttlcache.NewCache()
c.SetTTL(30 * time.Second)
cache := &cache{
c: c,
}
return cache
}

27
internal/cache/error.go vendored Normal file
View file

@ -0,0 +1,27 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package cache
import "errors"
// Error models an error returned by the in-memory cache.
type Error error
// ErrNotFound means that a value for the requested key was not found in the cache.
var ErrNotFound = errors.New("value not found in cache")

View file

@ -16,18 +16,13 @@
along with this program. If not, see <http://www.gnu.org/licenses/>. along with this program. If not, see <http://www.gnu.org/licenses/>.
*/ */
package pg package cache
import ( func (c *cache) Fetch(k string) (interface{}, error) {
"strings" i, stored := c.c.Get(k)
if !stored {
"github.com/superseriousbusiness/gotosocial/internal/db" return nil, ErrNotFound
)
func (ps *postgresService) Put(i interface{}) error {
_, err := ps.conn.Model(i).Insert(i)
if err != nil && strings.Contains(err.Error(), "duplicate key value violates unique constraint") {
return db.ErrAlreadyExists{}
} }
return err
return i, nil
} }

24
internal/cache/store.go vendored Normal file
View file

@ -0,0 +1,24 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package cache
func (c *cache) Store(k string, v interface{}) error {
c.c.Set(k, v)
return nil
}

View file

@ -88,8 +88,8 @@ var Confirm cliactions.GTSAction = func(ctx context.Context, c *config.Config, l
return err return err
} }
a := &gtsmodel.Account{} a, err := dbConn.GetLocalAccountByUsername(username)
if err := dbConn.GetLocalAccountByUsername(username, a); err != nil { if err != nil {
return err return err
} }
@ -123,8 +123,8 @@ var Promote cliactions.GTSAction = func(ctx context.Context, c *config.Config, l
return err return err
} }
a := &gtsmodel.Account{} a, err := dbConn.GetLocalAccountByUsername(username)
if err := dbConn.GetLocalAccountByUsername(username, a); err != nil { if err != nil {
return err return err
} }
@ -155,8 +155,8 @@ var Demote cliactions.GTSAction = func(ctx context.Context, c *config.Config, lo
return err return err
} }
a := &gtsmodel.Account{} a, err := dbConn.GetLocalAccountByUsername(username)
if err := dbConn.GetLocalAccountByUsername(username, a); err != nil { if err != nil {
return err return err
} }
@ -187,8 +187,8 @@ var Disable cliactions.GTSAction = func(ctx context.Context, c *config.Config, l
return err return err
} }
a := &gtsmodel.Account{} a, err := dbConn.GetLocalAccountByUsername(username)
if err := dbConn.GetLocalAccountByUsername(username, a); err != nil { if err != nil {
return err return err
} }
@ -233,8 +233,8 @@ var Password cliactions.GTSAction = func(ctx context.Context, c *config.Config,
return err return err
} }
a := &gtsmodel.Account{} a, err := dbConn.GetLocalAccountByUsername(username)
if err := dbConn.GetLocalAccountByUsername(username, a); err != nil { if err != nil {
return err return err
} }

View file

@ -62,6 +62,8 @@ var models []interface{} = []interface{}{
&gtsmodel.MediaAttachment{}, &gtsmodel.MediaAttachment{},
&gtsmodel.Mention{}, &gtsmodel.Mention{},
&gtsmodel.Status{}, &gtsmodel.Status{},
&gtsmodel.StatusToEmoji{},
&gtsmodel.StatusToTag{},
&gtsmodel.StatusFave{}, &gtsmodel.StatusFave{},
&gtsmodel.StatusBookmark{}, &gtsmodel.StatusBookmark{},
&gtsmodel.StatusMute{}, &gtsmodel.StatusMute{},

66
internal/db/account.go Normal file
View file

@ -0,0 +1,66 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package db
import (
"time"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
// Account contains functions related to account getting/setting/creation.
type Account interface {
// GetAccountByID returns one account with the given ID, or an error if something goes wrong.
GetAccountByID(id string) (*gtsmodel.Account, Error)
// GetAccountByURI returns one account with the given URI, or an error if something goes wrong.
GetAccountByURI(uri string) (*gtsmodel.Account, Error)
// GetAccountByURL returns one account with the given URL, or an error if something goes wrong.
GetAccountByURL(uri string) (*gtsmodel.Account, Error)
// GetLocalAccountByUsername returns an account on this instance by its username.
GetLocalAccountByUsername(username string) (*gtsmodel.Account, Error)
// GetAccountFaves fetches faves/likes created by the target accountID.
GetAccountFaves(accountID string) ([]*gtsmodel.StatusFave, Error)
// GetAccountStatusesCount is a shortcut for the common action of counting statuses produced by accountID.
CountAccountStatuses(accountID string) (int, Error)
// GetAccountStatuses is a shortcut for getting the most recent statuses. accountID is optional, if not provided
// then all statuses will be returned. If limit is set to 0, the size of the returned slice will not be limited. This can
// be very memory intensive so you probably shouldn't do this!
// In case of no entries, a 'no entries' error will be returned
GetAccountStatuses(accountID string, limit int, excludeReplies bool, maxID string, pinnedOnly bool, mediaOnly bool) ([]*gtsmodel.Status, Error)
GetAccountBlocks(accountID string, maxID string, sinceID string, limit int) ([]*gtsmodel.Account, string, string, Error)
// GetAccountLastPosted simply gets the timestamp of the most recent post by the account.
//
// The returned time will be zero if account has never posted anything.
GetAccountLastPosted(accountID string) (time.Time, Error)
// SetAccountHeaderOrAvatar sets the header or avatar for the given accountID to the given media attachment.
SetAccountHeaderOrAvatar(mediaAttachment *gtsmodel.MediaAttachment, accountID string) Error
// GetInstanceAccount returns the instance account for the given domain.
// If domain is empty, this instance account will be returned.
GetInstanceAccount(domain string) (*gtsmodel.Account, Error)
}

53
internal/db/admin.go Normal file
View file

@ -0,0 +1,53 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package db
import (
"net"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
// Admin contains functions related to instance administration (new signups etc).
type Admin interface {
// IsUsernameAvailable checks whether a given username is available on our domain.
// Returns an error if the username is already taken, or something went wrong in the db.
IsUsernameAvailable(username string) Error
// IsEmailAvailable checks whether a given email address for a new account is available to be used on our domain.
// Return an error if:
// A) the email is already associated with an account
// B) we block signups from this email domain
// C) something went wrong in the db
IsEmailAvailable(email string) Error
// NewSignup creates a new user in the database with the given parameters.
// By the time this function is called, it should be assumed that all the parameters have passed validation!
NewSignup(username string, reason string, requireApproval bool, email string, password string, signUpIP net.IP, locale string, appID string, emailVerified bool, admin bool) (*gtsmodel.User, Error)
// CreateInstanceAccount creates an account in the database with the same username as the instance host value.
// Ie., if the instance is hosted at 'example.org' the instance user will have a username of 'example.org'.
// This is needed for things like serving files that belong to the instance and not an individual user/account.
CreateInstanceAccount() Error
// CreateInstanceInstance creates an instance in the database with the same domain as the instance host value.
// Ie., if the instance is hosted at 'example.org' the instance will have a domain of 'example.org'.
// This is needed for things like serving instance information through /api/v1/instance
CreateInstanceInstance() Error
}

87
internal/db/basic.go Normal file
View file

@ -0,0 +1,87 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package db
import "context"
// Basic wraps basic database functionality.
type Basic interface {
// CreateTable creates a table for the given interface.
// For implementations that don't use tables, this can just return nil.
CreateTable(i interface{}) Error
// DropTable drops the table for the given interface.
// For implementations that don't use tables, this can just return nil.
DropTable(i interface{}) Error
// RegisterTable registers a table for use in many2many relations.
// For implementations that don't use tables, or many2many relations, this can just return nil.
RegisterTable(i interface{}) Error
// Stop should stop and close the database connection cleanly, returning an error if this is not possible.
// If the database implementation doesn't need to be stopped, this can just return nil.
Stop(ctx context.Context) Error
// IsHealthy should return nil if the database connection is healthy, or an error if not.
IsHealthy(ctx context.Context) Error
// GetByID gets one entry by its id. In a database like postgres, this might be the 'id' field of the entry,
// for other implementations (for example, in-memory) it might just be the key of a map.
// The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice.
// In case of no entries, a 'no entries' error will be returned
GetByID(id string, i interface{}) Error
// GetWhere gets one entry where key = value. This is similar to GetByID but allows the caller to specify the
// name of the key to select from.
// The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice.
// In case of no entries, a 'no entries' error will be returned
GetWhere(where []Where, i interface{}) Error
// GetAll will try to get all entries of type i.
// The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice.
// In case of no entries, a 'no entries' error will be returned
GetAll(i interface{}) Error
// Put simply stores i. It is up to the implementation to figure out how to store it, and using what key.
// The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice.
Put(i interface{}) Error
// Upsert stores or updates i based on the given conflict column, as in https://www.postgresqltutorial.com/postgresql-upsert/
// It is up to the implementation to figure out how to store it, and using what key.
// The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice.
Upsert(i interface{}, conflictColumn string) Error
// UpdateByID updates i with id id.
// The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice.
UpdateByID(id string, i interface{}) Error
// UpdateOneByID updates interface i with database the given database id. It will update one field of key key and value value.
UpdateOneByID(id string, key string, value interface{}, i interface{}) Error
// UpdateWhere updates column key of interface i with the given value, where the given parameters apply.
UpdateWhere(where []Where, key string, value interface{}, i interface{}) Error
// DeleteByID removes i with id id.
// If i didn't exist anyway, then no error should be returned.
DeleteByID(id string, i interface{}) Error
// DeleteWhere deletes i where key = value
// If i didn't exist anyway, then no error should be returned.
DeleteWhere(where []Where, i interface{}) Error
}

View file

@ -19,9 +19,6 @@
package db package db
import ( import (
"context"
"net"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
) )
@ -30,257 +27,19 @@ const (
DBTypePostgres string = "POSTGRES" DBTypePostgres string = "POSTGRES"
) )
// DB provides methods for interacting with an underlying database or other storage mechanism (for now, just postgres). // DB provides methods for interacting with an underlying database or other storage mechanism.
// Note that in all of the functions below, the passed interface should be a pointer or a slice, which will then be populated
// by whatever is returned from the database.
type DB interface { type DB interface {
/* Account
BASIC DB FUNCTIONALITY Admin
*/ Basic
Domain
// CreateTable creates a table for the given interface. Instance
// For implementations that don't use tables, this can just return nil. Media
CreateTable(i interface{}) error Mention
Notification
// DropTable drops the table for the given interface. Relationship
// For implementations that don't use tables, this can just return nil. Status
DropTable(i interface{}) error Timeline
// Stop should stop and close the database connection cleanly, returning an error if this is not possible.
// If the database implementation doesn't need to be stopped, this can just return nil.
Stop(ctx context.Context) error
// IsHealthy should return nil if the database connection is healthy, or an error if not.
IsHealthy(ctx context.Context) error
// GetByID gets one entry by its id. In a database like postgres, this might be the 'id' field of the entry,
// for other implementations (for example, in-memory) it might just be the key of a map.
// The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice.
// In case of no entries, a 'no entries' error will be returned
GetByID(id string, i interface{}) error
// GetWhere gets one entry where key = value. This is similar to GetByID but allows the caller to specify the
// name of the key to select from.
// The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice.
// In case of no entries, a 'no entries' error will be returned
GetWhere(where []Where, i interface{}) error
// GetAll will try to get all entries of type i.
// The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice.
// In case of no entries, a 'no entries' error will be returned
GetAll(i interface{}) error
// Put simply stores i. It is up to the implementation to figure out how to store it, and using what key.
// The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice.
Put(i interface{}) error
// Upsert stores or updates i based on the given conflict column, as in https://www.postgresqltutorial.com/postgresql-upsert/
// It is up to the implementation to figure out how to store it, and using what key.
// The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice.
Upsert(i interface{}, conflictColumn string) error
// UpdateByID updates i with id id.
// The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice.
UpdateByID(id string, i interface{}) error
// UpdateOneByID updates interface i with database the given database id. It will update one field of key key and value value.
UpdateOneByID(id string, key string, value interface{}, i interface{}) error
// UpdateWhere updates column key of interface i with the given value, where the given parameters apply.
UpdateWhere(where []Where, key string, value interface{}, i interface{}) error
// DeleteByID removes i with id id.
// If i didn't exist anyway, then no error should be returned.
DeleteByID(id string, i interface{}) error
// DeleteWhere deletes i where key = value
// If i didn't exist anyway, then no error should be returned.
DeleteWhere(where []Where, i interface{}) error
/*
HANDY SHORTCUTS
*/
// AcceptFollowRequest moves a follow request in the database from the follow_requests table to the follows table.
// In other words, it should create the follow, and delete the existing follow request.
//
// It will return the newly created follow for further processing.
AcceptFollowRequest(originAccountID string, targetAccountID string) (*gtsmodel.Follow, error)
// CreateInstanceAccount creates an account in the database with the same username as the instance host value.
// Ie., if the instance is hosted at 'example.org' the instance user will have a username of 'example.org'.
// This is needed for things like serving files that belong to the instance and not an individual user/account.
CreateInstanceAccount() error
// CreateInstanceInstance creates an instance in the database with the same domain as the instance host value.
// Ie., if the instance is hosted at 'example.org' the instance will have a domain of 'example.org'.
// This is needed for things like serving instance information through /api/v1/instance
CreateInstanceInstance() error
// GetAccountByUserID is a shortcut for the common action of fetching an account corresponding to a user ID.
// The given account pointer will be set to the result of the query, whatever it is.
// In case of no entries, a 'no entries' error will be returned
GetAccountByUserID(userID string, account *gtsmodel.Account) error
// GetLocalAccountByUsername is a shortcut for the common action of fetching an account ON THIS INSTANCE
// according to its username, which should be unique.
// The given account pointer will be set to the result of the query, whatever it is.
// In case of no entries, a 'no entries' error will be returned
GetLocalAccountByUsername(username string, account *gtsmodel.Account) error
// GetFollowRequestsForAccountID is a shortcut for the common action of fetching a list of follow requests targeting the given account ID.
// The given slice 'followRequests' will be set to the result of the query, whatever it is.
// In case of no entries, a 'no entries' error will be returned
GetFollowRequestsForAccountID(accountID string, followRequests *[]gtsmodel.FollowRequest) error
// GetFollowingByAccountID is a shortcut for the common action of fetching a list of accounts that accountID is following.
// The given slice 'following' will be set to the result of the query, whatever it is.
// In case of no entries, a 'no entries' error will be returned
GetFollowingByAccountID(accountID string, following *[]gtsmodel.Follow) error
// GetFollowersByAccountID is a shortcut for the common action of fetching a list of accounts that accountID is followed by.
// The given slice 'followers' will be set to the result of the query, whatever it is.
// In case of no entries, a 'no entries' error will be returned
//
// If localOnly is set to true, then only followers from *this instance* will be returned.
GetFollowersByAccountID(accountID string, followers *[]gtsmodel.Follow, localOnly bool) error
// GetFavesByAccountID is a shortcut for the common action of fetching a list of faves made by the given accountID.
// The given slice 'faves' will be set to the result of the query, whatever it is.
// In case of no entries, a 'no entries' error will be returned
GetFavesByAccountID(accountID string, faves *[]gtsmodel.StatusFave) error
// CountStatusesByAccountID is a shortcut for the common action of counting statuses produced by accountID.
CountStatusesByAccountID(accountID string) (int, error)
// GetStatusesForAccount is a shortcut for getting the most recent statuses. accountID is optional, if not provided
// then all statuses will be returned. If limit is set to 0, the size of the returned slice will not be limited. This can
// be very memory intensive so you probably shouldn't do this!
// In case of no entries, a 'no entries' error will be returned
GetStatusesForAccount(accountID string, limit int, excludeReplies bool, maxID string, pinnedOnly bool, mediaOnly bool) ([]*gtsmodel.Status, error)
GetBlocksForAccount(accountID string, maxID string, sinceID string, limit int) ([]*gtsmodel.Account, string, string, error)
// GetLastStatusForAccountID simply gets the most recent status by the given account.
// The given slice 'status' pointer will be set to the result of the query, whatever it is.
// In case of no entries, a 'no entries' error will be returned
GetLastStatusForAccountID(accountID string, status *gtsmodel.Status) error
// IsUsernameAvailable checks whether a given username is available on our domain.
// Returns an error if the username is already taken, or something went wrong in the db.
IsUsernameAvailable(username string) error
// IsEmailAvailable checks whether a given email address for a new account is available to be used on our domain.
// Return an error if:
// A) the email is already associated with an account
// B) we block signups from this email domain
// C) something went wrong in the db
IsEmailAvailable(email string) error
// NewSignup creates a new user in the database with the given parameters.
// By the time this function is called, it should be assumed that all the parameters have passed validation!
NewSignup(username string, reason string, requireApproval bool, email string, password string, signUpIP net.IP, locale string, appID string, emailVerified bool, admin bool) (*gtsmodel.User, error)
// SetHeaderOrAvatarForAccountID sets the header or avatar for the given accountID to the given media attachment.
SetHeaderOrAvatarForAccountID(mediaAttachment *gtsmodel.MediaAttachment, accountID string) error
// GetHeaderAvatarForAccountID gets the current avatar for the given account ID.
// The passed mediaAttachment pointer will be populated with the value of the avatar, if it exists.
GetAvatarForAccountID(avatar *gtsmodel.MediaAttachment, accountID string) error
// GetHeaderForAccountID gets the current header for the given account ID.
// The passed mediaAttachment pointer will be populated with the value of the header, if it exists.
GetHeaderForAccountID(header *gtsmodel.MediaAttachment, accountID string) error
// Blocked checks whether a block exists in eiher direction between two accounts.
// That is, it returns true if account1 blocks account2, OR if account2 blocks account1.
Blocked(account1 string, account2 string) (bool, error)
// GetRelationship retrieves the relationship of the targetAccount to the requestingAccount.
GetRelationship(requestingAccount string, targetAccount string) (*gtsmodel.Relationship, error)
// Follows returns true if sourceAccount follows target account, or an error if something goes wrong while finding out.
Follows(sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, error)
// FollowRequested returns true if sourceAccount has requested to follow target account, or an error if something goes wrong while finding out.
FollowRequested(sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, error)
// Mutuals returns true if account1 and account2 both follow each other, or an error if something goes wrong while finding out.
Mutuals(account1 *gtsmodel.Account, account2 *gtsmodel.Account) (bool, error)
// GetReplyCountForStatus returns the amount of replies recorded for a status, or an error if something goes wrong
GetReplyCountForStatus(status *gtsmodel.Status) (int, error)
// GetReblogCountForStatus returns the amount of reblogs/boosts recorded for a status, or an error if something goes wrong
GetReblogCountForStatus(status *gtsmodel.Status) (int, error)
// GetFaveCountForStatus returns the amount of faves/likes recorded for a status, or an error if something goes wrong
GetFaveCountForStatus(status *gtsmodel.Status) (int, error)
// StatusParents get the parent statuses of a given status.
//
// If onlyDirect is true, only the immediate parent will be returned.
StatusParents(status *gtsmodel.Status, onlyDirect bool) ([]*gtsmodel.Status, error)
// StatusChildren gets the child statuses of a given status.
//
// If onlyDirect is true, only the immediate children will be returned.
StatusChildren(status *gtsmodel.Status, onlyDirect bool, minID string) ([]*gtsmodel.Status, error)
// StatusFavedBy checks if a given status has been faved by a given account ID
StatusFavedBy(status *gtsmodel.Status, accountID string) (bool, error)
// StatusRebloggedBy checks if a given status has been reblogged/boosted by a given account ID
StatusRebloggedBy(status *gtsmodel.Status, accountID string) (bool, error)
// StatusMutedBy checks if a given status has been muted by a given account ID
StatusMutedBy(status *gtsmodel.Status, accountID string) (bool, error)
// StatusBookmarkedBy checks if a given status has been bookmarked by a given account ID
StatusBookmarkedBy(status *gtsmodel.Status, accountID string) (bool, error)
// WhoFavedStatus returns a slice of accounts who faved the given status.
// This slice will be unfiltered, not taking account of blocks and whatnot, so filter it before serving it back to a user.
WhoFavedStatus(status *gtsmodel.Status) ([]*gtsmodel.Account, error)
// WhoBoostedStatus returns a slice of accounts who boosted the given status.
// This slice will be unfiltered, not taking account of blocks and whatnot, so filter it before serving it back to a user.
WhoBoostedStatus(status *gtsmodel.Status) ([]*gtsmodel.Account, error)
// GetHomeTimelineForAccount returns a slice of statuses from accounts that are followed by the given account id.
//
// Statuses should be returned in descending order of when they were created (newest first).
GetHomeTimelineForAccount(accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, error)
// GetPublicTimelineForAccount fetches the account's PUBLIC timeline -- ie., posts and replies that are public.
// It will use the given filters and try to return as many statuses as possible up to the limit.
//
// Statuses should be returned in descending order of when they were created (newest first).
GetPublicTimelineForAccount(accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, error)
// GetFavedTimelineForAccount fetches the account's FAVED timeline -- ie., posts and replies that the requesting account has faved.
// It will use the given filters and try to return as many statuses as possible up to the limit.
//
// Note that unlike the other GetTimeline functions, the returned statuses will be arranged by their FAVE id, not the STATUS id.
// In other words, they'll be returned in descending order of when they were faved by the requesting user, not when they were created.
//
// Also note the extra return values, which correspond to the nextMaxID and prevMinID for building Link headers.
GetFavedTimelineForAccount(accountID string, maxID string, minID string, limit int) ([]*gtsmodel.Status, string, string, error)
// GetNotificationsForAccount returns a list of notifications that pertain to the given accountID.
GetNotificationsForAccount(accountID string, limit int, maxID string, sinceID string) ([]*gtsmodel.Notification, error)
// GetUserCountForInstance returns the number of known accounts registered with the given domain.
GetUserCountForInstance(domain string) (int, error)
// GetStatusCountForInstance returns the number of known statuses posted from the given domain.
GetStatusCountForInstance(domain string) (int, error)
// GetDomainCountForInstance returns the number of known instances known that the given domain federates with.
GetDomainCountForInstance(domain string) (int, error)
// GetAccountsForInstance returns a slice of accounts from the given instance, arranged by ID.
GetAccountsForInstance(domain string, maxID string, limit int) ([]*gtsmodel.Account, error)
/* /*
USEFUL CONVERSION FUNCTIONS USEFUL CONVERSION FUNCTIONS

36
internal/db/domain.go Normal file
View file

@ -0,0 +1,36 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package db
import "net/url"
// Domain contains DB functions related to domains and domain blocks.
type Domain interface {
// IsDomainBlocked checks if an instance-level domain block exists for the given domain string (eg., `example.org`).
IsDomainBlocked(domain string) (bool, Error)
// AreDomainsBlocked checks if an instance-level domain block exists for any of the given domains strings, and returns true if even one is found.
AreDomainsBlocked(domains []string) (bool, Error)
// IsURIBlocked checks if an instance-level domain block exists for the `host` in the given URI (eg., `https://example.org/users/whatever`).
IsURIBlocked(uri *url.URL) (bool, Error)
// AreURIsBlocked checks if an instance-level domain block exists for any `host` in the given URI slice, and returns true if even one is found.
AreURIsBlocked(uris []*url.URL) (bool, Error)
}

View file

@ -18,16 +18,18 @@
package db package db
// ErrNoEntries is to be returned from the DB interface when no entries are found for a given query. import "fmt"
type ErrNoEntries struct{}
func (e ErrNoEntries) Error() string { // Error denotes a database error.
return "no entries" type Error error
}
// ErrAlreadyExists is to be returned from the DB interface when an entry already exists for a given query or its constraints. var (
type ErrAlreadyExists struct{} // ErrNoEntries is returned when a caller expected an entry for a query, but none was found.
ErrNoEntries Error = fmt.Errorf("no entries")
func (e ErrAlreadyExists) Error() string { // ErrMultipleEntries is returned when a caller expected ONE entry for a query, but multiples were found.
return "already exists" ErrMultipleEntries Error = fmt.Errorf("multiple entries")
} // ErrAlreadyExists is returned when a caller tries to insert a database entry that already exists in the db.
ErrAlreadyExists Error = fmt.Errorf("already exists")
// ErrUnknown denotes an unknown database error.
ErrUnknown Error = fmt.Errorf("unknown error")
)

36
internal/db/instance.go Normal file
View file

@ -0,0 +1,36 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package db
import "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
// Instance contains functions for instance-level actions (counting instance users etc.).
type Instance interface {
// CountInstanceUsers returns the number of known accounts registered with the given domain.
CountInstanceUsers(domain string) (int, Error)
// CountInstanceStatuses returns the number of known statuses posted from the given domain.
CountInstanceStatuses(domain string) (int, Error)
// CountInstanceDomains returns the number of known instances known that the given domain federates with.
CountInstanceDomains(domain string) (int, Error)
// GetInstanceAccounts returns a slice of accounts from the given instance, arranged by ID.
GetInstanceAccounts(domain string, maxID string, limit int) ([]*gtsmodel.Account, Error)
}

27
internal/db/media.go Normal file
View file

@ -0,0 +1,27 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package db
import "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
// Media contains functions related to creating/getting/removing media attachments.
type Media interface {
// GetAttachmentByID gets a single attachment by its ID
GetAttachmentByID(id string) (*gtsmodel.MediaAttachment, Error)
}

30
internal/db/mention.go Normal file
View file

@ -0,0 +1,30 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package db
import "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
// Mention contains functions for getting/creating mentions in the database.
type Mention interface {
// GetMention gets a single mention by ID
GetMention(id string) (*gtsmodel.Mention, Error)
// GetMentions gets multiple mentions.
GetMentions(ids []string) ([]*gtsmodel.Mention, Error)
}

View file

@ -0,0 +1,31 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package db
import "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
// Notification contains functions for creating and getting notifications.
type Notification interface {
// GetNotifications returns a slice of notifications that pertain to the given accountID.
//
// Returned notifications will be ordered ID descending (ie., highest/newest to lowest/oldest).
GetNotifications(accountID string, limit int, maxID string, sinceID string) ([]*gtsmodel.Notification, Error)
// GetNotification returns one notification according to its id.
GetNotification(id string) (*gtsmodel.Notification, Error)
}

256
internal/db/pg/account.go Normal file
View file

@ -0,0 +1,256 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package pg
import (
"context"
"errors"
"fmt"
"time"
"github.com/go-pg/pg/v10"
"github.com/go-pg/pg/v10/orm"
"github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
type accountDB struct {
config *config.Config
conn *pg.DB
log *logrus.Logger
cancel context.CancelFunc
}
func (a *accountDB) newAccountQ(account *gtsmodel.Account) *orm.Query {
return a.conn.Model(account).
Relation("AvatarMediaAttachment").
Relation("HeaderMediaAttachment")
}
func (a *accountDB) GetAccountByID(id string) (*gtsmodel.Account, db.Error) {
account := &gtsmodel.Account{}
q := a.newAccountQ(account).
Where("account.id = ?", id)
err := processErrorResponse(q.Select())
return account, err
}
func (a *accountDB) GetAccountByURI(uri string) (*gtsmodel.Account, db.Error) {
account := &gtsmodel.Account{}
q := a.newAccountQ(account).
Where("account.uri = ?", uri)
err := processErrorResponse(q.Select())
return account, err
}
func (a *accountDB) GetAccountByURL(uri string) (*gtsmodel.Account, db.Error) {
account := &gtsmodel.Account{}
q := a.newAccountQ(account).
Where("account.url = ?", uri)
err := processErrorResponse(q.Select())
return account, err
}
func (a *accountDB) GetInstanceAccount(domain string) (*gtsmodel.Account, db.Error) {
account := &gtsmodel.Account{}
q := a.newAccountQ(account)
if domain == "" {
q = q.
Where("account.username = ?", domain).
Where("account.domain = ?", domain)
} else {
q = q.
Where("account.username = ?", domain).
Where("? IS NULL", pg.Ident("domain"))
}
err := processErrorResponse(q.Select())
return account, err
}
func (a *accountDB) GetAccountLastPosted(accountID string) (time.Time, db.Error) {
status := &gtsmodel.Status{}
q := a.conn.Model(status).
Order("id DESC").
Limit(1).
Where("account_id = ?", accountID).
Column("created_at")
err := processErrorResponse(q.Select())
return status.CreatedAt, err
}
func (a *accountDB) SetAccountHeaderOrAvatar(mediaAttachment *gtsmodel.MediaAttachment, accountID string) db.Error {
if mediaAttachment.Avatar && mediaAttachment.Header {
return errors.New("one media attachment cannot be both header and avatar")
}
var headerOrAVI string
if mediaAttachment.Avatar {
headerOrAVI = "avatar"
} else if mediaAttachment.Header {
headerOrAVI = "header"
} else {
return errors.New("given media attachment was neither a header nor an avatar")
}
// TODO: there are probably more side effects here that need to be handled
if _, err := a.conn.Model(mediaAttachment).OnConflict("(id) DO UPDATE").Insert(); err != nil {
return err
}
if _, err := a.conn.Model(&gtsmodel.Account{}).Set(fmt.Sprintf("%s_media_attachment_id = ?", headerOrAVI), mediaAttachment.ID).Where("id = ?", accountID).Update(); err != nil {
return err
}
return nil
}
func (a *accountDB) GetLocalAccountByUsername(username string) (*gtsmodel.Account, db.Error) {
account := &gtsmodel.Account{}
q := a.newAccountQ(account).
Where("username = ?", username).
Where("? IS NULL", pg.Ident("domain"))
err := processErrorResponse(q.Select())
return account, err
}
func (a *accountDB) GetAccountFaves(accountID string) ([]*gtsmodel.StatusFave, db.Error) {
faves := []*gtsmodel.StatusFave{}
if err := a.conn.Model(&faves).
Where("account_id = ?", accountID).
Select(); err != nil {
if err == pg.ErrNoRows {
return faves, nil
}
return nil, err
}
return faves, nil
}
func (a *accountDB) CountAccountStatuses(accountID string) (int, db.Error) {
return a.conn.Model(&gtsmodel.Status{}).Where("account_id = ?", accountID).Count()
}
func (a *accountDB) GetAccountStatuses(accountID string, limit int, excludeReplies bool, maxID string, pinnedOnly bool, mediaOnly bool) ([]*gtsmodel.Status, db.Error) {
a.log.Debugf("getting statuses for account %s", accountID)
statuses := []*gtsmodel.Status{}
q := a.conn.Model(&statuses).Order("id DESC")
if accountID != "" {
q = q.Where("account_id = ?", accountID)
}
if limit != 0 {
q = q.Limit(limit)
}
if excludeReplies {
q = q.Where("? IS NULL", pg.Ident("in_reply_to_id"))
}
if pinnedOnly {
q = q.Where("pinned = ?", true)
}
if mediaOnly {
q = q.WhereGroup(func(q *pg.Query) (*pg.Query, error) {
return q.Where("? IS NOT NULL", pg.Ident("attachments")).Where("attachments != '{}'"), nil
})
}
if maxID != "" {
q = q.Where("id < ?", maxID)
}
if err := q.Select(); err != nil {
if err == pg.ErrNoRows {
return nil, db.ErrNoEntries
}
return nil, err
}
if len(statuses) == 0 {
return nil, db.ErrNoEntries
}
a.log.Debugf("returning statuses for account %s", accountID)
return statuses, nil
}
func (a *accountDB) GetAccountBlocks(accountID string, maxID string, sinceID string, limit int) ([]*gtsmodel.Account, string, string, db.Error) {
blocks := []*gtsmodel.Block{}
fq := a.conn.Model(&blocks).
Where("block.account_id = ?", accountID).
Relation("TargetAccount").
Order("block.id DESC")
if maxID != "" {
fq = fq.Where("block.id < ?", maxID)
}
if sinceID != "" {
fq = fq.Where("block.id > ?", sinceID)
}
if limit > 0 {
fq = fq.Limit(limit)
}
err := fq.Select()
if err != nil {
if err == pg.ErrNoRows {
return nil, "", "", db.ErrNoEntries
}
return nil, "", "", err
}
if len(blocks) == 0 {
return nil, "", "", db.ErrNoEntries
}
accounts := []*gtsmodel.Account{}
for _, b := range blocks {
accounts = append(accounts, b.TargetAccount)
}
nextMaxID := blocks[len(blocks)-1].ID
prevMinID := blocks[0].ID
return accounts, nextMaxID, prevMinID, nil
}

View file

@ -0,0 +1,70 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package pg_test
import (
"testing"
"github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/testrig"
)
type AccountTestSuite struct {
PGStandardTestSuite
}
func (suite *AccountTestSuite) SetupSuite() {
suite.testTokens = testrig.NewTestTokens()
suite.testClients = testrig.NewTestClients()
suite.testApplications = testrig.NewTestApplications()
suite.testUsers = testrig.NewTestUsers()
suite.testAccounts = testrig.NewTestAccounts()
suite.testAttachments = testrig.NewTestAttachments()
suite.testStatuses = testrig.NewTestStatuses()
suite.testTags = testrig.NewTestTags()
suite.testMentions = testrig.NewTestMentions()
}
func (suite *AccountTestSuite) SetupTest() {
suite.config = testrig.NewTestConfig()
suite.db = testrig.NewTestDB()
suite.log = testrig.NewTestLog()
testrig.StandardDBSetup(suite.db, suite.testAccounts)
}
func (suite *AccountTestSuite) TearDownTest() {
testrig.StandardDBTeardown(suite.db)
}
func (suite *AccountTestSuite) TestGetAccountByIDWithExtras() {
account, err := suite.db.GetAccountByID(suite.testAccounts["local_account_1"].ID)
if err != nil {
suite.FailNow(err.Error())
}
suite.NotNil(account)
suite.NotNil(account.AvatarMediaAttachment)
suite.NotEmpty(account.AvatarMediaAttachment.URL)
suite.NotNil(account.HeaderMediaAttachment)
suite.NotEmpty(account.HeaderMediaAttachment.URL)
}
func TestAccountTestSuite(t *testing.T) {
suite.Run(t, new(AccountTestSuite))
}

235
internal/db/pg/admin.go Normal file
View file

@ -0,0 +1,235 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package pg
import (
"context"
"crypto/rand"
"crypto/rsa"
"fmt"
"net"
"net/mail"
"strings"
"time"
"github.com/go-pg/pg/v10"
"github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/id"
"github.com/superseriousbusiness/gotosocial/internal/util"
"golang.org/x/crypto/bcrypt"
)
type adminDB struct {
config *config.Config
conn *pg.DB
log *logrus.Logger
cancel context.CancelFunc
}
func (a *adminDB) IsUsernameAvailable(username string) db.Error {
// if no error we fail because it means we found something
// if error but it's not pg.ErrNoRows then we fail
// if err is pg.ErrNoRows we're good, we found nothing so continue
if err := a.conn.Model(&gtsmodel.Account{}).Where("username = ?", username).Where("domain = ?", nil).Select(); err == nil {
return fmt.Errorf("username %s already in use", username)
} else if err != pg.ErrNoRows {
return fmt.Errorf("db error: %s", err)
}
return nil
}
func (a *adminDB) IsEmailAvailable(email string) db.Error {
// parse the domain from the email
m, err := mail.ParseAddress(email)
if err != nil {
return fmt.Errorf("error parsing email address %s: %s", email, err)
}
domain := strings.Split(m.Address, "@")[1] // domain will always be the second part after @
// check if the email domain is blocked
if err := a.conn.Model(&gtsmodel.EmailDomainBlock{}).Where("domain = ?", domain).Select(); err == nil {
// fail because we found something
return fmt.Errorf("email domain %s is blocked", domain)
} else if err != pg.ErrNoRows {
// fail because we got an unexpected error
return fmt.Errorf("db error: %s", err)
}
// check if this email is associated with a user already
if err := a.conn.Model(&gtsmodel.User{}).Where("email = ?", email).WhereOr("unconfirmed_email = ?", email).Select(); err == nil {
// fail because we found something
return fmt.Errorf("email %s already in use", email)
} else if err != pg.ErrNoRows {
// fail because we got an unexpected error
return fmt.Errorf("db error: %s", err)
}
return nil
}
func (a *adminDB) NewSignup(username string, reason string, requireApproval bool, email string, password string, signUpIP net.IP, locale string, appID string, emailVerified bool, admin bool) (*gtsmodel.User, db.Error) {
key, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
a.log.Errorf("error creating new rsa key: %s", err)
return nil, err
}
// if something went wrong while creating a user, we might already have an account, so check here first...
acct := &gtsmodel.Account{}
err = a.conn.Model(acct).Where("username = ?", username).Where("? IS NULL", pg.Ident("domain")).Select()
if err != nil {
// there's been an actual error
if err != pg.ErrNoRows {
return nil, fmt.Errorf("db error checking existence of account: %s", err)
}
// we just don't have an account yet create one
newAccountURIs := util.GenerateURIsForAccount(username, a.config.Protocol, a.config.Host)
newAccountID, err := id.NewRandomULID()
if err != nil {
return nil, err
}
acct = &gtsmodel.Account{
ID: newAccountID,
Username: username,
DisplayName: username,
Reason: reason,
URL: newAccountURIs.UserURL,
PrivateKey: key,
PublicKey: &key.PublicKey,
PublicKeyURI: newAccountURIs.PublicKeyURI,
ActorType: gtsmodel.ActivityStreamsPerson,
URI: newAccountURIs.UserURI,
InboxURI: newAccountURIs.InboxURI,
OutboxURI: newAccountURIs.OutboxURI,
FollowersURI: newAccountURIs.FollowersURI,
FollowingURI: newAccountURIs.FollowingURI,
FeaturedCollectionURI: newAccountURIs.CollectionURI,
}
if _, err = a.conn.Model(acct).Insert(); err != nil {
return nil, err
}
}
pw, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err != nil {
return nil, fmt.Errorf("error hashing password: %s", err)
}
newUserID, err := id.NewRandomULID()
if err != nil {
return nil, err
}
u := &gtsmodel.User{
ID: newUserID,
AccountID: acct.ID,
EncryptedPassword: string(pw),
SignUpIP: signUpIP.To4(),
Locale: locale,
UnconfirmedEmail: email,
CreatedByApplicationID: appID,
Approved: !requireApproval, // if we don't require moderator approval, just pre-approve the user
}
if emailVerified {
u.ConfirmedAt = time.Now()
u.Email = email
}
if admin {
u.Admin = true
u.Moderator = true
}
if _, err = a.conn.Model(u).Insert(); err != nil {
return nil, err
}
return u, nil
}
func (a *adminDB) CreateInstanceAccount() db.Error {
username := a.config.Host
key, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
a.log.Errorf("error creating new rsa key: %s", err)
return err
}
aID, err := id.NewRandomULID()
if err != nil {
return err
}
newAccountURIs := util.GenerateURIsForAccount(username, a.config.Protocol, a.config.Host)
acct := &gtsmodel.Account{
ID: aID,
Username: a.config.Host,
DisplayName: username,
URL: newAccountURIs.UserURL,
PrivateKey: key,
PublicKey: &key.PublicKey,
PublicKeyURI: newAccountURIs.PublicKeyURI,
ActorType: gtsmodel.ActivityStreamsPerson,
URI: newAccountURIs.UserURI,
InboxURI: newAccountURIs.InboxURI,
OutboxURI: newAccountURIs.OutboxURI,
FollowersURI: newAccountURIs.FollowersURI,
FollowingURI: newAccountURIs.FollowingURI,
FeaturedCollectionURI: newAccountURIs.CollectionURI,
}
inserted, err := a.conn.Model(acct).Where("username = ?", username).SelectOrInsert()
if err != nil {
return err
}
if inserted {
a.log.Infof("created instance account %s with id %s", username, acct.ID)
} else {
a.log.Infof("instance account %s already exists with id %s", username, acct.ID)
}
return nil
}
func (a *adminDB) CreateInstanceInstance() db.Error {
iID, err := id.NewRandomULID()
if err != nil {
return err
}
i := &gtsmodel.Instance{
ID: iID,
Domain: a.config.Host,
Title: a.config.Host,
URI: fmt.Sprintf("%s://%s", a.config.Protocol, a.config.Host),
}
inserted, err := a.conn.Model(i).Where("domain = ?", a.config.Host).SelectOrInsert()
if err != nil {
return err
}
if inserted {
a.log.Infof("created instance instance %s with id %s", a.config.Host, i.ID)
} else {
a.log.Infof("instance instance %s already exists with id %s", a.config.Host, i.ID)
}
return nil
}

205
internal/db/pg/basic.go Normal file
View file

@ -0,0 +1,205 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package pg
import (
"context"
"errors"
"fmt"
"strings"
"github.com/go-pg/pg/v10"
"github.com/go-pg/pg/v10/orm"
"github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
)
type basicDB struct {
config *config.Config
conn *pg.DB
log *logrus.Logger
cancel context.CancelFunc
}
func (b *basicDB) Put(i interface{}) db.Error {
_, err := b.conn.Model(i).Insert(i)
if err != nil && strings.Contains(err.Error(), "duplicate key value violates unique constraint") {
return db.ErrAlreadyExists
}
return err
}
func (b *basicDB) GetByID(id string, i interface{}) db.Error {
if err := b.conn.Model(i).Where("id = ?", id).Select(); err != nil {
if err == pg.ErrNoRows {
return db.ErrNoEntries
}
return err
}
return nil
}
func (b *basicDB) GetWhere(where []db.Where, i interface{}) db.Error {
if len(where) == 0 {
return errors.New("no queries provided")
}
q := b.conn.Model(i)
for _, w := range where {
if w.Value == nil {
q = q.Where("? IS NULL", pg.Ident(w.Key))
} else {
if w.CaseInsensitive {
q = q.Where("LOWER(?) = LOWER(?)", pg.Safe(w.Key), w.Value)
} else {
q = q.Where("? = ?", pg.Safe(w.Key), w.Value)
}
}
}
if err := q.Select(); err != nil {
if err == pg.ErrNoRows {
return db.ErrNoEntries
}
return err
}
return nil
}
func (b *basicDB) GetAll(i interface{}) db.Error {
if err := b.conn.Model(i).Select(); err != nil {
if err == pg.ErrNoRows {
return db.ErrNoEntries
}
return err
}
return nil
}
func (b *basicDB) DeleteByID(id string, i interface{}) db.Error {
if _, err := b.conn.Model(i).Where("id = ?", id).Delete(); err != nil {
// if there are no rows *anyway* then that's fine
// just return err if there's an actual error
if err != pg.ErrNoRows {
return err
}
}
return nil
}
func (b *basicDB) DeleteWhere(where []db.Where, i interface{}) db.Error {
if len(where) == 0 {
return errors.New("no queries provided")
}
q := b.conn.Model(i)
for _, w := range where {
q = q.Where("? = ?", pg.Safe(w.Key), w.Value)
}
if _, err := q.Delete(); err != nil {
// if there are no rows *anyway* then that's fine
// just return err if there's an actual error
if err != pg.ErrNoRows {
return err
}
}
return nil
}
func (b *basicDB) Upsert(i interface{}, conflictColumn string) db.Error {
if _, err := b.conn.Model(i).OnConflict(fmt.Sprintf("(%s) DO UPDATE", conflictColumn)).Insert(); err != nil {
if err == pg.ErrNoRows {
return db.ErrNoEntries
}
return err
}
return nil
}
func (b *basicDB) UpdateByID(id string, i interface{}) db.Error {
if _, err := b.conn.Model(i).Where("id = ?", id).OnConflict("(id) DO UPDATE").Insert(); err != nil {
if err == pg.ErrNoRows {
return db.ErrNoEntries
}
return err
}
return nil
}
func (b *basicDB) UpdateOneByID(id string, key string, value interface{}, i interface{}) db.Error {
_, err := b.conn.Model(i).Set("? = ?", pg.Safe(key), value).Where("id = ?", id).Update()
return err
}
func (b *basicDB) UpdateWhere(where []db.Where, key string, value interface{}, i interface{}) db.Error {
q := b.conn.Model(i)
for _, w := range where {
if w.Value == nil {
q = q.Where("? IS NULL", pg.Ident(w.Key))
} else {
if w.CaseInsensitive {
q = q.Where("LOWER(?) = LOWER(?)", pg.Safe(w.Key), w.Value)
} else {
q = q.Where("? = ?", pg.Safe(w.Key), w.Value)
}
}
}
q = q.Set("? = ?", pg.Safe(key), value)
_, err := q.Update()
return err
}
func (b *basicDB) CreateTable(i interface{}) db.Error {
return b.conn.Model(i).CreateTable(&orm.CreateTableOptions{
IfNotExists: true,
})
}
func (b *basicDB) DropTable(i interface{}) db.Error {
return b.conn.Model(i).DropTable(&orm.DropTableOptions{
IfExists: true,
})
}
func (b *basicDB) RegisterTable(i interface{}) db.Error {
orm.RegisterTable(i)
return nil
}
func (b *basicDB) IsHealthy(ctx context.Context) db.Error {
return b.conn.Ping(ctx)
}
func (b *basicDB) Stop(ctx context.Context) db.Error {
b.log.Info("closing db connection")
if err := b.conn.Close(); err != nil {
// only cancel if there's a problem closing the db
b.cancel()
return err
}
return nil
}

View file

@ -1,67 +0,0 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package pg
import (
"github.com/go-pg/pg/v10"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
func (ps *postgresService) GetBlocksForAccount(accountID string, maxID string, sinceID string, limit int) ([]*gtsmodel.Account, string, string, error) {
blocks := []*gtsmodel.Block{}
fq := ps.conn.Model(&blocks).
Where("block.account_id = ?", accountID).
Relation("TargetAccount").
Order("block.id DESC")
if maxID != "" {
fq = fq.Where("block.id < ?", maxID)
}
if sinceID != "" {
fq = fq.Where("block.id > ?", sinceID)
}
if limit > 0 {
fq = fq.Limit(limit)
}
err := fq.Select()
if err != nil {
if err == pg.ErrNoRows {
return nil, "", "", db.ErrNoEntries{}
}
return nil, "", "", err
}
if len(blocks) == 0 {
return nil, "", "", db.ErrNoEntries{}
}
accounts := []*gtsmodel.Account{}
for _, b := range blocks {
accounts = append(accounts, b.TargetAccount)
}
nextMaxID := blocks[len(blocks)-1].ID
prevMinID := blocks[0].ID
return accounts, nextMaxID, prevMinID, nil
}

83
internal/db/pg/domain.go Normal file
View file

@ -0,0 +1,83 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package pg
import (
"context"
"net/url"
"github.com/go-pg/pg/v10"
"github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/util"
)
type domainDB struct {
config *config.Config
conn *pg.DB
log *logrus.Logger
cancel context.CancelFunc
}
func (d *domainDB) IsDomainBlocked(domain string) (bool, db.Error) {
if domain == "" {
return false, nil
}
blocked, err := d.conn.
Model(&gtsmodel.DomainBlock{}).
Where("LOWER(domain) = LOWER(?)", domain).
Exists()
err = processErrorResponse(err)
return blocked, err
}
func (d *domainDB) AreDomainsBlocked(domains []string) (bool, db.Error) {
// filter out any doubles
uniqueDomains := util.UniqueStrings(domains)
for _, domain := range uniqueDomains {
if blocked, err := d.IsDomainBlocked(domain); err != nil {
return false, err
} else if blocked {
return blocked, nil
}
}
// no blocks found
return false, nil
}
func (d *domainDB) IsURIBlocked(uri *url.URL) (bool, db.Error) {
domain := uri.Hostname()
return d.IsDomainBlocked(domain)
}
func (d *domainDB) AreURIsBlocked(uris []*url.URL) (bool, db.Error) {
domains := []string{}
for _, uri := range uris {
domains = append(domains, uri.Hostname())
}
return d.AreDomainsBlocked(domains)
}

View file

@ -1,75 +0,0 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package pg
import (
"errors"
"github.com/go-pg/pg/v10"
"github.com/superseriousbusiness/gotosocial/internal/db"
)
func (ps *postgresService) GetByID(id string, i interface{}) error {
if err := ps.conn.Model(i).Where("id = ?", id).Select(); err != nil {
if err == pg.ErrNoRows {
return db.ErrNoEntries{}
}
return err
}
return nil
}
func (ps *postgresService) GetWhere(where []db.Where, i interface{}) error {
if len(where) == 0 {
return errors.New("no queries provided")
}
q := ps.conn.Model(i)
for _, w := range where {
if w.Value == nil {
q = q.Where("? IS NULL", pg.Ident(w.Key))
} else {
if w.CaseInsensitive {
q = q.Where("LOWER(?) = LOWER(?)", pg.Safe(w.Key), w.Value)
} else {
q = q.Where("? = ?", pg.Safe(w.Key), w.Value)
}
}
}
if err := q.Select(); err != nil {
if err == pg.ErrNoRows {
return db.ErrNoEntries{}
}
return err
}
return nil
}
func (ps *postgresService) GetAll(i interface{}) error {
if err := ps.conn.Model(i).Select(); err != nil {
if err == pg.ErrNoRows {
return db.ErrNoEntries{}
}
return err
}
return nil
}

View file

@ -19,15 +19,26 @@
package pg package pg
import ( import (
"context"
"github.com/go-pg/pg/v10" "github.com/go-pg/pg/v10"
"github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
) )
func (ps *postgresService) GetUserCountForInstance(domain string) (int, error) { type instanceDB struct {
q := ps.conn.Model(&[]*gtsmodel.Account{}) config *config.Config
conn *pg.DB
log *logrus.Logger
cancel context.CancelFunc
}
if domain == ps.config.Host { func (i *instanceDB) CountInstanceUsers(domain string) (int, db.Error) {
q := i.conn.Model(&[]*gtsmodel.Account{})
if domain == i.config.Host {
// if the domain is *this* domain, just count where the domain field is null // if the domain is *this* domain, just count where the domain field is null
q = q.Where("? IS NULL", pg.Ident("domain")) q = q.Where("? IS NULL", pg.Ident("domain"))
} else { } else {
@ -40,10 +51,10 @@ func (ps *postgresService) GetUserCountForInstance(domain string) (int, error) {
return q.Count() return q.Count()
} }
func (ps *postgresService) GetStatusCountForInstance(domain string) (int, error) { func (i *instanceDB) CountInstanceStatuses(domain string) (int, db.Error) {
q := ps.conn.Model(&[]*gtsmodel.Status{}) q := i.conn.Model(&[]*gtsmodel.Status{})
if domain == ps.config.Host { if domain == i.config.Host {
// if the domain is *this* domain, just count where local is true // if the domain is *this* domain, just count where local is true
q = q.Where("local = ?", true) q = q.Where("local = ?", true)
} else { } else {
@ -55,10 +66,10 @@ func (ps *postgresService) GetStatusCountForInstance(domain string) (int, error)
return q.Count() return q.Count()
} }
func (ps *postgresService) GetDomainCountForInstance(domain string) (int, error) { func (i *instanceDB) CountInstanceDomains(domain string) (int, db.Error) {
q := ps.conn.Model(&[]*gtsmodel.Instance{}) q := i.conn.Model(&[]*gtsmodel.Instance{})
if domain == ps.config.Host { if domain == i.config.Host {
// if the domain is *this* domain, just count other instances it knows about // if the domain is *this* domain, just count other instances it knows about
// exclude domains that are blocked // exclude domains that are blocked
q = q.Where("domain != ?", domain).Where("? IS NULL", pg.Ident("suspended_at")) q = q.Where("domain != ?", domain).Where("? IS NULL", pg.Ident("suspended_at"))
@ -70,12 +81,12 @@ func (ps *postgresService) GetDomainCountForInstance(domain string) (int, error)
return q.Count() return q.Count()
} }
func (ps *postgresService) GetAccountsForInstance(domain string, maxID string, limit int) ([]*gtsmodel.Account, error) { func (i *instanceDB) GetInstanceAccounts(domain string, maxID string, limit int) ([]*gtsmodel.Account, db.Error) {
ps.log.Debug("GetAccountsForInstance") i.log.Debug("GetAccountsForInstance")
accounts := []*gtsmodel.Account{} accounts := []*gtsmodel.Account{}
q := ps.conn.Model(&accounts).Where("domain = ?", domain).Order("id DESC") q := i.conn.Model(&accounts).Where("domain = ?", domain).Order("id DESC")
if maxID != "" { if maxID != "" {
q = q.Where("id < ?", maxID) q = q.Where("id < ?", maxID)
@ -88,13 +99,13 @@ func (ps *postgresService) GetAccountsForInstance(domain string, maxID string, l
err := q.Select() err := q.Select()
if err != nil { if err != nil {
if err == pg.ErrNoRows { if err == pg.ErrNoRows {
return nil, db.ErrNoEntries{} return nil, db.ErrNoEntries
} }
return nil, err return nil, err
} }
if len(accounts) == 0 { if len(accounts) == 0 {
return nil, db.ErrNoEntries{} return nil, db.ErrNoEntries
} }
return accounts, nil return accounts, nil

View file

@ -19,39 +19,35 @@
package pg package pg
import ( import (
"errors" "context"
"github.com/go-pg/pg/v10" "github.com/go-pg/pg/v10"
"github.com/go-pg/pg/v10/orm"
"github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
) )
func (ps *postgresService) DeleteByID(id string, i interface{}) error { type mediaDB struct {
if _, err := ps.conn.Model(i).Where("id = ?", id).Delete(); err != nil { config *config.Config
// if there are no rows *anyway* then that's fine conn *pg.DB
// just return err if there's an actual error log *logrus.Logger
if err != pg.ErrNoRows { cancel context.CancelFunc
return err
}
}
return nil
} }
func (ps *postgresService) DeleteWhere(where []db.Where, i interface{}) error { func (m *mediaDB) newMediaQ(i interface{}) *orm.Query {
if len(where) == 0 { return m.conn.Model(i).
return errors.New("no queries provided") Relation("Account")
} }
q := ps.conn.Model(i) func (m *mediaDB) GetAttachmentByID(id string) (*gtsmodel.MediaAttachment, db.Error) {
for _, w := range where { attachment := &gtsmodel.MediaAttachment{}
q = q.Where("? = ?", pg.Safe(w.Key), w.Value)
} q := m.newMediaQ(attachment).
Where("media_attachment.id = ?", id)
if _, err := q.Delete(); err != nil {
// if there are no rows *anyway* then that's fine err := processErrorResponse(q.Select())
// just return err if there's an actual error
if err != pg.ErrNoRows { return attachment, err
return err
}
}
return nil
} }

108
internal/db/pg/mention.go Normal file
View file

@ -0,0 +1,108 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package pg
import (
"context"
"github.com/go-pg/pg/v10"
"github.com/go-pg/pg/v10/orm"
"github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/cache"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
type mentionDB struct {
config *config.Config
conn *pg.DB
log *logrus.Logger
cancel context.CancelFunc
cache cache.Cache
}
func (m *mentionDB) cacheMention(id string, mention *gtsmodel.Mention) {
if m.cache == nil {
m.cache = cache.New()
}
if err := m.cache.Store(id, mention); err != nil {
m.log.Panicf("mentionDB: error storing in cache: %s", err)
}
}
func (m *mentionDB) mentionCached(id string) (*gtsmodel.Mention, bool) {
if m.cache == nil {
m.cache = cache.New()
return nil, false
}
mI, err := m.cache.Fetch(id)
if err != nil || mI == nil {
return nil, false
}
mention, ok := mI.(*gtsmodel.Mention)
if !ok {
m.log.Panicf("mentionDB: cached interface with key %s was not a mention", id)
}
return mention, true
}
func (m *mentionDB) newMentionQ(i interface{}) *orm.Query {
return m.conn.Model(i).
Relation("Status").
Relation("OriginAccount").
Relation("TargetAccount")
}
func (m *mentionDB) GetMention(id string) (*gtsmodel.Mention, db.Error) {
if mention, cached := m.mentionCached(id); cached {
return mention, nil
}
mention := &gtsmodel.Mention{}
q := m.newMentionQ(mention).
Where("mention.id = ?", id)
err := processErrorResponse(q.Select())
if err == nil && mention != nil {
m.cacheMention(id, mention)
}
return mention, err
}
func (m *mentionDB) GetMentions(ids []string) ([]*gtsmodel.Mention, db.Error) {
mentions := []*gtsmodel.Mention{}
for _, i := range ids {
mention, err := m.GetMention(i)
if err != nil {
return nil, processErrorResponse(err)
}
mentions = append(mentions, mention)
}
return mentions, nil
}

View file

@ -0,0 +1,135 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package pg
import (
"context"
"github.com/go-pg/pg/v10"
"github.com/go-pg/pg/v10/orm"
"github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/cache"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
type notificationDB struct {
config *config.Config
conn *pg.DB
log *logrus.Logger
cancel context.CancelFunc
cache cache.Cache
}
func (n *notificationDB) cacheNotification(id string, notification *gtsmodel.Notification) {
if n.cache == nil {
n.cache = cache.New()
}
if err := n.cache.Store(id, notification); err != nil {
n.log.Panicf("notificationDB: error storing in cache: %s", err)
}
}
func (n *notificationDB) notificationCached(id string) (*gtsmodel.Notification, bool) {
if n.cache == nil {
n.cache = cache.New()
return nil, false
}
nI, err := n.cache.Fetch(id)
if err != nil || nI == nil {
return nil, false
}
notification, ok := nI.(*gtsmodel.Notification)
if !ok {
n.log.Panicf("notificationDB: cached interface with key %s was not a notification", id)
}
return notification, true
}
func (n *notificationDB) newNotificationQ(i interface{}) *orm.Query {
return n.conn.Model(i).
Relation("OriginAccount").
Relation("TargetAccount").
Relation("Status")
}
func (n *notificationDB) GetNotification(id string) (*gtsmodel.Notification, db.Error) {
if notification, cached := n.notificationCached(id); cached {
return notification, nil
}
notification := &gtsmodel.Notification{}
q := n.newNotificationQ(notification).
Where("notification.id = ?", id)
err := processErrorResponse(q.Select())
if err == nil && notification != nil {
n.cacheNotification(id, notification)
}
return notification, err
}
func (n *notificationDB) GetNotifications(accountID string, limit int, maxID string, sinceID string) ([]*gtsmodel.Notification, db.Error) {
// begin by selecting just the IDs
notifIDs := []*gtsmodel.Notification{}
q := n.conn.
Model(&notifIDs).
Column("id").
Where("target_account_id = ?", accountID).
Order("id DESC")
if maxID != "" {
q = q.Where("id < ?", maxID)
}
if sinceID != "" {
q = q.Where("id > ?", sinceID)
}
if limit != 0 {
q = q.Limit(limit)
}
err := processErrorResponse(q.Select())
if err != nil {
return nil, err
}
// now we have the IDs, select the notifs one by one
// reason for this is that for each notif, we can instead get it from our cache if it's cached
notifications := []*gtsmodel.Notification{}
for _, notifID := range notifIDs {
notif, err := n.GetNotification(notifID.ID)
errP := processErrorResponse(err)
if errP != nil {
return nil, errP
}
notifications = append(notifications, notif)
}
return notifications, nil
}

View file

@ -20,15 +20,11 @@ package pg
import ( import (
"context" "context"
"crypto/rand"
"crypto/rsa"
"crypto/tls" "crypto/tls"
"crypto/x509" "crypto/x509"
"encoding/pem" "encoding/pem"
"errors" "errors"
"fmt" "fmt"
"net"
"net/mail"
"os" "os"
"strings" "strings"
"time" "time"
@ -41,12 +37,26 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/id" "github.com/superseriousbusiness/gotosocial/internal/id"
"github.com/superseriousbusiness/gotosocial/internal/util"
"golang.org/x/crypto/bcrypt"
) )
var registerTables []interface{} = []interface{}{
&gtsmodel.StatusToEmoji{},
&gtsmodel.StatusToTag{},
}
// postgresService satisfies the DB interface // postgresService satisfies the DB interface
type postgresService struct { type postgresService struct {
db.Account
db.Admin
db.Basic
db.Domain
db.Instance
db.Media
db.Mention
db.Notification
db.Relationship
db.Status
db.Timeline
config *config.Config config *config.Config
conn *pg.DB conn *pg.DB
log *logrus.Logger log *logrus.Logger
@ -56,6 +66,11 @@ type postgresService struct {
// NewPostgresService returns a postgresService derived from the provided config, which implements the go-fed DB interface. // NewPostgresService returns a postgresService derived from the provided config, which implements the go-fed DB interface.
// Under the hood, it uses https://github.com/go-pg/pg to create and maintain a database connection. // Under the hood, it uses https://github.com/go-pg/pg to create and maintain a database connection.
func NewPostgresService(ctx context.Context, c *config.Config, log *logrus.Logger) (db.DB, error) { func NewPostgresService(ctx context.Context, c *config.Config, log *logrus.Logger) (db.DB, error) {
for _, t := range registerTables {
// https://pg.uptrace.dev/orm/many-to-many-relation/
orm.RegisterTable(t)
}
opts, err := derivePGOptions(c) opts, err := derivePGOptions(c)
if err != nil { if err != nil {
return nil, fmt.Errorf("could not create postgres service: %s", err) return nil, fmt.Errorf("could not create postgres service: %s", err)
@ -91,6 +106,72 @@ func NewPostgresService(ctx context.Context, c *config.Config, log *logrus.Logge
log.Infof("connected to postgres version: %s", version) log.Infof("connected to postgres version: %s", version)
ps := &postgresService{ ps := &postgresService{
Account: &accountDB{
config: c,
conn: conn,
log: log,
cancel: cancel,
},
Admin: &adminDB{
config: c,
conn: conn,
log: log,
cancel: cancel,
},
Basic: &basicDB{
config: c,
conn: conn,
log: log,
cancel: cancel,
},
Domain: &domainDB{
config: c,
conn: conn,
log: log,
cancel: cancel,
},
Instance: &instanceDB{
config: c,
conn: conn,
log: log,
cancel: cancel,
},
Media: &mediaDB{
config: c,
conn: conn,
log: log,
cancel: cancel,
},
Mention: &mentionDB{
config: c,
conn: conn,
log: log,
cancel: cancel,
},
Notification: &notificationDB{
config: c,
conn: conn,
log: log,
cancel: cancel,
},
Relationship: &relationshipDB{
config: c,
conn: conn,
log: log,
cancel: cancel,
},
Status: &statusDB{
config: c,
conn: conn,
log: log,
cancel: cancel,
},
Timeline: &timelineDB{
config: c,
conn: conn,
log: log,
cancel: cancel,
},
config: c, config: c,
conn: conn, conn: conn,
log: log, log: log,
@ -199,724 +280,6 @@ func derivePGOptions(c *config.Config) (*pg.Options, error) {
return options, nil return options, nil
} }
/*
BASIC DB FUNCTIONALITY
*/
func (ps *postgresService) CreateTable(i interface{}) error {
return ps.conn.Model(i).CreateTable(&orm.CreateTableOptions{
IfNotExists: true,
})
}
func (ps *postgresService) DropTable(i interface{}) error {
return ps.conn.Model(i).DropTable(&orm.DropTableOptions{
IfExists: true,
})
}
func (ps *postgresService) Stop(ctx context.Context) error {
ps.log.Info("closing db connection")
if err := ps.conn.Close(); err != nil {
// only cancel if there's a problem closing the db
ps.cancel()
return err
}
return nil
}
func (ps *postgresService) IsHealthy(ctx context.Context) error {
return ps.conn.Ping(ctx)
}
func (ps *postgresService) CreateSchema(ctx context.Context) error {
models := []interface{}{
(*gtsmodel.Account)(nil),
(*gtsmodel.Status)(nil),
(*gtsmodel.User)(nil),
}
ps.log.Info("creating db schema")
for _, model := range models {
err := ps.conn.Model(model).CreateTable(&orm.CreateTableOptions{
IfNotExists: true,
})
if err != nil {
return err
}
}
ps.log.Info("db schema created")
return nil
}
/*
HANDY SHORTCUTS
*/
func (ps *postgresService) AcceptFollowRequest(originAccountID string, targetAccountID string) (*gtsmodel.Follow, error) {
// make sure the original follow request exists
fr := &gtsmodel.FollowRequest{}
if err := ps.conn.Model(fr).Where("account_id = ?", originAccountID).Where("target_account_id = ?", targetAccountID).Select(); err != nil {
if err == pg.ErrMultiRows {
return nil, db.ErrNoEntries{}
}
return nil, err
}
// create a new follow to 'replace' the request with
follow := &gtsmodel.Follow{
ID: fr.ID,
AccountID: originAccountID,
TargetAccountID: targetAccountID,
URI: fr.URI,
}
// if the follow already exists, just update the URI -- we don't need to do anything else
if _, err := ps.conn.Model(follow).OnConflict("ON CONSTRAINT follows_account_id_target_account_id_key DO UPDATE set uri = ?", follow.URI).Insert(); err != nil {
return nil, err
}
// now remove the follow request
if _, err := ps.conn.Model(&gtsmodel.FollowRequest{}).Where("account_id = ?", originAccountID).Where("target_account_id = ?", targetAccountID).Delete(); err != nil {
return nil, err
}
return follow, nil
}
func (ps *postgresService) CreateInstanceAccount() error {
username := ps.config.Host
key, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
ps.log.Errorf("error creating new rsa key: %s", err)
return err
}
aID, err := id.NewRandomULID()
if err != nil {
return err
}
newAccountURIs := util.GenerateURIsForAccount(username, ps.config.Protocol, ps.config.Host)
a := &gtsmodel.Account{
ID: aID,
Username: ps.config.Host,
DisplayName: username,
URL: newAccountURIs.UserURL,
PrivateKey: key,
PublicKey: &key.PublicKey,
PublicKeyURI: newAccountURIs.PublicKeyURI,
ActorType: gtsmodel.ActivityStreamsPerson,
URI: newAccountURIs.UserURI,
InboxURI: newAccountURIs.InboxURI,
OutboxURI: newAccountURIs.OutboxURI,
FollowersURI: newAccountURIs.FollowersURI,
FollowingURI: newAccountURIs.FollowingURI,
FeaturedCollectionURI: newAccountURIs.CollectionURI,
}
inserted, err := ps.conn.Model(a).Where("username = ?", username).SelectOrInsert()
if err != nil {
return err
}
if inserted {
ps.log.Infof("created instance account %s with id %s", username, a.ID)
} else {
ps.log.Infof("instance account %s already exists with id %s", username, a.ID)
}
return nil
}
func (ps *postgresService) CreateInstanceInstance() error {
iID, err := id.NewRandomULID()
if err != nil {
return err
}
i := &gtsmodel.Instance{
ID: iID,
Domain: ps.config.Host,
Title: ps.config.Host,
URI: fmt.Sprintf("%s://%s", ps.config.Protocol, ps.config.Host),
}
inserted, err := ps.conn.Model(i).Where("domain = ?", ps.config.Host).SelectOrInsert()
if err != nil {
return err
}
if inserted {
ps.log.Infof("created instance instance %s with id %s", ps.config.Host, i.ID)
} else {
ps.log.Infof("instance instance %s already exists with id %s", ps.config.Host, i.ID)
}
return nil
}
func (ps *postgresService) GetAccountByUserID(userID string, account *gtsmodel.Account) error {
user := &gtsmodel.User{
ID: userID,
}
if err := ps.conn.Model(user).Where("id = ?", userID).Select(); err != nil {
if err == pg.ErrNoRows {
return db.ErrNoEntries{}
}
return err
}
if err := ps.conn.Model(account).Where("id = ?", user.AccountID).Select(); err != nil {
if err == pg.ErrNoRows {
return db.ErrNoEntries{}
}
return err
}
return nil
}
func (ps *postgresService) GetLocalAccountByUsername(username string, account *gtsmodel.Account) error {
if err := ps.conn.Model(account).Where("username = ?", username).Where("? IS NULL", pg.Ident("domain")).Select(); err != nil {
if err == pg.ErrNoRows {
return db.ErrNoEntries{}
}
return err
}
return nil
}
func (ps *postgresService) GetFollowRequestsForAccountID(accountID string, followRequests *[]gtsmodel.FollowRequest) error {
if err := ps.conn.Model(followRequests).Where("target_account_id = ?", accountID).Select(); err != nil {
if err == pg.ErrNoRows {
return nil
}
return err
}
return nil
}
func (ps *postgresService) GetFollowingByAccountID(accountID string, following *[]gtsmodel.Follow) error {
if err := ps.conn.Model(following).Where("account_id = ?", accountID).Select(); err != nil {
if err == pg.ErrNoRows {
return nil
}
return err
}
return nil
}
func (ps *postgresService) GetFollowersByAccountID(accountID string, followers *[]gtsmodel.Follow, localOnly bool) error {
q := ps.conn.Model(followers)
if localOnly {
// for local accounts let's get where domain is null OR where domain is an empty string, just to be safe
whereGroup := func(q *pg.Query) (*pg.Query, error) {
q = q.
WhereOr("? IS NULL", pg.Ident("a.domain")).
WhereOr("a.domain = ?", "")
return q, nil
}
q = q.ColumnExpr("follow.*").
Join("JOIN accounts AS a ON follow.account_id = TEXT(a.id)").
Where("follow.target_account_id = ?", accountID).
WhereGroup(whereGroup)
} else {
q = q.Where("target_account_id = ?", accountID)
}
if err := q.Select(); err != nil {
if err == pg.ErrNoRows {
return nil
}
return err
}
return nil
}
func (ps *postgresService) GetFavesByAccountID(accountID string, faves *[]gtsmodel.StatusFave) error {
if err := ps.conn.Model(faves).Where("account_id = ?", accountID).Select(); err != nil {
if err == pg.ErrNoRows {
return nil
}
return err
}
return nil
}
func (ps *postgresService) CountStatusesByAccountID(accountID string) (int, error) {
count, err := ps.conn.Model(&gtsmodel.Status{}).Where("account_id = ?", accountID).Count()
if err != nil {
if err == pg.ErrNoRows {
return 0, nil
}
return 0, err
}
return count, nil
}
func (ps *postgresService) GetStatusesForAccount(accountID string, limit int, excludeReplies bool, maxID string, pinnedOnly bool, mediaOnly bool) ([]*gtsmodel.Status, error) {
ps.log.Debugf("getting statuses for account %s", accountID)
statuses := []*gtsmodel.Status{}
q := ps.conn.Model(&statuses).Order("id DESC")
if accountID != "" {
q = q.Where("account_id = ?", accountID)
}
if limit != 0 {
q = q.Limit(limit)
}
if excludeReplies {
q = q.Where("? IS NULL", pg.Ident("in_reply_to_id"))
}
if pinnedOnly {
q = q.Where("pinned = ?", true)
}
if mediaOnly {
q = q.WhereGroup(func(q *pg.Query) (*pg.Query, error) {
return q.Where("? IS NOT NULL", pg.Ident("attachments")).Where("attachments != '{}'"), nil
})
}
if maxID != "" {
q = q.Where("id < ?", maxID)
}
if err := q.Select(); err != nil {
if err == pg.ErrNoRows {
return nil, db.ErrNoEntries{}
}
return nil, err
}
if len(statuses) == 0 {
return nil, db.ErrNoEntries{}
}
ps.log.Debugf("returning statuses for account %s", accountID)
return statuses, nil
}
func (ps *postgresService) GetLastStatusForAccountID(accountID string, status *gtsmodel.Status) error {
if err := ps.conn.Model(status).Order("created_at DESC").Limit(1).Where("account_id = ?", accountID).Select(); err != nil {
if err == pg.ErrNoRows {
return db.ErrNoEntries{}
}
return err
}
return nil
}
func (ps *postgresService) IsUsernameAvailable(username string) error {
// if no error we fail because it means we found something
// if error but it's not pg.ErrNoRows then we fail
// if err is pg.ErrNoRows we're good, we found nothing so continue
if err := ps.conn.Model(&gtsmodel.Account{}).Where("username = ?", username).Where("domain = ?", nil).Select(); err == nil {
return fmt.Errorf("username %s already in use", username)
} else if err != pg.ErrNoRows {
return fmt.Errorf("db error: %s", err)
}
return nil
}
func (ps *postgresService) IsEmailAvailable(email string) error {
// parse the domain from the email
m, err := mail.ParseAddress(email)
if err != nil {
return fmt.Errorf("error parsing email address %s: %s", email, err)
}
domain := strings.Split(m.Address, "@")[1] // domain will always be the second part after @
// check if the email domain is blocked
if err := ps.conn.Model(&gtsmodel.EmailDomainBlock{}).Where("domain = ?", domain).Select(); err == nil {
// fail because we found something
return fmt.Errorf("email domain %s is blocked", domain)
} else if err != pg.ErrNoRows {
// fail because we got an unexpected error
return fmt.Errorf("db error: %s", err)
}
// check if this email is associated with a user already
if err := ps.conn.Model(&gtsmodel.User{}).Where("email = ?", email).WhereOr("unconfirmed_email = ?", email).Select(); err == nil {
// fail because we found something
return fmt.Errorf("email %s already in use", email)
} else if err != pg.ErrNoRows {
// fail because we got an unexpected error
return fmt.Errorf("db error: %s", err)
}
return nil
}
func (ps *postgresService) NewSignup(username string, reason string, requireApproval bool, email string, password string, signUpIP net.IP, locale string, appID string, emailVerified bool, admin bool) (*gtsmodel.User, error) {
key, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
ps.log.Errorf("error creating new rsa key: %s", err)
return nil, err
}
// if something went wrong while creating a user, we might already have an account, so check here first...
a := &gtsmodel.Account{}
err = ps.conn.Model(a).Where("username = ?", username).Where("? IS NULL", pg.Ident("domain")).Select()
if err != nil {
// there's been an actual error
if err != pg.ErrNoRows {
return nil, fmt.Errorf("db error checking existence of account: %s", err)
}
// we just don't have an account yet create one
newAccountURIs := util.GenerateURIsForAccount(username, ps.config.Protocol, ps.config.Host)
newAccountID, err := id.NewRandomULID()
if err != nil {
return nil, err
}
a = &gtsmodel.Account{
ID: newAccountID,
Username: username,
DisplayName: username,
Reason: reason,
URL: newAccountURIs.UserURL,
PrivateKey: key,
PublicKey: &key.PublicKey,
PublicKeyURI: newAccountURIs.PublicKeyURI,
ActorType: gtsmodel.ActivityStreamsPerson,
URI: newAccountURIs.UserURI,
InboxURI: newAccountURIs.InboxURI,
OutboxURI: newAccountURIs.OutboxURI,
FollowersURI: newAccountURIs.FollowersURI,
FollowingURI: newAccountURIs.FollowingURI,
FeaturedCollectionURI: newAccountURIs.CollectionURI,
}
if _, err = ps.conn.Model(a).Insert(); err != nil {
return nil, err
}
}
pw, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err != nil {
return nil, fmt.Errorf("error hashing password: %s", err)
}
newUserID, err := id.NewRandomULID()
if err != nil {
return nil, err
}
u := &gtsmodel.User{
ID: newUserID,
AccountID: a.ID,
EncryptedPassword: string(pw),
SignUpIP: signUpIP.To4(),
Locale: locale,
UnconfirmedEmail: email,
CreatedByApplicationID: appID,
Approved: !requireApproval, // if we don't require moderator approval, just pre-approve the user
}
if emailVerified {
u.ConfirmedAt = time.Now()
u.Email = email
}
if admin {
u.Admin = true
u.Moderator = true
}
if _, err = ps.conn.Model(u).Insert(); err != nil {
return nil, err
}
return u, nil
}
func (ps *postgresService) SetHeaderOrAvatarForAccountID(mediaAttachment *gtsmodel.MediaAttachment, accountID string) error {
if mediaAttachment.Avatar && mediaAttachment.Header {
return errors.New("one media attachment cannot be both header and avatar")
}
var headerOrAVI string
if mediaAttachment.Avatar {
headerOrAVI = "avatar"
} else if mediaAttachment.Header {
headerOrAVI = "header"
} else {
return errors.New("given media attachment was neither a header nor an avatar")
}
// TODO: there are probably more side effects here that need to be handled
if _, err := ps.conn.Model(mediaAttachment).OnConflict("(id) DO UPDATE").Insert(); err != nil {
return err
}
if _, err := ps.conn.Model(&gtsmodel.Account{}).Set(fmt.Sprintf("%s_media_attachment_id = ?", headerOrAVI), mediaAttachment.ID).Where("id = ?", accountID).Update(); err != nil {
return err
}
return nil
}
func (ps *postgresService) GetHeaderForAccountID(header *gtsmodel.MediaAttachment, accountID string) error {
acct := &gtsmodel.Account{}
if err := ps.conn.Model(acct).Where("id = ?", accountID).Select(); err != nil {
if err == pg.ErrNoRows {
return db.ErrNoEntries{}
}
return err
}
if acct.HeaderMediaAttachmentID == "" {
return db.ErrNoEntries{}
}
if err := ps.conn.Model(header).Where("id = ?", acct.HeaderMediaAttachmentID).Select(); err != nil {
if err == pg.ErrNoRows {
return db.ErrNoEntries{}
}
return err
}
return nil
}
func (ps *postgresService) GetAvatarForAccountID(avatar *gtsmodel.MediaAttachment, accountID string) error {
acct := &gtsmodel.Account{}
if err := ps.conn.Model(acct).Where("id = ?", accountID).Select(); err != nil {
if err == pg.ErrNoRows {
return db.ErrNoEntries{}
}
return err
}
if acct.AvatarMediaAttachmentID == "" {
return db.ErrNoEntries{}
}
if err := ps.conn.Model(avatar).Where("id = ?", acct.AvatarMediaAttachmentID).Select(); err != nil {
if err == pg.ErrNoRows {
return db.ErrNoEntries{}
}
return err
}
return nil
}
func (ps *postgresService) Blocked(account1 string, account2 string) (bool, error) {
// TODO: check domain blocks as well
var blocked bool
if err := ps.conn.Model(&gtsmodel.Block{}).
Where("account_id = ?", account1).Where("target_account_id = ?", account2).
WhereOr("target_account_id = ?", account1).Where("account_id = ?", account2).
Select(); err != nil {
if err == pg.ErrNoRows {
blocked = false
return blocked, nil
}
return blocked, err
}
blocked = true
return blocked, nil
}
func (ps *postgresService) GetRelationship(requestingAccount string, targetAccount string) (*gtsmodel.Relationship, error) {
r := &gtsmodel.Relationship{
ID: targetAccount,
}
// check if the requesting account follows the target account
follow := &gtsmodel.Follow{}
if err := ps.conn.Model(follow).Where("account_id = ?", requestingAccount).Where("target_account_id = ?", targetAccount).Select(); err != nil {
if err != pg.ErrNoRows {
// a proper error
return nil, fmt.Errorf("getrelationship: error checking follow existence: %s", err)
}
// no follow exists so these are all false
r.Following = false
r.ShowingReblogs = false
r.Notifying = false
} else {
// follow exists so we can fill these fields out...
r.Following = true
r.ShowingReblogs = follow.ShowReblogs
r.Notifying = follow.Notify
}
// check if the target account follows the requesting account
followedBy, err := ps.conn.Model(&gtsmodel.Follow{}).Where("account_id = ?", targetAccount).Where("target_account_id = ?", requestingAccount).Exists()
if err != nil {
return nil, fmt.Errorf("getrelationship: error checking followed_by existence: %s", err)
}
r.FollowedBy = followedBy
// check if the requesting account blocks the target account
blocking, err := ps.conn.Model(&gtsmodel.Block{}).Where("account_id = ?", requestingAccount).Where("target_account_id = ?", targetAccount).Exists()
if err != nil {
return nil, fmt.Errorf("getrelationship: error checking blocking existence: %s", err)
}
r.Blocking = blocking
// check if the target account blocks the requesting account
blockedBy, err := ps.conn.Model(&gtsmodel.Block{}).Where("account_id = ?", targetAccount).Where("target_account_id = ?", requestingAccount).Exists()
if err != nil {
return nil, fmt.Errorf("getrelationship: error checking blocked existence: %s", err)
}
r.BlockedBy = blockedBy
// check if there's a pending following request from requesting account to target account
requested, err := ps.conn.Model(&gtsmodel.FollowRequest{}).Where("account_id = ?", requestingAccount).Where("target_account_id = ?", targetAccount).Exists()
if err != nil {
return nil, fmt.Errorf("getrelationship: error checking blocked existence: %s", err)
}
r.Requested = requested
return r, nil
}
func (ps *postgresService) Follows(sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, error) {
if sourceAccount == nil || targetAccount == nil {
return false, nil
}
return ps.conn.Model(&gtsmodel.Follow{}).Where("account_id = ?", sourceAccount.ID).Where("target_account_id = ?", targetAccount.ID).Exists()
}
func (ps *postgresService) FollowRequested(sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, error) {
if sourceAccount == nil || targetAccount == nil {
return false, nil
}
return ps.conn.Model(&gtsmodel.FollowRequest{}).Where("account_id = ?", sourceAccount.ID).Where("target_account_id = ?", targetAccount.ID).Exists()
}
func (ps *postgresService) Mutuals(account1 *gtsmodel.Account, account2 *gtsmodel.Account) (bool, error) {
if account1 == nil || account2 == nil {
return false, nil
}
// make sure account 1 follows account 2
f1, err := ps.conn.Model(&gtsmodel.Follow{}).Where("account_id = ?", account1.ID).Where("target_account_id = ?", account2.ID).Exists()
if err != nil {
if err == pg.ErrNoRows {
return false, nil
}
return false, err
}
// make sure account 2 follows account 1
f2, err := ps.conn.Model(&gtsmodel.Follow{}).Where("account_id = ?", account2.ID).Where("target_account_id = ?", account1.ID).Exists()
if err != nil {
if err == pg.ErrNoRows {
return false, nil
}
return false, err
}
return f1 && f2, nil
}
func (ps *postgresService) GetReplyCountForStatus(status *gtsmodel.Status) (int, error) {
return ps.conn.Model(&gtsmodel.Status{}).Where("in_reply_to_id = ?", status.ID).Count()
}
func (ps *postgresService) GetReblogCountForStatus(status *gtsmodel.Status) (int, error) {
return ps.conn.Model(&gtsmodel.Status{}).Where("boost_of_id = ?", status.ID).Count()
}
func (ps *postgresService) GetFaveCountForStatus(status *gtsmodel.Status) (int, error) {
return ps.conn.Model(&gtsmodel.StatusFave{}).Where("status_id = ?", status.ID).Count()
}
func (ps *postgresService) StatusFavedBy(status *gtsmodel.Status, accountID string) (bool, error) {
return ps.conn.Model(&gtsmodel.StatusFave{}).Where("status_id = ?", status.ID).Where("account_id = ?", accountID).Exists()
}
func (ps *postgresService) StatusRebloggedBy(status *gtsmodel.Status, accountID string) (bool, error) {
return ps.conn.Model(&gtsmodel.Status{}).Where("boost_of_id = ?", status.ID).Where("account_id = ?", accountID).Exists()
}
func (ps *postgresService) StatusMutedBy(status *gtsmodel.Status, accountID string) (bool, error) {
return ps.conn.Model(&gtsmodel.StatusMute{}).Where("status_id = ?", status.ID).Where("account_id = ?", accountID).Exists()
}
func (ps *postgresService) StatusBookmarkedBy(status *gtsmodel.Status, accountID string) (bool, error) {
return ps.conn.Model(&gtsmodel.StatusBookmark{}).Where("status_id = ?", status.ID).Where("account_id = ?", accountID).Exists()
}
func (ps *postgresService) WhoFavedStatus(status *gtsmodel.Status) ([]*gtsmodel.Account, error) {
accounts := []*gtsmodel.Account{}
faves := []*gtsmodel.StatusFave{}
if err := ps.conn.Model(&faves).Where("status_id = ?", status.ID).Select(); err != nil {
if err == pg.ErrNoRows {
return accounts, nil // no rows just means nobody has faved this status, so that's fine
}
return nil, err // an actual error has occurred
}
for _, f := range faves {
acc := &gtsmodel.Account{}
if err := ps.conn.Model(acc).Where("id = ?", f.AccountID).Select(); err != nil {
if err == pg.ErrNoRows {
continue // the account doesn't exist for some reason??? but this isn't the place to worry about that so just skip it
}
return nil, err // an actual error has occurred
}
accounts = append(accounts, acc)
}
return accounts, nil
}
func (ps *postgresService) WhoBoostedStatus(status *gtsmodel.Status) ([]*gtsmodel.Account, error) {
accounts := []*gtsmodel.Account{}
boosts := []*gtsmodel.Status{}
if err := ps.conn.Model(&boosts).Where("boost_of_id = ?", status.ID).Select(); err != nil {
if err == pg.ErrNoRows {
return accounts, nil // no rows just means nobody has boosted this status, so that's fine
}
return nil, err // an actual error has occurred
}
for _, f := range boosts {
acc := &gtsmodel.Account{}
if err := ps.conn.Model(acc).Where("id = ?", f.AccountID).Select(); err != nil {
if err == pg.ErrNoRows {
continue // the account doesn't exist for some reason??? but this isn't the place to worry about that so just skip it
}
return nil, err // an actual error has occurred
}
accounts = append(accounts, acc)
}
return accounts, nil
}
func (ps *postgresService) GetNotificationsForAccount(accountID string, limit int, maxID string, sinceID string) ([]*gtsmodel.Notification, error) {
notifications := []*gtsmodel.Notification{}
q := ps.conn.Model(&notifications).Where("target_account_id = ?", accountID)
if maxID != "" {
q = q.Where("id < ?", maxID)
}
if sinceID != "" {
q = q.Where("id > ?", sinceID)
}
if limit != 0 {
q = q.Limit(limit)
}
q = q.Order("created_at DESC")
if err := q.Select(); err != nil {
if err != pg.ErrNoRows {
return nil, err
}
}
return notifications, nil
}
/* /*
CONVERSION FUNCTIONS CONVERSION FUNCTIONS
*/ */
@ -988,14 +351,14 @@ func (ps *postgresService) MentionStringsToMentions(targetAccounts []string, ori
// id, createdAt and updatedAt will be populated by the db, so we have everything we need! // id, createdAt and updatedAt will be populated by the db, so we have everything we need!
menchies = append(menchies, &gtsmodel.Mention{ menchies = append(menchies, &gtsmodel.Mention{
StatusID: statusID, StatusID: statusID,
OriginAccountID: ogAccount.ID, OriginAccountID: ogAccount.ID,
OriginAccountURI: ogAccount.URI, OriginAccountURI: ogAccount.URI,
TargetAccountID: mentionedAccount.ID, TargetAccountID: mentionedAccount.ID,
NameString: a, NameString: a,
MentionedAccountURI: mentionedAccount.URI, TargetAccountURI: mentionedAccount.URI,
MentionedAccountURL: mentionedAccount.URL, TargetAccountURL: mentionedAccount.URL,
GTSAccount: mentionedAccount, OriginAccount: mentionedAccount,
}) })
} }
return menchies, nil return menchies, nil

47
internal/db/pg/pg_test.go Normal file
View file

@ -0,0 +1,47 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package pg_test
import (
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/oauth"
)
type PGStandardTestSuite struct {
// standard suite interfaces
suite.Suite
config *config.Config
db db.DB
log *logrus.Logger
// standard suite models
testTokens map[string]*oauth.Token
testClients map[string]*oauth.Client
testApplications map[string]*gtsmodel.Application
testUsers map[string]*gtsmodel.User
testAccounts map[string]*gtsmodel.Account
testAttachments map[string]*gtsmodel.MediaAttachment
testStatuses map[string]*gtsmodel.Status
testTags map[string]*gtsmodel.Tag
testMentions map[string]*gtsmodel.Mention
}

View file

@ -0,0 +1,276 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package pg
import (
"context"
"fmt"
"github.com/go-pg/pg/v10"
"github.com/go-pg/pg/v10/orm"
"github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
type relationshipDB struct {
config *config.Config
conn *pg.DB
log *logrus.Logger
cancel context.CancelFunc
}
func (r *relationshipDB) newBlockQ(block *gtsmodel.Block) *orm.Query {
return r.conn.Model(block).
Relation("Account").
Relation("TargetAccount")
}
func (r *relationshipDB) newFollowQ(follow interface{}) *orm.Query {
return r.conn.Model(follow).
Relation("Account").
Relation("TargetAccount")
}
func (r *relationshipDB) IsBlocked(account1 string, account2 string, eitherDirection bool) (bool, db.Error) {
q := r.conn.
Model(&gtsmodel.Block{}).
Where("account_id = ?", account1).
Where("target_account_id = ?", account2)
if eitherDirection {
q = q.
WhereOr("target_account_id = ?", account1).
Where("account_id = ?", account2)
}
return q.Exists()
}
func (r *relationshipDB) GetBlock(account1 string, account2 string) (*gtsmodel.Block, db.Error) {
block := &gtsmodel.Block{}
q := r.newBlockQ(block).
Where("block.account_id = ?", account1).
Where("block.target_account_id = ?", account2)
err := processErrorResponse(q.Select())
return block, err
}
func (r *relationshipDB) GetRelationship(requestingAccount string, targetAccount string) (*gtsmodel.Relationship, db.Error) {
rel := &gtsmodel.Relationship{
ID: targetAccount,
}
// check if the requesting account follows the target account
follow := &gtsmodel.Follow{}
if err := r.conn.Model(follow).Where("account_id = ?", requestingAccount).Where("target_account_id = ?", targetAccount).Select(); err != nil {
if err != pg.ErrNoRows {
// a proper error
return nil, fmt.Errorf("getrelationship: error checking follow existence: %s", err)
}
// no follow exists so these are all false
rel.Following = false
rel.ShowingReblogs = false
rel.Notifying = false
} else {
// follow exists so we can fill these fields out...
rel.Following = true
rel.ShowingReblogs = follow.ShowReblogs
rel.Notifying = follow.Notify
}
// check if the target account follows the requesting account
followedBy, err := r.conn.Model(&gtsmodel.Follow{}).Where("account_id = ?", targetAccount).Where("target_account_id = ?", requestingAccount).Exists()
if err != nil {
return nil, fmt.Errorf("getrelationship: error checking followed_by existence: %s", err)
}
rel.FollowedBy = followedBy
// check if the requesting account blocks the target account
blocking, err := r.conn.Model(&gtsmodel.Block{}).Where("account_id = ?", requestingAccount).Where("target_account_id = ?", targetAccount).Exists()
if err != nil {
return nil, fmt.Errorf("getrelationship: error checking blocking existence: %s", err)
}
rel.Blocking = blocking
// check if the target account blocks the requesting account
blockedBy, err := r.conn.Model(&gtsmodel.Block{}).Where("account_id = ?", targetAccount).Where("target_account_id = ?", requestingAccount).Exists()
if err != nil {
return nil, fmt.Errorf("getrelationship: error checking blocked existence: %s", err)
}
rel.BlockedBy = blockedBy
// check if there's a pending following request from requesting account to target account
requested, err := r.conn.Model(&gtsmodel.FollowRequest{}).Where("account_id = ?", requestingAccount).Where("target_account_id = ?", targetAccount).Exists()
if err != nil {
return nil, fmt.Errorf("getrelationship: error checking blocked existence: %s", err)
}
rel.Requested = requested
return rel, nil
}
func (r *relationshipDB) IsFollowing(sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, db.Error) {
if sourceAccount == nil || targetAccount == nil {
return false, nil
}
q := r.conn.
Model(&gtsmodel.Follow{}).
Where("account_id = ?", sourceAccount.ID).
Where("target_account_id = ?", targetAccount.ID)
return q.Exists()
}
func (r *relationshipDB) IsFollowRequested(sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, db.Error) {
if sourceAccount == nil || targetAccount == nil {
return false, nil
}
q := r.conn.
Model(&gtsmodel.FollowRequest{}).
Where("account_id = ?", sourceAccount.ID).
Where("target_account_id = ?", targetAccount.ID)
return q.Exists()
}
func (r *relationshipDB) IsMutualFollowing(account1 *gtsmodel.Account, account2 *gtsmodel.Account) (bool, db.Error) {
if account1 == nil || account2 == nil {
return false, nil
}
// make sure account 1 follows account 2
f1, err := r.IsFollowing(account1, account2)
if err != nil {
return false, processErrorResponse(err)
}
// make sure account 2 follows account 1
f2, err := r.IsFollowing(account2, account1)
if err != nil {
return false, processErrorResponse(err)
}
return f1 && f2, nil
}
func (r *relationshipDB) AcceptFollowRequest(originAccountID string, targetAccountID string) (*gtsmodel.Follow, db.Error) {
// make sure the original follow request exists
fr := &gtsmodel.FollowRequest{}
if err := r.conn.Model(fr).Where("account_id = ?", originAccountID).Where("target_account_id = ?", targetAccountID).Select(); err != nil {
if err == pg.ErrMultiRows {
return nil, db.ErrNoEntries
}
return nil, err
}
// create a new follow to 'replace' the request with
follow := &gtsmodel.Follow{
ID: fr.ID,
AccountID: originAccountID,
TargetAccountID: targetAccountID,
URI: fr.URI,
}
// if the follow already exists, just update the URI -- we don't need to do anything else
if _, err := r.conn.Model(follow).OnConflict("ON CONSTRAINT follows_account_id_target_account_id_key DO UPDATE set uri = ?", follow.URI).Insert(); err != nil {
return nil, err
}
// now remove the follow request
if _, err := r.conn.Model(&gtsmodel.FollowRequest{}).Where("account_id = ?", originAccountID).Where("target_account_id = ?", targetAccountID).Delete(); err != nil {
return nil, err
}
return follow, nil
}
func (r *relationshipDB) GetAccountFollowRequests(accountID string) ([]*gtsmodel.FollowRequest, db.Error) {
followRequests := []*gtsmodel.FollowRequest{}
q := r.newFollowQ(&followRequests).
Where("target_account_id = ?", accountID)
err := processErrorResponse(q.Select())
return followRequests, err
}
func (r *relationshipDB) GetAccountFollows(accountID string) ([]*gtsmodel.Follow, db.Error) {
follows := []*gtsmodel.Follow{}
q := r.newFollowQ(&follows).
Where("account_id = ?", accountID)
err := processErrorResponse(q.Select())
return follows, err
}
func (r *relationshipDB) CountAccountFollows(accountID string, localOnly bool) (int, db.Error) {
return r.conn.
Model(&[]*gtsmodel.Follow{}).
Where("account_id = ?", accountID).
Count()
}
func (r *relationshipDB) GetAccountFollowedBy(accountID string, localOnly bool) ([]*gtsmodel.Follow, db.Error) {
follows := []*gtsmodel.Follow{}
q := r.conn.Model(&follows)
if localOnly {
// for local accounts let's get where domain is null OR where domain is an empty string, just to be safe
whereGroup := func(q *pg.Query) (*pg.Query, error) {
q = q.
WhereOr("? IS NULL", pg.Ident("a.domain")).
WhereOr("a.domain = ?", "")
return q, nil
}
q = q.ColumnExpr("follow.*").
Join("JOIN accounts AS a ON follow.account_id = TEXT(a.id)").
Where("follow.target_account_id = ?", accountID).
WhereGroup(whereGroup)
} else {
q = q.Where("target_account_id = ?", accountID)
}
if err := q.Select(); err != nil {
if err == pg.ErrNoRows {
return follows, nil
}
return nil, err
}
return follows, nil
}
func (r *relationshipDB) CountAccountFollowedBy(accountID string, localOnly bool) (int, db.Error) {
return r.conn.
Model(&[]*gtsmodel.Follow{}).
Where("target_account_id = ?", accountID).
Count()
}

318
internal/db/pg/status.go Normal file
View file

@ -0,0 +1,318 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package pg
import (
"container/list"
"context"
"errors"
"time"
"github.com/go-pg/pg/v10"
"github.com/go-pg/pg/v10/orm"
"github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/cache"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
type statusDB struct {
config *config.Config
conn *pg.DB
log *logrus.Logger
cancel context.CancelFunc
cache cache.Cache
}
func (s *statusDB) cacheStatus(id string, status *gtsmodel.Status) {
if s.cache == nil {
s.cache = cache.New()
}
if err := s.cache.Store(id, status); err != nil {
s.log.Panicf("statusDB: error storing in cache: %s", err)
}
}
func (s *statusDB) statusCached(id string) (*gtsmodel.Status, bool) {
if s.cache == nil {
s.cache = cache.New()
return nil, false
}
sI, err := s.cache.Fetch(id)
if err != nil || sI == nil {
return nil, false
}
status, ok := sI.(*gtsmodel.Status)
if !ok {
s.log.Panicf("statusDB: cached interface with key %s was not a status", id)
}
return status, true
}
func (s *statusDB) newStatusQ(status interface{}) *orm.Query {
return s.conn.Model(status).
Relation("Attachments").
Relation("Tags").
Relation("Mentions").
Relation("Emojis").
Relation("Account").
Relation("InReplyTo").
Relation("InReplyToAccount").
Relation("BoostOf").
Relation("BoostOfAccount").
Relation("CreatedWithApplication")
}
func (s *statusDB) newFaveQ(faves interface{}) *orm.Query {
return s.conn.Model(faves).
Relation("Account").
Relation("TargetAccount").
Relation("Status")
}
func (s *statusDB) GetStatusByID(id string) (*gtsmodel.Status, db.Error) {
if status, cached := s.statusCached(id); cached {
return status, nil
}
status := &gtsmodel.Status{}
q := s.newStatusQ(status).
Where("status.id = ?", id)
err := processErrorResponse(q.Select())
if err == nil && status != nil {
s.cacheStatus(id, status)
}
return status, err
}
func (s *statusDB) GetStatusByURI(uri string) (*gtsmodel.Status, db.Error) {
if status, cached := s.statusCached(uri); cached {
return status, nil
}
status := &gtsmodel.Status{}
q := s.newStatusQ(status).
Where("LOWER(status.uri) = LOWER(?)", uri)
err := processErrorResponse(q.Select())
if err == nil && status != nil {
s.cacheStatus(uri, status)
}
return status, err
}
func (s *statusDB) GetStatusByURL(uri string) (*gtsmodel.Status, db.Error) {
if status, cached := s.statusCached(uri); cached {
return status, nil
}
status := &gtsmodel.Status{}
q := s.newStatusQ(status).
Where("LOWER(status.url) = LOWER(?)", uri)
err := processErrorResponse(q.Select())
if err == nil && status != nil {
s.cacheStatus(uri, status)
}
return status, err
}
func (s *statusDB) PutStatus(status *gtsmodel.Status) db.Error {
transaction := func(tx *pg.Tx) error {
// create links between this status and any emojis it uses
for _, i := range status.EmojiIDs {
if _, err := tx.Model(&gtsmodel.StatusToEmoji{
StatusID: status.ID,
EmojiID: i,
}).Insert(); err != nil {
return err
}
}
// create links between this status and any tags it uses
for _, i := range status.TagIDs {
if _, err := tx.Model(&gtsmodel.StatusToTag{
StatusID: status.ID,
TagID: i,
}).Insert(); err != nil {
return err
}
}
// change the status ID of the media attachments to the new status
for _, a := range status.Attachments {
a.StatusID = status.ID
a.UpdatedAt = time.Now()
if _, err := s.conn.Model(a).
Where("id = ?", a.ID).
Update(); err != nil {
return err
}
}
_, err := tx.Model(status).Insert()
return err
}
return processErrorResponse(s.conn.RunInTransaction(context.Background(), transaction))
}
func (s *statusDB) GetStatusParents(status *gtsmodel.Status, onlyDirect bool) ([]*gtsmodel.Status, db.Error) {
parents := []*gtsmodel.Status{}
s.statusParent(status, &parents, onlyDirect)
return parents, nil
}
func (s *statusDB) statusParent(status *gtsmodel.Status, foundStatuses *[]*gtsmodel.Status, onlyDirect bool) {
if status.InReplyToID == "" {
return
}
parentStatus, err := s.GetStatusByID(status.InReplyToID)
if err == nil {
*foundStatuses = append(*foundStatuses, parentStatus)
}
if onlyDirect {
return
}
s.statusParent(parentStatus, foundStatuses, false)
}
func (s *statusDB) GetStatusChildren(status *gtsmodel.Status, onlyDirect bool, minID string) ([]*gtsmodel.Status, db.Error) {
foundStatuses := &list.List{}
foundStatuses.PushFront(status)
s.statusChildren(status, foundStatuses, onlyDirect, minID)
children := []*gtsmodel.Status{}
for e := foundStatuses.Front(); e != nil; e = e.Next() {
entry, ok := e.Value.(*gtsmodel.Status)
if !ok {
panic(errors.New("entry in foundStatuses was not a *gtsmodel.Status"))
}
// only append children, not the overall parent status
if entry.ID != status.ID {
children = append(children, entry)
}
}
return children, nil
}
func (s *statusDB) statusChildren(status *gtsmodel.Status, foundStatuses *list.List, onlyDirect bool, minID string) {
immediateChildren := []*gtsmodel.Status{}
q := s.conn.Model(&immediateChildren).Where("in_reply_to_id = ?", status.ID)
if minID != "" {
q = q.Where("status.id > ?", minID)
}
if err := q.Select(); err != nil {
return
}
for _, child := range immediateChildren {
insertLoop:
for e := foundStatuses.Front(); e != nil; e = e.Next() {
entry, ok := e.Value.(*gtsmodel.Status)
if !ok {
panic(errors.New("entry in foundStatuses was not a *gtsmodel.Status"))
}
if child.InReplyToAccountID != "" && entry.ID == child.InReplyToID {
foundStatuses.InsertAfter(child, e)
break insertLoop
}
}
// only do one loop if we only want direct children
if onlyDirect {
return
}
s.statusChildren(child, foundStatuses, false, minID)
}
}
func (s *statusDB) CountStatusReplies(status *gtsmodel.Status) (int, db.Error) {
return s.conn.Model(&gtsmodel.Status{}).Where("in_reply_to_id = ?", status.ID).Count()
}
func (s *statusDB) CountStatusReblogs(status *gtsmodel.Status) (int, db.Error) {
return s.conn.Model(&gtsmodel.Status{}).Where("boost_of_id = ?", status.ID).Count()
}
func (s *statusDB) CountStatusFaves(status *gtsmodel.Status) (int, db.Error) {
return s.conn.Model(&gtsmodel.StatusFave{}).Where("status_id = ?", status.ID).Count()
}
func (s *statusDB) IsStatusFavedBy(status *gtsmodel.Status, accountID string) (bool, db.Error) {
return s.conn.Model(&gtsmodel.StatusFave{}).Where("status_id = ?", status.ID).Where("account_id = ?", accountID).Exists()
}
func (s *statusDB) IsStatusRebloggedBy(status *gtsmodel.Status, accountID string) (bool, db.Error) {
return s.conn.Model(&gtsmodel.Status{}).Where("boost_of_id = ?", status.ID).Where("account_id = ?", accountID).Exists()
}
func (s *statusDB) IsStatusMutedBy(status *gtsmodel.Status, accountID string) (bool, db.Error) {
return s.conn.Model(&gtsmodel.StatusMute{}).Where("status_id = ?", status.ID).Where("account_id = ?", accountID).Exists()
}
func (s *statusDB) IsStatusBookmarkedBy(status *gtsmodel.Status, accountID string) (bool, db.Error) {
return s.conn.Model(&gtsmodel.StatusBookmark{}).Where("status_id = ?", status.ID).Where("account_id = ?", accountID).Exists()
}
func (s *statusDB) GetStatusFaves(status *gtsmodel.Status) ([]*gtsmodel.StatusFave, db.Error) {
faves := []*gtsmodel.StatusFave{}
q := s.newFaveQ(&faves).
Where("status_id = ?", status.ID)
err := processErrorResponse(q.Select())
return faves, err
}
func (s *statusDB) GetStatusReblogs(status *gtsmodel.Status) ([]*gtsmodel.Status, db.Error) {
reblogs := []*gtsmodel.Status{}
q := s.newStatusQ(&reblogs).
Where("boost_of_id = ?", status.ID)
err := processErrorResponse(q.Select())
return reblogs, err
}

View file

@ -0,0 +1,134 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package pg_test
import (
"fmt"
"testing"
"time"
"github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/testrig"
)
type StatusTestSuite struct {
PGStandardTestSuite
}
func (suite *StatusTestSuite) SetupSuite() {
suite.testTokens = testrig.NewTestTokens()
suite.testClients = testrig.NewTestClients()
suite.testApplications = testrig.NewTestApplications()
suite.testUsers = testrig.NewTestUsers()
suite.testAccounts = testrig.NewTestAccounts()
suite.testAttachments = testrig.NewTestAttachments()
suite.testStatuses = testrig.NewTestStatuses()
suite.testTags = testrig.NewTestTags()
suite.testMentions = testrig.NewTestMentions()
}
func (suite *StatusTestSuite) SetupTest() {
suite.config = testrig.NewTestConfig()
suite.db = testrig.NewTestDB()
suite.log = testrig.NewTestLog()
testrig.StandardDBSetup(suite.db, suite.testAccounts)
}
func (suite *StatusTestSuite) TearDownTest() {
testrig.StandardDBTeardown(suite.db)
}
func (suite *StatusTestSuite) TestGetStatusByID() {
status, err := suite.db.GetStatusByID(suite.testStatuses["local_account_1_status_1"].ID)
if err != nil {
suite.FailNow(err.Error())
}
suite.NotNil(status)
suite.NotNil(status.Account)
suite.NotNil(status.CreatedWithApplication)
suite.Nil(status.BoostOf)
suite.Nil(status.BoostOfAccount)
suite.Nil(status.InReplyTo)
suite.Nil(status.InReplyToAccount)
}
func (suite *StatusTestSuite) TestGetStatusByURI() {
status, err := suite.db.GetStatusByURI(suite.testStatuses["local_account_1_status_1"].URI)
if err != nil {
suite.FailNow(err.Error())
}
suite.NotNil(status)
suite.NotNil(status.Account)
suite.NotNil(status.CreatedWithApplication)
suite.Nil(status.BoostOf)
suite.Nil(status.BoostOfAccount)
suite.Nil(status.InReplyTo)
suite.Nil(status.InReplyToAccount)
}
func (suite *StatusTestSuite) TestGetStatusWithExtras() {
status, err := suite.db.GetStatusByID(suite.testStatuses["admin_account_status_1"].ID)
if err != nil {
suite.FailNow(err.Error())
}
suite.NotNil(status)
suite.NotNil(status.Account)
suite.NotNil(status.CreatedWithApplication)
suite.NotEmpty(status.Tags)
suite.NotEmpty(status.Attachments)
suite.NotEmpty(status.Emojis)
}
func (suite *StatusTestSuite) TestGetStatusWithMention() {
status, err := suite.db.GetStatusByID(suite.testStatuses["local_account_2_status_5"].ID)
if err != nil {
suite.FailNow(err.Error())
}
suite.NotNil(status)
suite.NotNil(status.Account)
suite.NotNil(status.CreatedWithApplication)
suite.NotEmpty(status.Mentions)
suite.NotEmpty(status.MentionIDs)
suite.NotNil(status.InReplyTo)
suite.NotNil(status.InReplyToAccount)
}
func (suite *StatusTestSuite) TestGetStatusTwice() {
before1 := time.Now()
_, err := suite.db.GetStatusByURI(suite.testStatuses["local_account_1_status_1"].URI)
suite.NoError(err)
after1 := time.Now()
duration1 := after1.Sub(before1)
fmt.Println(duration1.Nanoseconds())
before2 := time.Now()
_, err = suite.db.GetStatusByURI(suite.testStatuses["local_account_1_status_1"].URI)
suite.NoError(err)
after2 := time.Now()
duration2 := after2.Sub(before2)
fmt.Println(duration2.Nanoseconds())
// second retrieval should be several orders faster since it will be cached now
suite.Less(duration2, duration1)
}
func TestStatusTestSuite(t *testing.T) {
suite.Run(t, new(StatusTestSuite))
}

View file

@ -1,104 +0,0 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package pg
import (
"container/list"
"errors"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
func (ps *postgresService) StatusParents(status *gtsmodel.Status, onlyDirect bool) ([]*gtsmodel.Status, error) {
parents := []*gtsmodel.Status{}
ps.statusParent(status, &parents, onlyDirect)
return parents, nil
}
func (ps *postgresService) statusParent(status *gtsmodel.Status, foundStatuses *[]*gtsmodel.Status, onlyDirect bool) {
if status.InReplyToID == "" {
return
}
parentStatus := &gtsmodel.Status{}
if err := ps.conn.Model(parentStatus).Where("id = ?", status.InReplyToID).Select(); err == nil {
*foundStatuses = append(*foundStatuses, parentStatus)
}
if onlyDirect {
return
}
ps.statusParent(parentStatus, foundStatuses, false)
}
func (ps *postgresService) StatusChildren(status *gtsmodel.Status, onlyDirect bool, minID string) ([]*gtsmodel.Status, error) {
foundStatuses := &list.List{}
foundStatuses.PushFront(status)
ps.statusChildren(status, foundStatuses, onlyDirect, minID)
children := []*gtsmodel.Status{}
for e := foundStatuses.Front(); e != nil; e = e.Next() {
entry, ok := e.Value.(*gtsmodel.Status)
if !ok {
panic(errors.New("entry in foundStatuses was not a *gtsmodel.Status"))
}
// only append children, not the overall parent status
if entry.ID != status.ID {
children = append(children, entry)
}
}
return children, nil
}
func (ps *postgresService) statusChildren(status *gtsmodel.Status, foundStatuses *list.List, onlyDirect bool, minID string) {
immediateChildren := []*gtsmodel.Status{}
q := ps.conn.Model(&immediateChildren).Where("in_reply_to_id = ?", status.ID)
if minID != "" {
q = q.Where("status.id > ?", minID)
}
if err := q.Select(); err != nil {
return
}
for _, child := range immediateChildren {
insertLoop:
for e := foundStatuses.Front(); e != nil; e = e.Next() {
entry, ok := e.Value.(*gtsmodel.Status)
if !ok {
panic(errors.New("entry in foundStatuses was not a *gtsmodel.Status"))
}
if child.InReplyToAccountID != "" && entry.ID == child.InReplyToID {
foundStatuses.InsertAfter(child, e)
break insertLoop
}
}
// only do one loop if we only want direct children
if onlyDirect {
return
}
ps.statusChildren(child, foundStatuses, false, minID)
}
}

View file

@ -19,16 +19,26 @@
package pg package pg
import ( import (
"context"
"sort" "sort"
"github.com/go-pg/pg/v10" "github.com/go-pg/pg/v10"
"github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
) )
func (ps *postgresService) GetHomeTimelineForAccount(accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, error) { type timelineDB struct {
config *config.Config
conn *pg.DB
log *logrus.Logger
cancel context.CancelFunc
}
func (t *timelineDB) GetHomeTimeline(accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, db.Error) {
statuses := []*gtsmodel.Status{} statuses := []*gtsmodel.Status{}
q := ps.conn.Model(&statuses) q := t.conn.Model(&statuses)
q = q.ColumnExpr("status.*"). q = q.ColumnExpr("status.*").
// Find out who accountID follows. // Find out who accountID follows.
@ -74,22 +84,22 @@ func (ps *postgresService) GetHomeTimelineForAccount(accountID string, maxID str
err := q.Select() err := q.Select()
if err != nil { if err != nil {
if err == pg.ErrNoRows { if err == pg.ErrNoRows {
return nil, db.ErrNoEntries{} return nil, db.ErrNoEntries
} }
return nil, err return nil, err
} }
if len(statuses) == 0 { if len(statuses) == 0 {
return nil, db.ErrNoEntries{} return nil, db.ErrNoEntries
} }
return statuses, nil return statuses, nil
} }
func (ps *postgresService) GetPublicTimelineForAccount(accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, error) { func (t *timelineDB) GetPublicTimeline(accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, db.Error) {
statuses := []*gtsmodel.Status{} statuses := []*gtsmodel.Status{}
q := ps.conn.Model(&statuses). q := t.conn.Model(&statuses).
Where("visibility = ?", gtsmodel.VisibilityPublic). Where("visibility = ?", gtsmodel.VisibilityPublic).
Where("? IS NULL", pg.Ident("in_reply_to_id")). Where("? IS NULL", pg.Ident("in_reply_to_id")).
Where("? IS NULL", pg.Ident("in_reply_to_uri")). Where("? IS NULL", pg.Ident("in_reply_to_uri")).
@ -119,13 +129,13 @@ func (ps *postgresService) GetPublicTimelineForAccount(accountID string, maxID s
err := q.Select() err := q.Select()
if err != nil { if err != nil {
if err == pg.ErrNoRows { if err == pg.ErrNoRows {
return nil, db.ErrNoEntries{} return nil, db.ErrNoEntries
} }
return nil, err return nil, err
} }
if len(statuses) == 0 { if len(statuses) == 0 {
return nil, db.ErrNoEntries{} return nil, db.ErrNoEntries
} }
return statuses, nil return statuses, nil
@ -133,11 +143,11 @@ func (ps *postgresService) GetPublicTimelineForAccount(accountID string, maxID s
// TODO optimize this query and the logic here, because it's slow as balls -- it takes like a literal second to return with a limit of 20! // TODO optimize this query and the logic here, because it's slow as balls -- it takes like a literal second to return with a limit of 20!
// It might be worth serving it through a timeline instead of raw DB queries, like we do for Home feeds. // It might be worth serving it through a timeline instead of raw DB queries, like we do for Home feeds.
func (ps *postgresService) GetFavedTimelineForAccount(accountID string, maxID string, minID string, limit int) ([]*gtsmodel.Status, string, string, error) { func (t *timelineDB) GetFavedTimeline(accountID string, maxID string, minID string, limit int) ([]*gtsmodel.Status, string, string, db.Error) {
faves := []*gtsmodel.StatusFave{} faves := []*gtsmodel.StatusFave{}
fq := ps.conn.Model(&faves). fq := t.conn.Model(&faves).
Where("account_id = ?", accountID). Where("account_id = ?", accountID).
Order("id DESC") Order("id DESC")
@ -156,13 +166,13 @@ func (ps *postgresService) GetFavedTimelineForAccount(accountID string, maxID st
err := fq.Select() err := fq.Select()
if err != nil { if err != nil {
if err == pg.ErrNoRows { if err == pg.ErrNoRows {
return nil, "", "", db.ErrNoEntries{} return nil, "", "", db.ErrNoEntries
} }
return nil, "", "", err return nil, "", "", err
} }
if len(faves) == 0 { if len(faves) == 0 {
return nil, "", "", db.ErrNoEntries{} return nil, "", "", db.ErrNoEntries
} }
// map[statusID]faveID -- we need this to sort statuses by fave ID rather than their own ID // map[statusID]faveID -- we need this to sort statuses by fave ID rather than their own ID
@ -175,16 +185,16 @@ func (ps *postgresService) GetFavedTimelineForAccount(accountID string, maxID st
} }
statuses := []*gtsmodel.Status{} statuses := []*gtsmodel.Status{}
err = ps.conn.Model(&statuses).Where("id IN (?)", pg.In(in)).Select() err = t.conn.Model(&statuses).Where("id IN (?)", pg.In(in)).Select()
if err != nil { if err != nil {
if err == pg.ErrNoRows { if err == pg.ErrNoRows {
return nil, "", "", db.ErrNoEntries{} return nil, "", "", db.ErrNoEntries
} }
return nil, "", "", err return nil, "", "", err
} }
if len(statuses) == 0 { if len(statuses) == 0 {
return nil, "", "", db.ErrNoEntries{} return nil, "", "", db.ErrNoEntries
} }
// arrange statuses by fave ID // arrange statuses by fave ID

View file

@ -1,73 +0,0 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package pg
import (
"fmt"
"github.com/go-pg/pg/v10"
"github.com/superseriousbusiness/gotosocial/internal/db"
)
func (ps *postgresService) Upsert(i interface{}, conflictColumn string) error {
if _, err := ps.conn.Model(i).OnConflict(fmt.Sprintf("(%s) DO UPDATE", conflictColumn)).Insert(); err != nil {
if err == pg.ErrNoRows {
return db.ErrNoEntries{}
}
return err
}
return nil
}
func (ps *postgresService) UpdateByID(id string, i interface{}) error {
if _, err := ps.conn.Model(i).Where("id = ?", id).OnConflict("(id) DO UPDATE").Insert(); err != nil {
if err == pg.ErrNoRows {
return db.ErrNoEntries{}
}
return err
}
return nil
}
func (ps *postgresService) UpdateOneByID(id string, key string, value interface{}, i interface{}) error {
_, err := ps.conn.Model(i).Set("? = ?", pg.Safe(key), value).Where("id = ?", id).Update()
return err
}
func (ps *postgresService) UpdateWhere(where []db.Where, key string, value interface{}, i interface{}) error {
q := ps.conn.Model(i)
for _, w := range where {
if w.Value == nil {
q = q.Where("? IS NULL", pg.Ident(w.Key))
} else {
if w.CaseInsensitive {
q = q.Where("LOWER(?) = LOWER(?)", pg.Safe(w.Key), w.Value)
} else {
q = q.Where("? = ?", pg.Safe(w.Key), w.Value)
}
}
}
q = q.Set("? = ?", pg.Safe(key), value)
_, err := q.Update()
return err
}

25
internal/db/pg/util.go Normal file
View file

@ -0,0 +1,25 @@
package pg
import (
"strings"
"github.com/go-pg/pg/v10"
"github.com/superseriousbusiness/gotosocial/internal/db"
)
// processErrorResponse parses the given error and returns an appropriate DBError.
func processErrorResponse(err error) db.Error {
switch err {
case nil:
return nil
case pg.ErrNoRows:
return db.ErrNoEntries
case pg.ErrMultiRows:
return db.ErrMultipleEntries
default:
if strings.Contains(err.Error(), "duplicate key value violates unique constraint") {
return db.ErrAlreadyExists
}
return err
}
}

View file

@ -0,0 +1,71 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package db
import "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
// Relationship contains functions for getting or modifying the relationship between two accounts.
type Relationship interface {
// IsBlocked checks whether account 1 has a block in place against block2.
// If eitherDirection is true, then the function returns true if account1 blocks account2, OR if account2 blocks account1.
IsBlocked(account1 string, account2 string, eitherDirection bool) (bool, Error)
// GetBlock returns the block from account1 targeting account2, if it exists, or an error if it doesn't.
//
// Because this is slower than Blocked, only use it if you need the actual Block struct for some reason,
// not if you're just checking for the existence of a block.
GetBlock(account1 string, account2 string) (*gtsmodel.Block, Error)
// GetRelationship retrieves the relationship of the targetAccount to the requestingAccount.
GetRelationship(requestingAccount string, targetAccount string) (*gtsmodel.Relationship, Error)
// IsFollowing returns true if sourceAccount follows target account, or an error if something goes wrong while finding out.
IsFollowing(sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, Error)
// IsFollowRequested returns true if sourceAccount has requested to follow target account, or an error if something goes wrong while finding out.
IsFollowRequested(sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, Error)
// IsMutualFollowing returns true if account1 and account2 both follow each other, or an error if something goes wrong while finding out.
IsMutualFollowing(account1 *gtsmodel.Account, account2 *gtsmodel.Account) (bool, Error)
// AcceptFollowRequest moves a follow request in the database from the follow_requests table to the follows table.
// In other words, it should create the follow, and delete the existing follow request.
//
// It will return the newly created follow for further processing.
AcceptFollowRequest(originAccountID string, targetAccountID string) (*gtsmodel.Follow, Error)
// GetAccountFollowRequests returns all follow requests targeting the given account.
GetAccountFollowRequests(accountID string) ([]*gtsmodel.FollowRequest, Error)
// GetAccountFollows returns a slice of follows owned by the given accountID.
GetAccountFollows(accountID string) ([]*gtsmodel.Follow, Error)
// CountAccountFollows returns the amount of accounts that the given accountID is following.
//
// If localOnly is set to true, then only follows from *this instance* will be returned.
CountAccountFollows(accountID string, localOnly bool) (int, Error)
// GetAccountFollowedBy fetches follows that target given accountID.
//
// If localOnly is set to true, then only follows from *this instance* will be returned.
GetAccountFollowedBy(accountID string, localOnly bool) ([]*gtsmodel.Follow, Error)
// CountAccountFollowedBy returns the amounts that the given ID is followed by.
CountAccountFollowedBy(accountID string, localOnly bool) (int, Error)
}

75
internal/db/status.go Normal file
View file

@ -0,0 +1,75 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package db
import "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
// Status contains functions for getting statuses, creating statuses, and checking various other fields on statuses.
type Status interface {
// GetStatusByID returns one status from the database, with all rel fields populated (if possible).
GetStatusByID(id string) (*gtsmodel.Status, Error)
// GetStatusByURI returns one status from the database, with all rel fields populated (if possible).
GetStatusByURI(uri string) (*gtsmodel.Status, Error)
// GetStatusByURL returns one status from the database, with all rel fields populated (if possible).
GetStatusByURL(uri string) (*gtsmodel.Status, Error)
// PutStatus stores one status in the database.
PutStatus(status *gtsmodel.Status) Error
// CountStatusReplies returns the amount of replies recorded for a status, or an error if something goes wrong
CountStatusReplies(status *gtsmodel.Status) (int, Error)
// CountStatusReblogs returns the amount of reblogs/boosts recorded for a status, or an error if something goes wrong
CountStatusReblogs(status *gtsmodel.Status) (int, Error)
// CountStatusFaves returns the amount of faves/likes recorded for a status, or an error if something goes wrong
CountStatusFaves(status *gtsmodel.Status) (int, Error)
// GetStatusParents gets the parent statuses of a given status.
//
// If onlyDirect is true, only the immediate parent will be returned.
GetStatusParents(status *gtsmodel.Status, onlyDirect bool) ([]*gtsmodel.Status, Error)
// GetStatusChildren gets the child statuses of a given status.
//
// If onlyDirect is true, only the immediate children will be returned.
GetStatusChildren(status *gtsmodel.Status, onlyDirect bool, minID string) ([]*gtsmodel.Status, Error)
// IsStatusFavedBy checks if a given status has been faved by a given account ID
IsStatusFavedBy(status *gtsmodel.Status, accountID string) (bool, Error)
// IsStatusRebloggedBy checks if a given status has been reblogged/boosted by a given account ID
IsStatusRebloggedBy(status *gtsmodel.Status, accountID string) (bool, Error)
// IsStatusMutedBy checks if a given status has been muted by a given account ID
IsStatusMutedBy(status *gtsmodel.Status, accountID string) (bool, Error)
// IsStatusBookmarkedBy checks if a given status has been bookmarked by a given account ID
IsStatusBookmarkedBy(status *gtsmodel.Status, accountID string) (bool, Error)
// GetStatusFaves returns a slice of faves/likes of the given status.
// This slice will be unfiltered, not taking account of blocks and whatnot, so filter it before serving it back to a user.
GetStatusFaves(status *gtsmodel.Status) ([]*gtsmodel.StatusFave, Error)
// GetStatusReblogs returns a slice of statuses that are a boost/reblog of the given status.
// This slice will be unfiltered, not taking account of blocks and whatnot, so filter it before serving it back to a user.
GetStatusReblogs(status *gtsmodel.Status) ([]*gtsmodel.Status, Error)
}

44
internal/db/timeline.go Normal file
View file

@ -0,0 +1,44 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package db
import "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
// Timeline contains functionality for retrieving home/public/faved etc timelines for an account.
type Timeline interface {
// GetHomeTimeline returns a slice of statuses from accounts that are followed by the given account id.
//
// Statuses should be returned in descending order of when they were created (newest first).
GetHomeTimeline(accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, Error)
// GetPublicTimeline fetches the account's PUBLIC timeline -- ie., posts and replies that are public.
// It will use the given filters and try to return as many statuses as possible up to the limit.
//
// Statuses should be returned in descending order of when they were created (newest first).
GetPublicTimeline(accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, Error)
// GetFavedTimeline fetches the account's FAVED timeline -- ie., posts and replies that the requesting account has faved.
// It will use the given filters and try to return as many statuses as possible up to the limit.
//
// Note that unlike the other GetTimeline functions, the returned statuses will be arranged by their FAVE id, not the STATUS id.
// In other words, they'll be returned in descending order of when they were faved by the requesting user, not when they were created.
//
// Also note the extra return values, which correspond to the nextMaxID and prevMinID for building Link headers.
GetFavedTimeline(accountID string, maxID string, minID string, limit int) ([]*gtsmodel.Status, string, string, Error)
}

View file

@ -1,3 +1,21 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package federation package federation
import ( import (

View file

@ -29,7 +29,6 @@ import (
"github.com/go-fed/activity/streams/vocab" "github.com/go-fed/activity/streams/vocab"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/ap" "github.com/superseriousbusiness/gotosocial/internal/ap"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/id" "github.com/superseriousbusiness/gotosocial/internal/id"
"github.com/superseriousbusiness/gotosocial/internal/transport" "github.com/superseriousbusiness/gotosocial/internal/transport"
@ -65,8 +64,8 @@ func (d *deref) GetRemoteAccount(username string, remoteAccountID *url.URL, refr
new := true new := true
// check if we already have the account in our db // check if we already have the account in our db
maybeAccount := &gtsmodel.Account{} maybeAccount, err := d.db.GetAccountByURI(remoteAccountID.String())
if err := d.db.GetWhere([]db.Where{{Key: "uri", Value: remoteAccountID.String()}}, maybeAccount); err == nil { if err == nil {
// we've seen this account before so it's not new // we've seen this account before so it's not new
new = false new = false
if !refresh { if !refresh {

View file

@ -27,14 +27,14 @@ import (
) )
func (d *deref) DereferenceAnnounce(announce *gtsmodel.Status, requestingUsername string) error { func (d *deref) DereferenceAnnounce(announce *gtsmodel.Status, requestingUsername string) error {
if announce.GTSBoostedStatus == nil || announce.GTSBoostedStatus.URI == "" { if announce.BoostOf == nil || announce.BoostOf.URI == "" {
// we can't do anything unfortunately // we can't do anything unfortunately
return errors.New("DereferenceAnnounce: no URI to dereference") return errors.New("DereferenceAnnounce: no URI to dereference")
} }
boostedStatusURI, err := url.Parse(announce.GTSBoostedStatus.URI) boostedStatusURI, err := url.Parse(announce.BoostOf.URI)
if err != nil { if err != nil {
return fmt.Errorf("DereferenceAnnounce: couldn't parse boosted status URI %s: %s", announce.GTSBoostedStatus.URI, err) return fmt.Errorf("DereferenceAnnounce: couldn't parse boosted status URI %s: %s", announce.BoostOf.URI, err)
} }
if blocked, err := d.blockedDomain(boostedStatusURI.Host); blocked || err != nil { if blocked, err := d.blockedDomain(boostedStatusURI.Host); blocked || err != nil {
return fmt.Errorf("DereferenceAnnounce: domain %s is blocked", boostedStatusURI.Host) return fmt.Errorf("DereferenceAnnounce: domain %s is blocked", boostedStatusURI.Host)
@ -47,7 +47,7 @@ func (d *deref) DereferenceAnnounce(announce *gtsmodel.Status, requestingUsernam
boostedStatus, _, _, err := d.GetRemoteStatus(requestingUsername, boostedStatusURI, false) boostedStatus, _, _, err := d.GetRemoteStatus(requestingUsername, boostedStatusURI, false)
if err != nil { if err != nil {
return fmt.Errorf("DereferenceAnnounce: error dereferencing remote status with id %s: %s", announce.GTSBoostedStatus.URI, err) return fmt.Errorf("DereferenceAnnounce: error dereferencing remote status with id %s: %s", announce.BoostOf.URI, err)
} }
announce.Content = boostedStatus.Content announce.Content = boostedStatus.Content
@ -60,6 +60,6 @@ func (d *deref) DereferenceAnnounce(announce *gtsmodel.Status, requestingUsernam
announce.BoostOfAccountID = boostedStatus.AccountID announce.BoostOfAccountID = boostedStatus.AccountID
announce.Visibility = boostedStatus.Visibility announce.Visibility = boostedStatus.Visibility
announce.VisibilityAdvanced = boostedStatus.VisibilityAdvanced announce.VisibilityAdvanced = boostedStatus.VisibilityAdvanced
announce.GTSBoostedStatus = boostedStatus announce.BoostOf = boostedStatus
return nil return nil
} }

View file

@ -31,7 +31,7 @@ func (d *deref) blockedDomain(host string) (bool, error) {
return true, nil return true, nil
} }
if _, ok := err.(db.ErrNoEntries); ok { if err == db.ErrNoEntries {
// there are no entries so there's no block // there are no entries so there's no block
return false, nil return false, nil
} }

View file

@ -66,8 +66,8 @@ func (d *deref) GetRemoteStatus(username string, remoteStatusID *url.URL, refres
new := true new := true
// check if we already have the status in our db // check if we already have the status in our db
maybeStatus := &gtsmodel.Status{} maybeStatus, err := d.db.GetStatusByURI(remoteStatusID.String())
if err := d.db.GetWhere([]db.Where{{Key: "uri", Value: remoteStatusID.String()}}, maybeStatus); err == nil { if err == nil {
// we've seen this status before so it's not new // we've seen this status before so it's not new
new = false new = false
@ -109,7 +109,7 @@ func (d *deref) GetRemoteStatus(username string, remoteStatusID *url.URL, refres
return nil, statusable, new, fmt.Errorf("GetRemoteStatus: error populating status fields: %s", err) return nil, statusable, new, fmt.Errorf("GetRemoteStatus: error populating status fields: %s", err)
} }
if err := d.db.Put(gtsStatus); err != nil { if err := d.db.PutStatus(gtsStatus); err != nil {
return nil, statusable, new, fmt.Errorf("GetRemoteStatus: error putting new status: %s", err) return nil, statusable, new, fmt.Errorf("GetRemoteStatus: error putting new status: %s", err)
} }
} else { } else {
@ -276,7 +276,7 @@ func (d *deref) populateStatusFields(status *gtsmodel.Status, requestingUsername
// * the remote URL (a.RemoteURL) // * the remote URL (a.RemoteURL)
// This should be enough to pass along to the media processor. // This should be enough to pass along to the media processor.
attachmentIDs := []string{} attachmentIDs := []string{}
for _, a := range status.GTSMediaAttachments { for _, a := range status.Attachments {
l.Tracef("dereferencing attachment: %+v", a) l.Tracef("dereferencing attachment: %+v", a)
// it might have been processed elsewhere so check first if it's already in the database or not // it might have been processed elsewhere so check first if it's already in the database or not
@ -288,7 +288,7 @@ func (d *deref) populateStatusFields(status *gtsmodel.Status, requestingUsername
attachmentIDs = append(attachmentIDs, maybeAttachment.ID) attachmentIDs = append(attachmentIDs, maybeAttachment.ID)
continue continue
} }
if _, ok := err.(db.ErrNoEntries); !ok { if err != db.ErrNoEntries {
// we have a real error // we have a real error
return fmt.Errorf("error checking db for existence of attachment with remote url %s: %s", a.RemoteURL, err) return fmt.Errorf("error checking db for existence of attachment with remote url %s: %s", a.RemoteURL, err)
} }
@ -307,7 +307,7 @@ func (d *deref) populateStatusFields(status *gtsmodel.Status, requestingUsername
} }
attachmentIDs = append(attachmentIDs, deferencedAttachment.ID) attachmentIDs = append(attachmentIDs, deferencedAttachment.ID)
} }
status.Attachments = attachmentIDs status.AttachmentIDs = attachmentIDs
// 2. Hashtags // 2. Hashtags
@ -317,53 +317,84 @@ func (d *deref) populateStatusFields(status *gtsmodel.Status, requestingUsername
// At this point, mentions should have the namestring and mentionedAccountURI set on them. // At this point, mentions should have the namestring and mentionedAccountURI set on them.
// //
// We should dereference any accounts mentioned here which we don't have in our db yet, by their URI. // We should dereference any accounts mentioned here which we don't have in our db yet, by their URI.
mentions := []string{} mentionIDs := []string{}
for _, m := range status.GTSMentions { for _, m := range status.Mentions {
if m.ID != "" { if m.ID != "" {
continue
// we've already populated this mention, since it has an ID // we've already populated this mention, since it has an ID
l.Debug("mention already populated")
continue
}
if m.TargetAccountURI == "" {
// can't do anything with this mention
l.Debug("target URI not set on mention")
continue
}
targetAccountURI, err := url.Parse(m.TargetAccountURI)
if err != nil {
l.Debugf("error parsing mentioned account uri %s: %s", m.TargetAccountURI, err)
continue
}
var targetAccount *gtsmodel.Account
if a, err := d.db.GetAccountByURL(targetAccountURI.String()); err == nil {
targetAccount = a
} else if a, _, err := d.GetRemoteAccount(requestingUsername, targetAccountURI, false); err == nil {
targetAccount = a
} else {
// we can't find the target account so bail
l.Debug("can't retrieve account targeted by mention")
continue
} }
mID, err := id.NewRandomULID() mID, err := id.NewRandomULID()
if err != nil { if err != nil {
return err return err
} }
m.ID = mID
uri, err := url.Parse(m.MentionedAccountURI) m = &gtsmodel.Mention{
if err != nil { ID: mID,
l.Debugf("error parsing mentioned account uri %s: %s", m.MentionedAccountURI, err) StatusID: status.ID,
continue Status: m.Status,
CreatedAt: status.CreatedAt,
UpdatedAt: status.UpdatedAt,
OriginAccountID: status.Account.ID,
OriginAccountURI: status.AccountURI,
OriginAccount: status.Account,
TargetAccountID: targetAccount.ID,
TargetAccount: targetAccount,
NameString: m.NameString,
TargetAccountURI: targetAccount.URI,
TargetAccountURL: targetAccount.URL,
} }
m.StatusID = status.ID
m.OriginAccountID = status.GTSAuthorAccount.ID
m.OriginAccountURI = status.GTSAuthorAccount.URI
targetAccount, _, err := d.GetRemoteAccount(requestingUsername, uri, false)
if err != nil {
continue
}
// by this point, we know the targetAccount exists in our database with an ID :)
m.TargetAccountID = targetAccount.ID
if err := d.db.Put(m); err != nil { if err := d.db.Put(m); err != nil {
return fmt.Errorf("error creating mention: %s", err) return fmt.Errorf("error creating mention: %s", err)
} }
mentions = append(mentions, m.ID) mentionIDs = append(mentionIDs, m.ID)
} }
status.Mentions = mentions status.MentionIDs = mentionIDs
// status has replyToURI but we don't have an ID yet for the status it replies to // status has replyToURI but we don't have an ID yet for the status it replies to
if status.InReplyToURI != "" && status.InReplyToID == "" { if status.InReplyToURI != "" && status.InReplyToID == "" {
replyToStatus := &gtsmodel.Status{} statusURI, err := url.Parse(status.InReplyToURI)
if err := d.db.GetWhere([]db.Where{{Key: "uri", Value: status.InReplyToURI}}, replyToStatus); err == nil { if err != nil {
return err
}
if replyToStatus, err := d.db.GetStatusByURI(status.InReplyToURI); err == nil {
// we have the status // we have the status
status.InReplyToID = replyToStatus.ID status.InReplyToID = replyToStatus.ID
status.InReplyTo = replyToStatus
status.InReplyToAccountID = replyToStatus.AccountID status.InReplyToAccountID = replyToStatus.AccountID
status.InReplyToAccount = replyToStatus.Account
} else if replyToStatus, _, _, err := d.GetRemoteStatus(requestingUsername, statusURI, false); err == nil {
// we got the status
status.InReplyToID = replyToStatus.ID
status.InReplyTo = replyToStatus
status.InReplyToAccountID = replyToStatus.AccountID
status.InReplyToAccount = replyToStatus.Account
} }
} }
return nil return nil
} }

View file

@ -1,3 +1,21 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package federatingdb package federatingdb
import ( import (

View file

@ -1,3 +1,21 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package federatingdb package federatingdb
import ( import (

View file

@ -112,8 +112,8 @@ func (f *federatingDB) Create(ctx context.Context, asType vocab.Type) error {
} }
status.ID = statusID status.ID = statusID
if err := f.db.Put(status); err != nil { if err := f.db.PutStatus(status); err != nil {
if _, ok := err.(db.ErrAlreadyExists); ok { if err == db.ErrAlreadyExists {
// the status already exists in the database, which means we've already handled everything else, // the status already exists in the database, which means we've already handled everything else,
// so we can just return nil here and be done with it. // so we can just return nil here and be done with it.
return nil return nil

View file

@ -1,3 +1,21 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package federatingdb package federatingdb
import ( import (
@ -6,7 +24,6 @@ import (
"net/url" "net/url"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/util" "github.com/superseriousbusiness/gotosocial/internal/util"
) )
@ -52,10 +69,8 @@ func (f *federatingDB) Delete(ctx context.Context, id *url.URL) error {
// in a delete we only get the URI, we can't know if we have a status or a profile or something else, // in a delete we only get the URI, we can't know if we have a status or a profile or something else,
// so we have to try a few different things... // so we have to try a few different things...
where := []db.Where{{Key: "uri", Value: id.String()}} s, err := f.db.GetStatusByURI(id.String())
if err == nil {
s := &gtsmodel.Status{}
if err := f.db.GetWhere(where, s); err == nil {
// it's a status // it's a status
l.Debugf("uri is for status with id: %s", s.ID) l.Debugf("uri is for status with id: %s", s.ID)
if err := f.db.DeleteByID(s.ID, &gtsmodel.Status{}); err != nil { if err := f.db.DeleteByID(s.ID, &gtsmodel.Status{}); err != nil {
@ -69,8 +84,8 @@ func (f *federatingDB) Delete(ctx context.Context, id *url.URL) error {
} }
} }
a := &gtsmodel.Account{} a, err := f.db.GetAccountByURI(id.String())
if err := f.db.GetWhere(where, a); err == nil { if err == nil {
// it's an account // it's an account
l.Debugf("uri is for an account with id: %s", s.ID) l.Debugf("uri is for an account with id: %s", s.ID)
if err := f.db.DeleteByID(a.ID, &gtsmodel.Account{}); err != nil { if err := f.db.DeleteByID(a.ID, &gtsmodel.Account{}); err != nil {

View file

@ -1,3 +1,21 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package federatingdb package federatingdb
import ( import (

View file

@ -31,7 +31,8 @@ func (f *federatingDB) Followers(c context.Context, actorIRI *url.URL) (follower
acct := &gtsmodel.Account{} acct := &gtsmodel.Account{}
if util.IsUserPath(actorIRI) { if util.IsUserPath(actorIRI) {
if err := f.db.GetWhere([]db.Where{{Key: "uri", Value: actorIRI.String()}}, acct); err != nil { acct, err = f.db.GetAccountByURI(actorIRI.String())
if err != nil {
return nil, fmt.Errorf("FOLLOWERS: db error getting account with uri %s: %s", actorIRI.String(), err) return nil, fmt.Errorf("FOLLOWERS: db error getting account with uri %s: %s", actorIRI.String(), err)
} }
} else if util.IsFollowersPath(actorIRI) { } else if util.IsFollowersPath(actorIRI) {
@ -42,8 +43,8 @@ func (f *federatingDB) Followers(c context.Context, actorIRI *url.URL) (follower
return nil, fmt.Errorf("FOLLOWERS: could not parse actor IRI %s as users or followers path", actorIRI.String()) return nil, fmt.Errorf("FOLLOWERS: could not parse actor IRI %s as users or followers path", actorIRI.String())
} }
acctFollowers := []gtsmodel.Follow{} acctFollowers, err := f.db.GetAccountFollowedBy(acct.ID, false)
if err := f.db.GetFollowersByAccountID(acct.ID, &acctFollowers, false); err != nil { if err != nil {
return nil, fmt.Errorf("FOLLOWERS: db error getting followers for account id %s: %s", acct.ID, err) return nil, fmt.Errorf("FOLLOWERS: db error getting followers for account id %s: %s", acct.ID, err)
} }

View file

@ -8,7 +8,6 @@ import (
"github.com/go-fed/activity/streams" "github.com/go-fed/activity/streams"
"github.com/go-fed/activity/streams/vocab" "github.com/go-fed/activity/streams/vocab"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/util" "github.com/superseriousbusiness/gotosocial/internal/util"
) )
@ -28,21 +27,37 @@ func (f *federatingDB) Following(c context.Context, actorIRI *url.URL) (followin
) )
l.Debugf("entering FOLLOWING function with actorIRI %s", actorIRI.String()) l.Debugf("entering FOLLOWING function with actorIRI %s", actorIRI.String())
acct := &gtsmodel.Account{} var acct *gtsmodel.Account
if util.IsUserPath(actorIRI) { if util.IsUserPath(actorIRI) {
if err := f.db.GetWhere([]db.Where{{Key: "uri", Value: actorIRI.String()}}, acct); err != nil { username, err := util.ParseUserPath(actorIRI)
if err != nil {
return nil, fmt.Errorf("FOLLOWING: error parsing user path: %s", err)
}
a, err := f.db.GetLocalAccountByUsername(username)
if err != nil {
return nil, fmt.Errorf("FOLLOWING: db error getting account with uri %s: %s", actorIRI.String(), err) return nil, fmt.Errorf("FOLLOWING: db error getting account with uri %s: %s", actorIRI.String(), err)
} }
acct = a
} else if util.IsFollowingPath(actorIRI) { } else if util.IsFollowingPath(actorIRI) {
if err := f.db.GetWhere([]db.Where{{Key: "following_uri", Value: actorIRI.String()}}, acct); err != nil { username, err := util.ParseFollowingPath(actorIRI)
if err != nil {
return nil, fmt.Errorf("FOLLOWING: error parsing following path: %s", err)
}
a, err := f.db.GetLocalAccountByUsername(username)
if err != nil {
return nil, fmt.Errorf("FOLLOWING: db error getting account with following uri %s: %s", actorIRI.String(), err) return nil, fmt.Errorf("FOLLOWING: db error getting account with following uri %s: %s", actorIRI.String(), err)
} }
acct = a
} else { } else {
return nil, fmt.Errorf("FOLLOWING: could not parse actor IRI %s as users or following path", actorIRI.String()) return nil, fmt.Errorf("FOLLOWING: could not parse actor IRI %s as users or following path", actorIRI.String())
} }
acctFollowing := []gtsmodel.Follow{} acctFollowing, err := f.db.GetAccountFollows(acct.ID)
if err := f.db.GetFollowingByAccountID(acct.ID, &acctFollowing); err != nil { if err != nil {
return nil, fmt.Errorf("FOLLOWING: db error getting following for account id %s: %s", acct.ID, err) return nil, fmt.Errorf("FOLLOWING: db error getting following for account id %s: %s", acct.ID, err)
} }

View file

@ -43,8 +43,8 @@ func (f *federatingDB) Get(c context.Context, id *url.URL) (value vocab.Type, er
l.Debug("entering GET function") l.Debug("entering GET function")
if util.IsUserPath(id) { if util.IsUserPath(id) {
acct := &gtsmodel.Account{} acct, err := f.db.GetAccountByURI(id.String())
if err := f.db.GetWhere([]db.Where{{Key: "uri", Value: id.String()}}, acct); err != nil { if err != nil {
return nil, err return nil, err
} }
l.Debug("is user path! returning account") l.Debug("is user path! returning account")

View file

@ -1,3 +1,21 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package federatingdb package federatingdb
import ( import (

View file

@ -1,3 +1,21 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package federatingdb package federatingdb
import ( import (

View file

@ -1,3 +1,21 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package federatingdb package federatingdb
import ( import (
@ -62,7 +80,7 @@ func (f *federatingDB) OutboxForInbox(c context.Context, inboxIRI *url.URL) (out
} }
acct := &gtsmodel.Account{} acct := &gtsmodel.Account{}
if err := f.db.GetWhere([]db.Where{{Key: "inbox_uri", Value: inboxIRI.String()}}, acct); err != nil { if err := f.db.GetWhere([]db.Where{{Key: "inbox_uri", Value: inboxIRI.String()}}, acct); err != nil {
if _, ok := err.(db.ErrNoEntries); ok { if err == db.ErrNoEntries {
return nil, fmt.Errorf("no actor found that corresponds to inbox %s", inboxIRI.String()) return nil, fmt.Errorf("no actor found that corresponds to inbox %s", inboxIRI.String())
} }
return nil, fmt.Errorf("db error searching for actor with inbox %s", inboxIRI.String()) return nil, fmt.Errorf("db error searching for actor with inbox %s", inboxIRI.String())

View file

@ -54,16 +54,16 @@ func (f *federatingDB) Owns(c context.Context, id *url.URL) (bool, error) {
if err != nil { if err != nil {
return false, fmt.Errorf("error parsing statuses path for url %s: %s", id.String(), err) return false, fmt.Errorf("error parsing statuses path for url %s: %s", id.String(), err)
} }
if err := f.db.GetWhere([]db.Where{{Key: "uri", Value: uid}}, &gtsmodel.Status{}); err != nil { status, err := f.db.GetStatusByURI(uid)
if _, ok := err.(db.ErrNoEntries); ok { if err != nil {
if err == db.ErrNoEntries {
// there are no entries for this status // there are no entries for this status
return false, nil return false, nil
} }
// an actual error happened // an actual error happened
return false, fmt.Errorf("database error fetching status with id %s: %s", uid, err) return false, fmt.Errorf("database error fetching status with id %s: %s", uid, err)
} }
l.Debugf("we own url %s", id.String()) return status.Local, nil
return true, nil
} }
if util.IsUserPath(id) { if util.IsUserPath(id) {
@ -71,8 +71,8 @@ func (f *federatingDB) Owns(c context.Context, id *url.URL) (bool, error) {
if err != nil { if err != nil {
return false, fmt.Errorf("error parsing statuses path for url %s: %s", id.String(), err) return false, fmt.Errorf("error parsing statuses path for url %s: %s", id.String(), err)
} }
if err := f.db.GetLocalAccountByUsername(username, &gtsmodel.Account{}); err != nil { if _, err := f.db.GetLocalAccountByUsername(username); err != nil {
if _, ok := err.(db.ErrNoEntries); ok { if err == db.ErrNoEntries {
// there are no entries for this username // there are no entries for this username
return false, nil return false, nil
} }
@ -88,8 +88,8 @@ func (f *federatingDB) Owns(c context.Context, id *url.URL) (bool, error) {
if err != nil { if err != nil {
return false, fmt.Errorf("error parsing statuses path for url %s: %s", id.String(), err) return false, fmt.Errorf("error parsing statuses path for url %s: %s", id.String(), err)
} }
if err := f.db.GetLocalAccountByUsername(username, &gtsmodel.Account{}); err != nil { if _, err := f.db.GetLocalAccountByUsername(username); err != nil {
if _, ok := err.(db.ErrNoEntries); ok { if err == db.ErrNoEntries {
// there are no entries for this username // there are no entries for this username
return false, nil return false, nil
} }
@ -105,8 +105,8 @@ func (f *federatingDB) Owns(c context.Context, id *url.URL) (bool, error) {
if err != nil { if err != nil {
return false, fmt.Errorf("error parsing statuses path for url %s: %s", id.String(), err) return false, fmt.Errorf("error parsing statuses path for url %s: %s", id.String(), err)
} }
if err := f.db.GetLocalAccountByUsername(username, &gtsmodel.Account{}); err != nil { if _, err := f.db.GetLocalAccountByUsername(username); err != nil {
if _, ok := err.(db.ErrNoEntries); ok { if err == db.ErrNoEntries {
// there are no entries for this username // there are no entries for this username
return false, nil return false, nil
} }
@ -122,8 +122,8 @@ func (f *federatingDB) Owns(c context.Context, id *url.URL) (bool, error) {
if err != nil { if err != nil {
return false, fmt.Errorf("error parsing like path for url %s: %s", id.String(), err) return false, fmt.Errorf("error parsing like path for url %s: %s", id.String(), err)
} }
if err := f.db.GetLocalAccountByUsername(username, &gtsmodel.Account{}); err != nil { if _, err := f.db.GetLocalAccountByUsername(username); err != nil {
if _, ok := err.(db.ErrNoEntries); ok { if err == db.ErrNoEntries {
// there are no entries for this username // there are no entries for this username
return false, nil return false, nil
} }
@ -131,7 +131,7 @@ func (f *federatingDB) Owns(c context.Context, id *url.URL) (bool, error) {
return false, fmt.Errorf("database error fetching account with username %s: %s", username, err) return false, fmt.Errorf("database error fetching account with username %s: %s", username, err)
} }
if err := f.db.GetByID(likeID, &gtsmodel.StatusFave{}); err != nil { if err := f.db.GetByID(likeID, &gtsmodel.StatusFave{}); err != nil {
if _, ok := err.(db.ErrNoEntries); ok { if err == db.ErrNoEntries {
// there are no entries // there are no entries
return false, nil return false, nil
} }
@ -147,8 +147,8 @@ func (f *federatingDB) Owns(c context.Context, id *url.URL) (bool, error) {
if err != nil { if err != nil {
return false, fmt.Errorf("error parsing block path for url %s: %s", id.String(), err) return false, fmt.Errorf("error parsing block path for url %s: %s", id.String(), err)
} }
if err := f.db.GetLocalAccountByUsername(username, &gtsmodel.Account{}); err != nil { if _, err := f.db.GetLocalAccountByUsername(username); err != nil {
if _, ok := err.(db.ErrNoEntries); ok { if err == db.ErrNoEntries {
// there are no entries for this username // there are no entries for this username
return false, nil return false, nil
} }
@ -156,7 +156,7 @@ func (f *federatingDB) Owns(c context.Context, id *url.URL) (bool, error) {
return false, fmt.Errorf("database error fetching account with username %s: %s", username, err) return false, fmt.Errorf("database error fetching account with username %s: %s", username, err)
} }
if err := f.db.GetByID(blockID, &gtsmodel.Block{}); err != nil { if err := f.db.GetByID(blockID, &gtsmodel.Block{}); err != nil {
if _, ok := err.(db.ErrNoEntries); ok { if err == db.ErrNoEntries {
// there are no entries // there are no entries
return false, nil return false, nil
} }

View file

@ -1,3 +1,21 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package federatingdb package federatingdb
import ( import (

View file

@ -1,3 +1,21 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package federatingdb package federatingdb
import ( import (

View file

@ -97,8 +97,8 @@ func (f *federatingDB) NewID(c context.Context, t vocab.Type) (idURL *url.URL, e
for iter := actorProp.Begin(); iter != actorProp.End(); iter = iter.Next() { for iter := actorProp.Begin(); iter != actorProp.End(); iter = iter.Next() {
// take the IRI of the first actor we can find (there should only be one) // take the IRI of the first actor we can find (there should only be one)
if iter.IsIRI() { if iter.IsIRI() {
actorAccount := &gtsmodel.Account{} // if there's an error here, just use the fallback behavior -- we don't need to return an error here
if err := f.db.GetWhere([]db.Where{{Key: "uri", Value: iter.GetIRI().String()}}, actorAccount); err == nil { // if there's an error here, just use the fallback behavior -- we don't need to return an error here if actorAccount, err := f.db.GetAccountByURI(iter.GetIRI().String()); err == nil {
newID, err := id.NewRandomULID() newID, err := id.NewRandomULID()
if err != nil { if err != nil {
return nil, err return nil, err
@ -213,7 +213,7 @@ func (f *federatingDB) ActorForOutbox(c context.Context, outboxIRI *url.URL) (ac
} }
acct := &gtsmodel.Account{} acct := &gtsmodel.Account{}
if err := f.db.GetWhere([]db.Where{{Key: "outbox_uri", Value: outboxIRI.String()}}, acct); err != nil { if err := f.db.GetWhere([]db.Where{{Key: "outbox_uri", Value: outboxIRI.String()}}, acct); err != nil {
if _, ok := err.(db.ErrNoEntries); ok { if err == db.ErrNoEntries {
return nil, fmt.Errorf("no actor found that corresponds to outbox %s", outboxIRI.String()) return nil, fmt.Errorf("no actor found that corresponds to outbox %s", outboxIRI.String())
} }
return nil, fmt.Errorf("db error searching for actor with outbox %s", outboxIRI.String()) return nil, fmt.Errorf("db error searching for actor with outbox %s", outboxIRI.String())
@ -238,7 +238,7 @@ func (f *federatingDB) ActorForInbox(c context.Context, inboxIRI *url.URL) (acto
} }
acct := &gtsmodel.Account{} acct := &gtsmodel.Account{}
if err := f.db.GetWhere([]db.Where{{Key: "inbox_uri", Value: inboxIRI.String()}}, acct); err != nil { if err := f.db.GetWhere([]db.Where{{Key: "inbox_uri", Value: inboxIRI.String()}}, acct); err != nil {
if _, ok := err.(db.ErrNoEntries); ok { if err == db.ErrNoEntries {
return nil, fmt.Errorf("no actor found that corresponds to inbox %s", inboxIRI.String()) return nil, fmt.Errorf("no actor found that corresponds to inbox %s", inboxIRI.String())
} }
return nil, fmt.Errorf("db error searching for actor with inbox %s", inboxIRI.String()) return nil, fmt.Errorf("db error searching for actor with inbox %s", inboxIRI.String())

View file

@ -113,8 +113,8 @@ func (f *federator) AuthenticatePostInbox(ctx context.Context, w http.ResponseWr
return nil, false, errors.New("username was empty") return nil, false, errors.New("username was empty")
} }
requestedAccount := &gtsmodel.Account{} requestedAccount, err := f.db.GetLocalAccountByUsername(username)
if err := f.db.GetLocalAccountByUsername(username, requestedAccount); err != nil { if err != nil {
return nil, false, fmt.Errorf("could not fetch requested account with username %s: %s", username, err) return nil, false, fmt.Errorf("could not fetch requested account with username %s: %s", username, err)
} }
@ -132,7 +132,7 @@ func (f *federator) AuthenticatePostInbox(ctx context.Context, w http.ResponseWr
// authentication has passed, so add an instance entry for this instance if it hasn't been done already // authentication has passed, so add an instance entry for this instance if it hasn't been done already
i := &gtsmodel.Instance{} i := &gtsmodel.Instance{}
if err := f.db.GetWhere([]db.Where{{Key: "domain", Value: publicKeyOwnerURI.Host, CaseInsensitive: true}}, i); err != nil { if err := f.db.GetWhere([]db.Where{{Key: "domain", Value: publicKeyOwnerURI.Host, CaseInsensitive: true}}, i); err != nil {
if _, ok := err.(db.ErrNoEntries); !ok { if err != db.ErrNoEntries {
// there's been an actual error // there's been an actual error
return ctx, false, fmt.Errorf("error getting requesting account with public key id %s: %s", publicKeyOwnerURI.String(), err) return ctx, false, fmt.Errorf("error getting requesting account with public key id %s: %s", publicKeyOwnerURI.String(), err)
} }
@ -176,8 +176,6 @@ func (f *federator) AuthenticatePostInbox(ctx context.Context, w http.ResponseWr
// Finally, if the authentication and authorization succeeds, then // Finally, if the authentication and authorization succeeds, then
// blocked must be false and error nil. The request will continue // blocked must be false and error nil. The request will continue
// to be processed. // to be processed.
//
// TODO: implement domain block checking here as well
func (f *federator) Blocked(ctx context.Context, actorIRIs []*url.URL) (bool, error) { func (f *federator) Blocked(ctx context.Context, actorIRIs []*url.URL) (bool, error) {
l := f.log.WithFields(logrus.Fields{ l := f.log.WithFields(logrus.Fields{
"func": "Blocked", "func": "Blocked",
@ -191,19 +189,18 @@ func (f *federator) Blocked(ctx context.Context, actorIRIs []*url.URL) (bool, er
return false, errors.New("requested account not set on request context, so couldn't determine blocks") return false, errors.New("requested account not set on request context, so couldn't determine blocks")
} }
for _, uri := range actorIRIs { blocked, err := f.db.AreURIsBlocked(actorIRIs)
blockedDomain, err := f.blockedDomain(uri.Host) if err != nil {
if err != nil { return false, fmt.Errorf("error checking domain blocks: %s", err)
return false, fmt.Errorf("error checking domain block: %s", err) }
} if blocked {
if blockedDomain { return blocked, nil
return true, nil }
}
requestingAccount := &gtsmodel.Account{} for _, uri := range actorIRIs {
if err := f.db.GetWhere([]db.Where{{Key: "uri", Value: uri.String()}}, requestingAccount); err != nil { requestingAccount, err := f.db.GetAccountByURI(uri.String())
_, ok := err.(db.ErrNoEntries) if err != nil {
if ok { if err == db.ErrNoEntries {
// we don't have an entry for this account so it's not blocked // we don't have an entry for this account so it's not blocked
// TODO: allow a different default to be set for this behavior // TODO: allow a different default to be set for this behavior
continue continue
@ -211,12 +208,11 @@ func (f *federator) Blocked(ctx context.Context, actorIRIs []*url.URL) (bool, er
return false, fmt.Errorf("error getting account with uri %s: %s", uri.String(), err) return false, fmt.Errorf("error getting account with uri %s: %s", uri.String(), err)
} }
// check if requested account blocks requesting account blocked, err = f.db.IsBlocked(requestedAccount.ID, requestingAccount.ID, true)
if err := f.db.GetWhere([]db.Where{ if err != nil {
{Key: "account_id", Value: requestedAccount.ID}, return false, fmt.Errorf("error checking account block: %s", err)
{Key: "target_account_id", Value: requestingAccount.ID}, }
}, &gtsmodel.Block{}); err == nil { if blocked {
// a block exists
return true, nil return true, nil
} }
} }

View file

@ -30,7 +30,7 @@ import (
) )
func (f *federator) FingerRemoteAccount(requestingUsername string, targetUsername string, targetDomain string) (*url.URL, error) { func (f *federator) FingerRemoteAccount(requestingUsername string, targetUsername string, targetDomain string) (*url.URL, error) {
if blocked, err := f.blockedDomain(targetDomain); blocked || err != nil { if blocked, err := f.db.IsDomainBlocked(targetDomain); blocked || err != nil {
return nil, fmt.Errorf("FingerRemoteAccount: domain %s is blocked", targetDomain) return nil, fmt.Errorf("FingerRemoteAccount: domain %s is blocked", targetDomain)
} }

View file

@ -1,3 +1,21 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package federation package federation
import "net/url" import "net/url"

View file

@ -1,3 +1,21 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package federation package federation
import ( import (

View file

@ -1,23 +0,0 @@
package federation
import (
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
func (f *federator) blockedDomain(host string) (bool, error) {
b := &gtsmodel.DomainBlock{}
err := f.db.GetWhere([]db.Where{{Key: "domain", Value: host, CaseInsensitive: true}}, b)
if err == nil {
// block exists
return true, nil
}
if _, ok := err.(db.ErrNoEntries); ok {
// there are no entries so there's no block
return false, nil
}
// there's an actual error
return false, err
}

View file

@ -45,11 +45,13 @@ type Account struct {
*/ */
// ID of the avatar as a media attachment // ID of the avatar as a media attachment
AvatarMediaAttachmentID string `pg:"type:CHAR(26)"` AvatarMediaAttachmentID string `pg:"type:CHAR(26)"`
AvatarMediaAttachment *MediaAttachment `pg:"rel:has-one"`
// For a non-local account, where can the header be fetched? // For a non-local account, where can the header be fetched?
AvatarRemoteURL string AvatarRemoteURL string
// ID of the header as a media attachment // ID of the header as a media attachment
HeaderMediaAttachmentID string `pg:"type:CHAR(26)"` HeaderMediaAttachmentID string `pg:"type:CHAR(26)"`
HeaderMediaAttachment *MediaAttachment `pg:"rel:has-one"`
// For a non-local account, where can the header be fetched? // For a non-local account, where can the header be fetched?
HeaderRemoteURL string HeaderRemoteURL string
// DisplayName for this account. Can be empty, then just the Username will be used for display purposes. // DisplayName for this account. Can be empty, then just the Username will be used for display purposes.

View file

@ -31,7 +31,8 @@ type DomainBlock struct {
// When was this block updated // When was this block updated
UpdatedAt time.Time `pg:"type:timestamp,notnull,default:now()"` UpdatedAt time.Time `pg:"type:timestamp,notnull,default:now()"`
// Account ID of the creator of this block // Account ID of the creator of this block
CreatedByAccountID string `pg:"type:CHAR(26),notnull"` CreatedByAccountID string `pg:"type:CHAR(26),notnull"`
CreatedByAccount *Account `pg:"rel:belongs-to"`
// Private comment on this block, viewable to admins // Private comment on this block, viewable to admins
PrivateComment string PrivateComment string
// Public comment on this block, viewable (optionally) by everyone // Public comment on this block, viewable (optionally) by everyone

View file

@ -31,5 +31,6 @@ type EmailDomainBlock struct {
// When was this block updated // When was this block updated
UpdatedAt time.Time `pg:"type:timestamp,notnull,default:now()"` UpdatedAt time.Time `pg:"type:timestamp,notnull,default:now()"`
// Account ID of the creator of this block // Account ID of the creator of this block
CreatedByAccountID string `pg:"type:CHAR(26),notnull"` CreatedByAccountID string `pg:"type:CHAR(26),notnull"`
CreatedByAccount *Account `pg:"rel:belongs-to"`
} }

View file

@ -73,5 +73,6 @@ type Emoji struct {
// Is this emoji visible in the admin emoji picker? // Is this emoji visible in the admin emoji picker?
VisibleInPicker bool `pg:",notnull,default:true"` VisibleInPicker bool `pg:",notnull,default:true"`
// In which emoji category is this emoji visible? // In which emoji category is this emoji visible?
CategoryID string `pg:"type:CHAR(26)"` CategoryID string `pg:"type:CHAR(26)"`
Status *Status `pg:"rel:belongs-to"`
} }

View file

@ -29,9 +29,11 @@ type Follow struct {
// When was this follow last updated? // When was this follow last updated?
UpdatedAt time.Time `pg:"type:timestamp,notnull,default:now()"` UpdatedAt time.Time `pg:"type:timestamp,notnull,default:now()"`
// Who does this follow belong to? // Who does this follow belong to?
AccountID string `pg:"type:CHAR(26),unique:srctarget,notnull"` AccountID string `pg:"type:CHAR(26),unique:srctarget,notnull"`
Account *Account `pg:"rel:belongs-to"`
// Who does AccountID follow? // Who does AccountID follow?
TargetAccountID string `pg:"type:CHAR(26),unique:srctarget,notnull"` TargetAccountID string `pg:"type:CHAR(26),unique:srctarget,notnull"`
TargetAccount *Account `pg:"rel:has-one"`
// Does this follow also want to see reblogs and not just posts? // Does this follow also want to see reblogs and not just posts?
ShowReblogs bool `pg:"default:true"` ShowReblogs bool `pg:"default:true"`
// What is the activitypub URI of this follow? // What is the activitypub URI of this follow?

View file

@ -29,9 +29,11 @@ type FollowRequest struct {
// When was this follow request last updated? // When was this follow request last updated?
UpdatedAt time.Time `pg:"type:timestamp,notnull,default:now()"` UpdatedAt time.Time `pg:"type:timestamp,notnull,default:now()"`
// Who does this follow request originate from? // Who does this follow request originate from?
AccountID string `pg:"type:CHAR(26),unique:srctarget,notnull"` AccountID string `pg:"type:CHAR(26),unique:srctarget,notnull"`
Account Account `pg:"rel:has-one"`
// Who is the target of this follow request? // Who is the target of this follow request?
TargetAccountID string `pg:"type:CHAR(26),unique:srctarget,notnull"` TargetAccountID string `pg:"type:CHAR(26),unique:srctarget,notnull"`
TargetAccount Account `pg:"rel:has-one"`
// Does this follow also want to see reblogs and not just posts? // Does this follow also want to see reblogs and not just posts?
ShowReblogs bool `pg:"default:true"` ShowReblogs bool `pg:"default:true"`
// What is the activitypub URI of this follow request? // What is the activitypub URI of this follow request?

View file

@ -19,7 +19,8 @@ type Instance struct {
// When was this instance suspended, if at all? // When was this instance suspended, if at all?
SuspendedAt time.Time SuspendedAt time.Time
// ID of any existing domain block for this instance in the database // ID of any existing domain block for this instance in the database
DomainBlockID string `pg:"type:CHAR(26)"` DomainBlockID string `pg:"type:CHAR(26)"`
DomainBlock *DomainBlock `pg:"rel:has-one"`
// Short description of this instance // Short description of this instance
ShortDescription string ShortDescription string
// Longer description of this instance // Longer description of this instance
@ -31,7 +32,8 @@ type Instance struct {
// Username of the contact account for this instance // Username of the contact account for this instance
ContactAccountUsername string ContactAccountUsername string
// Contact account ID in the database for this instance // Contact account ID in the database for this instance
ContactAccountID string `pg:"type:CHAR(26)"` ContactAccountID string `pg:"type:CHAR(26)"`
ContactAccount *Account `pg:"rel:has-one"`
// Reputation score of this instance // Reputation score of this instance
Reputation int64 `pg:",notnull,default:0"` Reputation int64 `pg:",notnull,default:0"`
// Version of the software used on this instance // Version of the software used on this instance

View file

@ -42,7 +42,8 @@ type MediaAttachment struct {
// Metadata about the file // Metadata about the file
FileMeta FileMeta FileMeta FileMeta
// To which account does this attachment belong // To which account does this attachment belong
AccountID string `pg:"type:CHAR(26),notnull"` AccountID string `pg:"type:CHAR(26),notnull"`
Account *Account `pg:"rel:belongs-to"`
// Description of the attachment (for screenreaders) // Description of the attachment (for screenreaders)
Description string Description string
// To which scheduled status does this attachment belong // To which scheduled status does this attachment belong

View file

@ -25,17 +25,20 @@ type Mention struct {
// ID of this mention in the database // ID of this mention in the database
ID string `pg:"type:CHAR(26),pk,notnull,unique"` ID string `pg:"type:CHAR(26),pk,notnull,unique"`
// ID of the status this mention originates from // ID of the status this mention originates from
StatusID string `pg:"type:CHAR(26),notnull"` StatusID string `pg:"type:CHAR(26),notnull"`
Status *Status `pg:"rel:belongs-to"`
// When was this mention created? // When was this mention created?
CreatedAt time.Time `pg:"type:timestamp,notnull,default:now()"` CreatedAt time.Time `pg:"type:timestamp,notnull,default:now()"`
// When was this mention last updated? // When was this mention last updated?
UpdatedAt time.Time `pg:"type:timestamp,notnull,default:now()"` UpdatedAt time.Time `pg:"type:timestamp,notnull,default:now()"`
// What's the internal account ID of the originator of the mention? // What's the internal account ID of the originator of the mention?
OriginAccountID string `pg:"type:CHAR(26),notnull"` OriginAccountID string `pg:"type:CHAR(26),notnull"`
OriginAccount *Account `pg:"rel:has-one"`
// What's the AP URI of the originator of the mention? // What's the AP URI of the originator of the mention?
OriginAccountURI string `pg:",notnull"` OriginAccountURI string `pg:",notnull"`
// What's the internal account ID of the mention target? // What's the internal account ID of the mention target?
TargetAccountID string `pg:"type:CHAR(26),notnull"` TargetAccountID string `pg:"type:CHAR(26),notnull"`
TargetAccount *Account `pg:"rel:has-one"`
// Prevent this mention from generating a notification? // Prevent this mention from generating a notification?
Silent bool Silent bool
@ -52,14 +55,14 @@ type Mention struct {
// //
// This will not be put in the database, it's just for convenience. // This will not be put in the database, it's just for convenience.
NameString string `pg:"-"` NameString string `pg:"-"`
// MentionedAccountURI is the AP ID (uri) of the user mentioned. // TargetAccountURI is the AP ID (uri) of the user mentioned.
// //
// This will not be put in the database, it's just for convenience. // This will not be put in the database, it's just for convenience.
MentionedAccountURI string `pg:"-"` TargetAccountURI string `pg:"-"`
// MentionedAccountURL is the web url of the user mentioned. // TargetAccountURL is the web url of the user mentioned.
// //
// This will not be put in the database, it's just for convenience. // This will not be put in the database, it's just for convenience.
MentionedAccountURL string `pg:"-"` TargetAccountURL string `pg:"-"`
// A pointer to the gtsmodel account of the mentioned account. // A pointer to the gtsmodel account of the mentioned account.
GTSAccount *Account `pg:"-"`
} }

View file

@ -29,24 +29,16 @@ type Notification struct {
// Creation time of this notification // Creation time of this notification
CreatedAt time.Time `pg:"type:timestamp,notnull,default:now()"` CreatedAt time.Time `pg:"type:timestamp,notnull,default:now()"`
// Which account does this notification target (ie., who will receive the notification?) // Which account does this notification target (ie., who will receive the notification?)
TargetAccountID string `pg:"type:CHAR(26),notnull"` TargetAccountID string `pg:"type:CHAR(26),notnull"`
TargetAccount *Account `pg:"rel:has-one"`
// Which account performed the action that created this notification? // Which account performed the action that created this notification?
OriginAccountID string `pg:"type:CHAR(26),notnull"` OriginAccountID string `pg:"type:CHAR(26),notnull"`
OriginAccount *Account `pg:"rel:has-one"`
// If the notification pertains to a status, what is the database ID of that status? // If the notification pertains to a status, what is the database ID of that status?
StatusID string `pg:"type:CHAR(26)"` StatusID string `pg:"type:CHAR(26)"`
Status *Status `pg:"rel:has-one"`
// Has this notification been read already? // Has this notification been read already?
Read bool Read bool
/*
NON-DATABASE fields
*/
// gts model of the target account, won't be put in the database, it's just for convenience when passing the notification around.
GTSTargetAccount *Account `pg:"-"`
// gts model of the origin account, won't be put in the database, it's just for convenience when passing the notification around.
GTSOriginAccount *Account `pg:"-"`
// gts model of the relevant status, won't be put in the database, it's just for convenience when passing the notification around.
GTSStatus *Status `pg:"-"`
} }
// NotificationType describes the reason/type of this notification. // NotificationType describes the reason/type of this notification.

View file

@ -33,13 +33,17 @@ type Status struct {
// the html-formatted content of this status // the html-formatted content of this status
Content string Content string
// Database IDs of any media attachments associated with this status // Database IDs of any media attachments associated with this status
Attachments []string `pg:",array"` AttachmentIDs []string `pg:"attachments,array"`
Attachments []*MediaAttachment `pg:"attached_media,rel:has-many"`
// Database IDs of any tags used in this status // Database IDs of any tags used in this status
Tags []string `pg:",array"` TagIDs []string `pg:"tags,array"`
Tags []*Tag `pg:"attached_tags,many2many:status_to_tags"` // https://pg.uptrace.dev/orm/many-to-many-relation/
// Database IDs of any mentions in this status // Database IDs of any mentions in this status
Mentions []string `pg:",array"` MentionIDs []string `pg:"mentions,array"`
Mentions []*Mention `pg:"attached_mentions,rel:has-many"`
// Database IDs of any emojis used in this status // Database IDs of any emojis used in this status
Emojis []string `pg:",array"` EmojiIDs []string `pg:"emojis,array"`
Emojis []*Emoji `pg:"attached_emojis,many2many:status_to_emojis"` // https://pg.uptrace.dev/orm/many-to-many-relation/
// when was this status created? // when was this status created?
CreatedAt time.Time `pg:"type:timestamp,notnull,default:now()"` CreatedAt time.Time `pg:"type:timestamp,notnull,default:now()"`
// when was this status updated? // when was this status updated?
@ -47,19 +51,24 @@ type Status struct {
// is this status from a local account? // is this status from a local account?
Local bool Local bool
// which account posted this status? // which account posted this status?
AccountID string `pg:"type:CHAR(26),notnull"` AccountID string `pg:"type:CHAR(26),notnull"`
Account *Account `pg:"rel:has-one"`
// AP uri of the owner of this status // AP uri of the owner of this status
AccountURI string AccountURI string
// id of the status this status is a reply to // id of the status this status is a reply to
InReplyToID string `pg:"type:CHAR(26)"` InReplyToID string `pg:"type:CHAR(26)"`
InReplyTo *Status `pg:"rel:has-one"`
// AP uri of the status this status is a reply to // AP uri of the status this status is a reply to
InReplyToURI string InReplyToURI string
// id of the account that this status replies to // id of the account that this status replies to
InReplyToAccountID string `pg:"type:CHAR(26)"` InReplyToAccountID string `pg:"type:CHAR(26)"`
InReplyToAccount *Account `pg:"rel:has-one"`
// id of the status this status is a boost of // id of the status this status is a boost of
BoostOfID string `pg:"type:CHAR(26)"` BoostOfID string `pg:"type:CHAR(26)"`
BoostOf *Status `pg:"rel:has-one"`
// id of the account that owns the boosted status // id of the account that owns the boosted status
BoostOfAccountID string `pg:"type:CHAR(26)"` BoostOfAccountID string `pg:"type:CHAR(26)"`
BoostOfAccount *Account `pg:"rel:has-one"`
// cw string for this status // cw string for this status
ContentWarning string ContentWarning string
// visibility entry for this status // visibility entry for this status
@ -69,7 +78,8 @@ type Status struct {
// what language is this status written in? // what language is this status written in?
Language string Language string
// Which application was used to create this status? // Which application was used to create this status?
CreatedWithApplicationID string `pg:"type:CHAR(26)"` CreatedWithApplicationID string `pg:"type:CHAR(26)"`
CreatedWithApplication *Application `pg:"rel:has-one"`
// advanced visibility for this status // advanced visibility for this status
VisibilityAdvanced *VisibilityAdvanced VisibilityAdvanced *VisibilityAdvanced
// What is the activitystreams type of this status? See: https://www.w3.org/TR/activitystreams-vocabulary/#object-types // What is the activitystreams type of this status? See: https://www.w3.org/TR/activitystreams-vocabulary/#object-types
@ -79,32 +89,18 @@ type Status struct {
Text string Text string
// Has this status been pinned by its owner? // Has this status been pinned by its owner?
Pinned bool Pinned bool
}
/* // StatusToTag is an intermediate struct to facilitate the many2many relationship between a status and one or more tags.
INTERNAL MODEL NON-DATABASE FIELDS type StatusToTag struct {
StatusID string `pg:"unique:statustag"`
TagID string `pg:"unique:statustag"`
}
These are for convenience while passing the status around internally, // StatusToEmoji is an intermediate struct to facilitate the many2many relationship between a status and one or more emojis.
but these fields should *never* be put in the db. type StatusToEmoji struct {
*/ StatusID string `pg:"unique:statusemoji"`
EmojiID string `pg:"unique:statusemoji"`
// Account that created this status
GTSAuthorAccount *Account `pg:"-"`
// Mentions created in this status
GTSMentions []*Mention `pg:"-"`
// Hashtags used in this status
GTSTags []*Tag `pg:"-"`
// Emojis used in this status
GTSEmojis []*Emoji `pg:"-"`
// MediaAttachments used in this status
GTSMediaAttachments []*MediaAttachment `pg:"-"`
// Status being replied to
GTSReplyToStatus *Status `pg:"-"`
// Account being replied to
GTSReplyToAccount *Account `pg:"-"`
// Status being boosted
GTSBoostedStatus *Status `pg:"-"`
// Account of the boosted status
GTSBoostedAccount *Account `pg:"-"`
} }
// Visibility represents the visibility granularity of a status. // Visibility represents the visibility granularity of a status.

View file

@ -27,9 +27,11 @@ type StatusBookmark struct {
// when was this bookmark created // when was this bookmark created
CreatedAt time.Time `pg:"type:timestamp,notnull,default:now()"` CreatedAt time.Time `pg:"type:timestamp,notnull,default:now()"`
// id of the account that created ('did') the bookmarking // id of the account that created ('did') the bookmarking
AccountID string `pg:"type:CHAR(26),notnull"` AccountID string `pg:"type:CHAR(26),notnull"`
Account *Account `pg:"rel:belongs-to"`
// id the account owning the bookmarked status // id the account owning the bookmarked status
TargetAccountID string `pg:"type:CHAR(26),notnull"` TargetAccountID string `pg:"type:CHAR(26),notnull"`
TargetAccount *Account `pg:"rel:has-one"`
// database id of the status that has been bookmarked // database id of the status that has been bookmarked
StatusID string `pg:"type:CHAR(26),notnull"` StatusID string `pg:"type:CHAR(26),notnull"`
} }

View file

@ -27,18 +27,14 @@ type StatusFave struct {
// when was this fave created // when was this fave created
CreatedAt time.Time `pg:"type:timestamp,notnull,default:now()"` CreatedAt time.Time `pg:"type:timestamp,notnull,default:now()"`
// id of the account that created ('did') the fave // id of the account that created ('did') the fave
AccountID string `pg:"type:CHAR(26),notnull"` AccountID string `pg:"type:CHAR(26),notnull"`
Account *Account `pg:"rel:has-one"`
// id the account owning the faved status // id the account owning the faved status
TargetAccountID string `pg:"type:CHAR(26),notnull"` TargetAccountID string `pg:"type:CHAR(26),notnull"`
TargetAccount *Account `pg:"rel:has-one"`
// database id of the status that has been 'faved' // database id of the status that has been 'faved'
StatusID string `pg:"type:CHAR(26),notnull"` StatusID string `pg:"type:CHAR(26),notnull"`
Status *Status `pg:"rel:has-one"`
// ActivityPub URI of this fave // ActivityPub URI of this fave
URI string `pg:",notnull"` URI string `pg:",notnull"`
// GTSStatus is the status being interacted with. It won't be put or retrieved from the db, it's just for conveniently passing a pointer around.
GTSStatus *Status `pg:"-"`
// GTSTargetAccount is the account being interacted with. It won't be put or retrieved from the db, it's just for conveniently passing a pointer around.
GTSTargetAccount *Account `pg:"-"`
// GTSFavingAccount is the account doing the faving. It won't be put or retrieved from the db, it's just for conveniently passing a pointer around.
GTSFavingAccount *Account `pg:"-"`
} }

View file

@ -27,9 +27,12 @@ type StatusMute struct {
// when was this mute created // when was this mute created
CreatedAt time.Time `pg:"type:timestamp,notnull,default:now()"` CreatedAt time.Time `pg:"type:timestamp,notnull,default:now()"`
// id of the account that created ('did') the mute // id of the account that created ('did') the mute
AccountID string `pg:"type:CHAR(26),notnull"` AccountID string `pg:"type:CHAR(26),notnull"`
Account *Account `pg:"rel:belongs-to"`
// id the account owning the muted status (can be the same as accountID) // id the account owning the muted status (can be the same as accountID)
TargetAccountID string `pg:"type:CHAR(26),notnull"` TargetAccountID string `pg:"type:CHAR(26),notnull"`
TargetAccount *Account `pg:"rel:has-one"`
// database id of the status that has been muted // database id of the status that has been muted
StatusID string `pg:"type:CHAR(26),notnull"` StatusID string `pg:"type:CHAR(26),notnull"`
Status *Status `pg:"rel:has-one"`
} }

View file

@ -27,7 +27,7 @@ type Tag struct {
// Href of this tag, eg https://example.org/tags/somehashtag // Href of this tag, eg https://example.org/tags/somehashtag
URL string URL string
// name of this tag -- the tag without the hash part // name of this tag -- the tag without the hash part
Name string `pg:",unique,pk,notnull"` Name string `pg:",unique,notnull"`
// Which account ID is the first one we saw using this tag? // Which account ID is the first one we saw using this tag?
FirstSeenFromAccountID string `pg:"type:CHAR(26)"` FirstSeenFromAccountID string `pg:"type:CHAR(26)"`
// when was this tag created // when was this tag created

View file

@ -35,7 +35,8 @@ type User struct {
// confirmed email address for this user, this should be unique -- only one email address registered per instance, multiple users per email are not supported // confirmed email address for this user, this should be unique -- only one email address registered per instance, multiple users per email are not supported
Email string `pg:"default:null,unique"` Email string `pg:"default:null,unique"`
// The id of the local gtsmodel.Account entry for this user, if it exists (unconfirmed users don't have an account yet) // The id of the local gtsmodel.Account entry for this user, if it exists (unconfirmed users don't have an account yet)
AccountID string `pg:"type:CHAR(26),unique"` AccountID string `pg:"type:CHAR(26),unique"`
Account *Account `pg:"rel:has-one"`
// The encrypted password of this user, generated using https://pkg.go.dev/golang.org/x/crypto/bcrypt#GenerateFromPassword. A salt is included so we're safe against 🌈 tables // The encrypted password of this user, generated using https://pkg.go.dev/golang.org/x/crypto/bcrypt#GenerateFromPassword. A salt is included so we're safe against 🌈 tables
EncryptedPassword string `pg:",notnull"` EncryptedPassword string `pg:",notnull"`
@ -68,7 +69,8 @@ type User struct {
// In what timezone/locale is this user located? // In what timezone/locale is this user located?
Locale string Locale string
// Which application id created this user? See gtsmodel.Application // Which application id created this user? See gtsmodel.Application
CreatedByApplicationID string `pg:"type:CHAR(26)"` CreatedByApplicationID string `pg:"type:CHAR(26)"`
CreatedByApplication *Application `pg:"rel:has-one"`
// When did we last contact this user // When did we last contact this user
LastEmailedAt time.Time `pg:"type:timestamp"` LastEmailedAt time.Time `pg:"type:timestamp"`

View file

@ -142,7 +142,7 @@ func (mh *mediaHandler) ProcessHeaderOrAvatar(attachment []byte, accountID strin
} }
// set it in the database // set it in the database
if err := mh.db.SetHeaderOrAvatarForAccountID(ma, accountID); err != nil { if err := mh.db.SetAccountHeaderOrAvatar(ma, accountID); err != nil {
return nil, fmt.Errorf("error putting %s in database: %s", mediaType, err) return nil, fmt.Errorf("error putting %s in database: %s", mediaType, err)
} }
@ -231,8 +231,8 @@ func (mh *mediaHandler) ProcessLocalEmoji(emojiBytes []byte, shortcode string) (
// since emoji aren't 'owned' by an account, but we still want to use the same pattern for serving them through the filserver, // since emoji aren't 'owned' by an account, but we still want to use the same pattern for serving them through the filserver,
// (ie., fileserver/ACCOUNT_ID/etc etc) we need to fetch the INSTANCE ACCOUNT from the database. That is, the account that's created // (ie., fileserver/ACCOUNT_ID/etc etc) we need to fetch the INSTANCE ACCOUNT from the database. That is, the account that's created
// with the same username as the instance hostname, which doesn't belong to any particular user. // with the same username as the instance hostname, which doesn't belong to any particular user.
instanceAccount := &gtsmodel.Account{} instanceAccount, err := mh.db.GetInstanceAccount("")
if err := mh.db.GetLocalAccountByUsername(mh.config.Host, instanceAccount); err != nil { if err != nil {
return nil, fmt.Errorf("error fetching instance account: %s", err) return nil, fmt.Errorf("error fetching instance account: %s", err)
} }

View file

@ -27,11 +27,11 @@ import (
) )
type clientStore struct { type clientStore struct {
db db.DB db db.Basic
} }
// NewClientStore returns an implementation of the oauth2 ClientStore interface, using the given db as a storage backend. // NewClientStore returns an implementation of the oauth2 ClientStore interface, using the given db as a storage backend.
func NewClientStore(db db.DB) oauth2.ClientStore { func NewClientStore(db db.Basic) oauth2.ClientStore {
pts := &clientStore{ pts := &clientStore{
db: db, db: db,
} }

View file

@ -99,7 +99,7 @@ func (suite *PgClientStoreTestSuite) TestClientSetAndDelete() {
// try to get the deleted client; we should get an error // try to get the deleted client; we should get an error
deletedClient, err := cs.GetByID(context.Background(), suite.testClientID) deletedClient, err := cs.GetByID(context.Background(), suite.testClientID)
suite.Assert().Nil(deletedClient) suite.Assert().Nil(deletedClient)
suite.Assert().EqualValues(db.ErrNoEntries{}, err) suite.Assert().EqualValues(db.ErrNoEntries, err)
} }
func TestPgClientStoreTestSuite(t *testing.T) { func TestPgClientStoreTestSuite(t *testing.T) {

View file

@ -66,7 +66,7 @@ type s struct {
} }
// New returns a new oauth server that implements the Server interface // New returns a new oauth server that implements the Server interface
func New(database db.DB, log *logrus.Logger) Server { func New(database db.Basic, log *logrus.Logger) Server {
ts := newTokenStore(context.Background(), database, log) ts := newTokenStore(context.Background(), database, log)
cs := NewClientStore(database) cs := NewClientStore(database)

View file

@ -34,7 +34,7 @@ import (
// tokenStore is an implementation of oauth2.TokenStore, which uses our db interface as a storage backend. // tokenStore is an implementation of oauth2.TokenStore, which uses our db interface as a storage backend.
type tokenStore struct { type tokenStore struct {
oauth2.TokenStore oauth2.TokenStore
db db.DB db db.Basic
log *logrus.Logger log *logrus.Logger
} }
@ -42,7 +42,7 @@ type tokenStore struct {
// //
// In order to allow tokens to 'expire', it will also set off a goroutine that iterates through // In order to allow tokens to 'expire', it will also set off a goroutine that iterates through
// the tokens in the DB once per minute and deletes any that have expired. // the tokens in the DB once per minute and deletes any that have expired.
func newTokenStore(ctx context.Context, db db.DB, log *logrus.Logger) oauth2.TokenStore { func newTokenStore(ctx context.Context, db db.Basic, log *logrus.Logger) oauth2.TokenStore {
pts := &tokenStore{ pts := &tokenStore{
db: db, db: db,
log: log, log: log,

View file

@ -31,24 +31,20 @@ import (
func (p *processor) BlockCreate(requestingAccount *gtsmodel.Account, targetAccountID string) (*apimodel.Relationship, gtserror.WithCode) { func (p *processor) BlockCreate(requestingAccount *gtsmodel.Account, targetAccountID string) (*apimodel.Relationship, gtserror.WithCode) {
// make sure the target account actually exists in our db // make sure the target account actually exists in our db
targetAcct := &gtsmodel.Account{} targetAccount, err := p.db.GetAccountByID(targetAccountID)
if err := p.db.GetByID(targetAccountID, targetAcct); err != nil { if err != nil {
if _, ok := err.(db.ErrNoEntries); ok { return nil, gtserror.NewErrorNotFound(fmt.Errorf("BlockCreate: error getting account %s from the db: %s", targetAccountID, err))
return nil, gtserror.NewErrorNotFound(fmt.Errorf("BlockCreate: account %s not found in the db: %s", targetAccountID, err))
}
} }
// if requestingAccount already blocks target account, we don't need to do anything // if requestingAccount already blocks target account, we don't need to do anything
block := &gtsmodel.Block{} if blocked, err := p.db.IsBlocked(requestingAccount.ID, targetAccountID, false); err != nil {
if err := p.db.GetWhere([]db.Where{ return nil, gtserror.NewErrorInternalError(fmt.Errorf("BlockCreate: error checking existence of block: %s", err))
{Key: "account_id", Value: requestingAccount.ID}, } else if blocked {
{Key: "target_account_id", Value: targetAccountID},
}, block); err == nil {
// block already exists, just return relationship
return p.RelationshipGet(requestingAccount, targetAccountID) return p.RelationshipGet(requestingAccount, targetAccountID)
} }
// make the block // make the block
block := &gtsmodel.Block{}
newBlockID, err := id.NewULID() newBlockID, err := id.NewULID()
if err != nil { if err != nil {
return nil, gtserror.NewErrorInternalError(err) return nil, gtserror.NewErrorInternalError(err)
@ -57,7 +53,7 @@ func (p *processor) BlockCreate(requestingAccount *gtsmodel.Account, targetAccou
block.AccountID = requestingAccount.ID block.AccountID = requestingAccount.ID
block.Account = requestingAccount block.Account = requestingAccount
block.TargetAccountID = targetAccountID block.TargetAccountID = targetAccountID
block.TargetAccount = targetAcct block.TargetAccount = targetAccount
block.URI = util.GenerateURIForBlock(requestingAccount.Username, p.config.Protocol, p.config.Host, newBlockID) block.URI = util.GenerateURIForBlock(requestingAccount.Username, p.config.Protocol, p.config.Host, newBlockID)
// whack it in the database // whack it in the database
@ -123,7 +119,7 @@ func (p *processor) BlockCreate(requestingAccount *gtsmodel.Account, targetAccou
URI: frURI, URI: frURI,
}, },
OriginAccount: requestingAccount, OriginAccount: requestingAccount,
TargetAccount: targetAcct, TargetAccount: targetAccount,
} }
} }
@ -138,7 +134,7 @@ func (p *processor) BlockCreate(requestingAccount *gtsmodel.Account, targetAccou
URI: fURI, URI: fURI,
}, },
OriginAccount: requestingAccount, OriginAccount: requestingAccount,
TargetAccount: targetAcct, TargetAccount: targetAccount,
} }
} }
@ -148,7 +144,7 @@ func (p *processor) BlockCreate(requestingAccount *gtsmodel.Account, targetAccou
APActivityType: gtsmodel.ActivityStreamsCreate, APActivityType: gtsmodel.ActivityStreamsCreate,
GTSModel: block, GTSModel: block,
OriginAccount: requestingAccount, OriginAccount: requestingAccount,
TargetAccount: targetAcct, TargetAccount: targetAccount,
} }
return p.RelationshipGet(requestingAccount, targetAccountID) return p.RelationshipGet(requestingAccount, targetAccountID)

View file

@ -31,38 +31,33 @@ import (
func (p *processor) FollowCreate(requestingAccount *gtsmodel.Account, form *apimodel.AccountFollowRequest) (*apimodel.Relationship, gtserror.WithCode) { func (p *processor) FollowCreate(requestingAccount *gtsmodel.Account, form *apimodel.AccountFollowRequest) (*apimodel.Relationship, gtserror.WithCode) {
// if there's a block between the accounts we shouldn't create the request ofc // if there's a block between the accounts we shouldn't create the request ofc
blocked, err := p.db.Blocked(requestingAccount.ID, form.ID) if blocked, err := p.db.IsBlocked(requestingAccount.ID, form.ID, true); err != nil {
if err != nil {
return nil, gtserror.NewErrorInternalError(err) return nil, gtserror.NewErrorInternalError(err)
} } else if blocked {
if blocked { return nil, gtserror.NewErrorNotFound(fmt.Errorf("block exists between accounts"))
return nil, gtserror.NewErrorNotFound(fmt.Errorf("accountfollowcreate: block exists between accounts"))
} }
// make sure the target account actually exists in our db // make sure the target account actually exists in our db
targetAcct := &gtsmodel.Account{} targetAcct, err := p.db.GetAccountByID(form.ID)
if err := p.db.GetByID(form.ID, targetAcct); err != nil { if err != nil {
if _, ok := err.(db.ErrNoEntries); ok { if err == db.ErrNoEntries {
return nil, gtserror.NewErrorNotFound(fmt.Errorf("accountfollowcreate: account %s not found in the db: %s", form.ID, err)) return nil, gtserror.NewErrorNotFound(fmt.Errorf("accountfollowcreate: account %s not found in the db: %s", form.ID, err))
} }
return nil, gtserror.NewErrorInternalError(err)
} }
// check if a follow exists already // check if a follow exists already
follows, err := p.db.Follows(requestingAccount, targetAcct) if follows, err := p.db.IsFollowing(requestingAccount, targetAcct); err != nil {
if err != nil {
return nil, gtserror.NewErrorInternalError(fmt.Errorf("accountfollowcreate: error checking follow in db: %s", err)) return nil, gtserror.NewErrorInternalError(fmt.Errorf("accountfollowcreate: error checking follow in db: %s", err))
} } else if follows {
if follows {
// already follows so just return the relationship // already follows so just return the relationship
return p.RelationshipGet(requestingAccount, form.ID) return p.RelationshipGet(requestingAccount, form.ID)
} }
// check if a follow exists already // check if a follow request exists already
followRequested, err := p.db.FollowRequested(requestingAccount, targetAcct) if followRequested, err := p.db.IsFollowRequested(requestingAccount, targetAcct); err != nil {
if err != nil {
return nil, gtserror.NewErrorInternalError(fmt.Errorf("accountfollowcreate: error checking follow request in db: %s", err)) return nil, gtserror.NewErrorInternalError(fmt.Errorf("accountfollowcreate: error checking follow request in db: %s", err))
} } else if followRequested {
if followRequested {
// already follow requested so just return the relationship // already follow requested so just return the relationship
return p.RelationshipGet(requestingAccount, form.ID) return p.RelationshipGet(requestingAccount, form.ID)
} }

View file

@ -133,9 +133,9 @@ func (p *processor) Delete(account *gtsmodel.Account, origin string) error {
var maxID string var maxID string
selectStatusesLoop: selectStatusesLoop:
for { for {
statuses, err := p.db.GetStatusesForAccount(account.ID, 20, false, maxID, false, false) statuses, err := p.db.GetAccountStatuses(account.ID, 20, false, maxID, false, false)
if err != nil { if err != nil {
if _, ok := err.(db.ErrNoEntries); ok { if err == db.ErrNoEntries {
// no statuses left for this instance so we're done // no statuses left for this instance so we're done
l.Infof("Delete: done iterating through statuses for account %s", account.Username) l.Infof("Delete: done iterating through statuses for account %s", account.Username)
break selectStatusesLoop break selectStatusesLoop
@ -147,7 +147,7 @@ selectStatusesLoop:
for i, s := range statuses { for i, s := range statuses {
// pass the status delete through the client api channel for processing // pass the status delete through the client api channel for processing
s.GTSAuthorAccount = account s.Account = account
l.Debug("putting status in the client api channel") l.Debug("putting status in the client api channel")
p.fromClientAPI <- gtsmodel.FromClientAPI{ p.fromClientAPI <- gtsmodel.FromClientAPI{
APObjectType: gtsmodel.ActivityStreamsNote, APObjectType: gtsmodel.ActivityStreamsNote,
@ -158,7 +158,7 @@ selectStatusesLoop:
} }
if err := p.db.DeleteByID(s.ID, s); err != nil { if err := p.db.DeleteByID(s.ID, s); err != nil {
if _, ok := err.(db.ErrNoEntries); !ok { if err != db.ErrNoEntries {
// actual error has occurred // actual error has occurred
l.Errorf("Delete: db error status %s for account %s: %s", s.ID, account.Username, err) l.Errorf("Delete: db error status %s for account %s: %s", s.ID, account.Username, err)
break selectStatusesLoop break selectStatusesLoop
@ -168,7 +168,7 @@ selectStatusesLoop:
// if there are any boosts of this status, delete them as well // if there are any boosts of this status, delete them as well
boosts := []*gtsmodel.Status{} boosts := []*gtsmodel.Status{}
if err := p.db.GetWhere([]db.Where{{Key: "boost_of_id", Value: s.ID}}, &boosts); err != nil { if err := p.db.GetWhere([]db.Where{{Key: "boost_of_id", Value: s.ID}}, &boosts); err != nil {
if _, ok := err.(db.ErrNoEntries); !ok { if err != db.ErrNoEntries {
// an actual error has occurred // an actual error has occurred
l.Errorf("Delete: db error selecting boosts of status %s for account %s: %s", s.ID, account.Username, err) l.Errorf("Delete: db error selecting boosts of status %s for account %s: %s", s.ID, account.Username, err)
break selectStatusesLoop break selectStatusesLoop
@ -190,7 +190,7 @@ selectStatusesLoop:
} }
if err := p.db.DeleteByID(b.ID, b); err != nil { if err := p.db.DeleteByID(b.ID, b); err != nil {
if _, ok := err.(db.ErrNoEntries); !ok { if err != db.ErrNoEntries {
// actual error has occurred // actual error has occurred
l.Errorf("Delete: db error deleting boost with id %s: %s", b.ID, err) l.Errorf("Delete: db error deleting boost with id %s: %s", b.ID, err)
break selectStatusesLoop break selectStatusesLoop

View file

@ -30,7 +30,7 @@ import (
func (p *processor) Get(requestingAccount *gtsmodel.Account, targetAccountID string) (*apimodel.Account, error) { func (p *processor) Get(requestingAccount *gtsmodel.Account, targetAccountID string) (*apimodel.Account, error) {
targetAccount := &gtsmodel.Account{} targetAccount := &gtsmodel.Account{}
if err := p.db.GetByID(targetAccountID, targetAccount); err != nil { if err := p.db.GetByID(targetAccountID, targetAccount); err != nil {
if _, ok := err.(db.ErrNoEntries); ok { if err == db.ErrNoEntries {
return nil, errors.New("account not found") return nil, errors.New("account not found")
} }
return nil, fmt.Errorf("db error: %s", err) return nil, fmt.Errorf("db error: %s", err)
@ -39,7 +39,7 @@ func (p *processor) Get(requestingAccount *gtsmodel.Account, targetAccountID str
var blocked bool var blocked bool
var err error var err error
if requestingAccount != nil { if requestingAccount != nil {
blocked, err = p.db.Blocked(requestingAccount.ID, targetAccountID) blocked, err = p.db.IsBlocked(requestingAccount.ID, targetAccountID, true)
if err != nil { if err != nil {
return nil, fmt.Errorf("error checking account block: %s", err) return nil, fmt.Errorf("error checking account block: %s", err)
} }

View file

@ -28,26 +28,23 @@ import (
) )
func (p *processor) FollowersGet(requestingAccount *gtsmodel.Account, targetAccountID string) ([]apimodel.Account, gtserror.WithCode) { func (p *processor) FollowersGet(requestingAccount *gtsmodel.Account, targetAccountID string) ([]apimodel.Account, gtserror.WithCode) {
blocked, err := p.db.Blocked(requestingAccount.ID, targetAccountID) if blocked, err := p.db.IsBlocked(requestingAccount.ID, targetAccountID, true); err != nil {
if err != nil {
return nil, gtserror.NewErrorInternalError(err) return nil, gtserror.NewErrorInternalError(err)
} } else if blocked {
if blocked {
return nil, gtserror.NewErrorNotFound(fmt.Errorf("block exists between accounts")) return nil, gtserror.NewErrorNotFound(fmt.Errorf("block exists between accounts"))
} }
followers := []gtsmodel.Follow{}
accounts := []apimodel.Account{} accounts := []apimodel.Account{}
if err := p.db.GetFollowersByAccountID(targetAccountID, &followers, false); err != nil { follows, err := p.db.GetAccountFollowedBy(targetAccountID, false)
if _, ok := err.(db.ErrNoEntries); ok { if err != nil {
if err == db.ErrNoEntries {
return accounts, nil return accounts, nil
} }
return nil, gtserror.NewErrorInternalError(err) return nil, gtserror.NewErrorInternalError(err)
} }
for _, f := range followers { for _, f := range follows {
blocked, err := p.db.Blocked(requestingAccount.ID, f.AccountID) blocked, err := p.db.IsBlocked(requestingAccount.ID, f.AccountID, true)
if err != nil { if err != nil {
return nil, gtserror.NewErrorInternalError(err) return nil, gtserror.NewErrorInternalError(err)
} }
@ -55,15 +52,18 @@ func (p *processor) FollowersGet(requestingAccount *gtsmodel.Account, targetAcco
continue continue
} }
a := &gtsmodel.Account{} if f.Account == nil {
if err := p.db.GetByID(f.AccountID, a); err != nil { a, err := p.db.GetAccountByID(f.AccountID)
if _, ok := err.(db.ErrNoEntries); ok { if err != nil {
continue if err == db.ErrNoEntries {
continue
}
return nil, gtserror.NewErrorInternalError(err)
} }
return nil, gtserror.NewErrorInternalError(err) f.Account = a
} }
account, err := p.tc.AccountToMastoPublic(a) account, err := p.tc.AccountToMastoPublic(f.Account)
if err != nil { if err != nil {
return nil, gtserror.NewErrorInternalError(err) return nil, gtserror.NewErrorInternalError(err)
} }

View file

@ -28,26 +28,23 @@ import (
) )
func (p *processor) FollowingGet(requestingAccount *gtsmodel.Account, targetAccountID string) ([]apimodel.Account, gtserror.WithCode) { func (p *processor) FollowingGet(requestingAccount *gtsmodel.Account, targetAccountID string) ([]apimodel.Account, gtserror.WithCode) {
blocked, err := p.db.Blocked(requestingAccount.ID, targetAccountID) if blocked, err := p.db.IsBlocked(requestingAccount.ID, targetAccountID, true); err != nil {
if err != nil {
return nil, gtserror.NewErrorInternalError(err) return nil, gtserror.NewErrorInternalError(err)
} } else if blocked {
if blocked {
return nil, gtserror.NewErrorNotFound(fmt.Errorf("block exists between accounts")) return nil, gtserror.NewErrorNotFound(fmt.Errorf("block exists between accounts"))
} }
following := []gtsmodel.Follow{}
accounts := []apimodel.Account{} accounts := []apimodel.Account{}
if err := p.db.GetFollowingByAccountID(targetAccountID, &following); err != nil { follows, err := p.db.GetAccountFollows(targetAccountID)
if _, ok := err.(db.ErrNoEntries); ok { if err != nil {
if err == db.ErrNoEntries {
return accounts, nil return accounts, nil
} }
return nil, gtserror.NewErrorInternalError(err) return nil, gtserror.NewErrorInternalError(err)
} }
for _, f := range following { for _, f := range follows {
blocked, err := p.db.Blocked(requestingAccount.ID, f.AccountID) blocked, err := p.db.IsBlocked(requestingAccount.ID, f.AccountID, true)
if err != nil { if err != nil {
return nil, gtserror.NewErrorInternalError(err) return nil, gtserror.NewErrorInternalError(err)
} }
@ -55,15 +52,18 @@ func (p *processor) FollowingGet(requestingAccount *gtsmodel.Account, targetAcco
continue continue
} }
a := &gtsmodel.Account{} if f.TargetAccount == nil {
if err := p.db.GetByID(f.TargetAccountID, a); err != nil { a, err := p.db.GetAccountByID(f.TargetAccountID)
if _, ok := err.(db.ErrNoEntries); ok { if err != nil {
continue if err == db.ErrNoEntries {
continue
}
return nil, gtserror.NewErrorInternalError(err)
} }
return nil, gtserror.NewErrorInternalError(err) f.TargetAccount = a
} }
account, err := p.tc.AccountToMastoPublic(a) account, err := p.tc.AccountToMastoPublic(f.TargetAccount)
if err != nil { if err != nil {
return nil, gtserror.NewErrorInternalError(err) return nil, gtserror.NewErrorInternalError(err)
} }

View file

@ -28,18 +28,17 @@ import (
) )
func (p *processor) StatusesGet(requestingAccount *gtsmodel.Account, targetAccountID string, limit int, excludeReplies bool, maxID string, pinnedOnly bool, mediaOnly bool) ([]apimodel.Status, gtserror.WithCode) { func (p *processor) StatusesGet(requestingAccount *gtsmodel.Account, targetAccountID string, limit int, excludeReplies bool, maxID string, pinnedOnly bool, mediaOnly bool) ([]apimodel.Status, gtserror.WithCode) {
targetAccount := &gtsmodel.Account{} if blocked, err := p.db.IsBlocked(requestingAccount.ID, targetAccountID, true); err != nil {
if err := p.db.GetByID(targetAccountID, targetAccount); err != nil {
if _, ok := err.(db.ErrNoEntries); ok {
return nil, gtserror.NewErrorNotFound(fmt.Errorf("no entry found for account id %s", targetAccountID))
}
return nil, gtserror.NewErrorInternalError(err) return nil, gtserror.NewErrorInternalError(err)
} else if blocked {
return nil, gtserror.NewErrorNotFound(fmt.Errorf("block exists between accounts"))
} }
apiStatuses := []apimodel.Status{} apiStatuses := []apimodel.Status{}
statuses, err := p.db.GetStatusesForAccount(targetAccountID, limit, excludeReplies, maxID, pinnedOnly, mediaOnly)
statuses, err := p.db.GetAccountStatuses(targetAccountID, limit, excludeReplies, maxID, pinnedOnly, mediaOnly)
if err != nil { if err != nil {
if _, ok := err.(db.ErrNoEntries); ok { if err == db.ErrNoEntries {
return apiStatuses, nil return apiStatuses, nil
} }
return nil, gtserror.NewErrorInternalError(err) return nil, gtserror.NewErrorInternalError(err)

View file

@ -29,11 +29,9 @@ import (
func (p *processor) BlockRemove(requestingAccount *gtsmodel.Account, targetAccountID string) (*apimodel.Relationship, gtserror.WithCode) { func (p *processor) BlockRemove(requestingAccount *gtsmodel.Account, targetAccountID string) (*apimodel.Relationship, gtserror.WithCode) {
// make sure the target account actually exists in our db // make sure the target account actually exists in our db
targetAcct := &gtsmodel.Account{} targetAccount, err := p.db.GetAccountByID(targetAccountID)
if err := p.db.GetByID(targetAccountID, targetAcct); err != nil { if err != nil {
if _, ok := err.(db.ErrNoEntries); ok { return nil, gtserror.NewErrorNotFound(fmt.Errorf("BlockCreate: error getting account %s from the db: %s", targetAccountID, err))
return nil, gtserror.NewErrorNotFound(fmt.Errorf("BlockRemove: account %s not found in the db: %s", targetAccountID, err))
}
} }
// check if a block exists, and remove it if it does (storing the URI for later) // check if a block exists, and remove it if it does (storing the URI for later)
@ -44,7 +42,7 @@ func (p *processor) BlockRemove(requestingAccount *gtsmodel.Account, targetAccou
{Key: "target_account_id", Value: targetAccountID}, {Key: "target_account_id", Value: targetAccountID},
}, block); err == nil { }, block); err == nil {
block.Account = requestingAccount block.Account = requestingAccount
block.TargetAccount = targetAcct block.TargetAccount = targetAccount
if err := p.db.DeleteByID(block.ID, &gtsmodel.Block{}); err != nil { if err := p.db.DeleteByID(block.ID, &gtsmodel.Block{}); err != nil {
return nil, gtserror.NewErrorInternalError(fmt.Errorf("BlockRemove: error removing block from db: %s", err)) return nil, gtserror.NewErrorInternalError(fmt.Errorf("BlockRemove: error removing block from db: %s", err))
} }
@ -58,7 +56,7 @@ func (p *processor) BlockRemove(requestingAccount *gtsmodel.Account, targetAccou
APActivityType: gtsmodel.ActivityStreamsUndo, APActivityType: gtsmodel.ActivityStreamsUndo,
GTSModel: block, GTSModel: block,
OriginAccount: requestingAccount, OriginAccount: requestingAccount,
TargetAccount: targetAcct, TargetAccount: targetAccount,
} }
} }

Some files were not shown because too many files have changed in this diff Show more