2021-03-02 17:26:30 +00:00
|
|
|
/*
|
|
|
|
GoToSocial
|
|
|
|
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
|
|
|
|
|
|
|
|
This program is free software: you can redistribute it and/or modify
|
|
|
|
it under the terms of the GNU Affero General Public License as published by
|
|
|
|
the Free Software Foundation, either version 3 of the License, or
|
|
|
|
(at your option) any later version.
|
|
|
|
|
|
|
|
This program is distributed in the hope that it will be useful,
|
|
|
|
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
|
|
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
|
|
GNU Affero General Public License for more details.
|
|
|
|
|
|
|
|
You should have received a copy of the GNU Affero General Public License
|
|
|
|
along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|
|
|
*/
|
|
|
|
|
|
|
|
package db
|
|
|
|
|
|
|
|
import (
|
|
|
|
"context"
|
2021-04-01 18:46:45 +00:00
|
|
|
"crypto/rand"
|
|
|
|
"crypto/rsa"
|
2021-03-02 17:26:30 +00:00
|
|
|
"errors"
|
|
|
|
"fmt"
|
2021-04-01 18:46:45 +00:00
|
|
|
"net"
|
|
|
|
"net/mail"
|
2021-03-04 13:38:18 +00:00
|
|
|
"regexp"
|
2021-03-03 17:12:02 +00:00
|
|
|
"strings"
|
2021-03-02 21:52:31 +00:00
|
|
|
"time"
|
2021-03-02 17:26:30 +00:00
|
|
|
|
2021-03-22 21:26:54 +00:00
|
|
|
"github.com/go-fed/activity/pub"
|
2021-03-05 17:31:12 +00:00
|
|
|
"github.com/go-pg/pg/extra/pgdebug"
|
|
|
|
"github.com/go-pg/pg/v10"
|
|
|
|
"github.com/go-pg/pg/v10/orm"
|
2021-03-02 21:52:31 +00:00
|
|
|
"github.com/sirupsen/logrus"
|
2021-04-01 18:46:45 +00:00
|
|
|
"github.com/superseriousbusiness/gotosocial/internal/config"
|
|
|
|
"github.com/superseriousbusiness/gotosocial/internal/db/model"
|
|
|
|
"github.com/superseriousbusiness/gotosocial/internal/util"
|
|
|
|
"github.com/superseriousbusiness/gotosocial/pkg/mastotypes"
|
|
|
|
"golang.org/x/crypto/bcrypt"
|
2021-03-02 17:26:30 +00:00
|
|
|
)
|
|
|
|
|
2021-03-22 21:26:54 +00:00
|
|
|
// postgresService satisfies the DB interface
|
2021-03-02 17:26:30 +00:00
|
|
|
type postgresService struct {
|
2021-04-01 18:46:45 +00:00
|
|
|
config *config.Config
|
2021-03-22 21:26:54 +00:00
|
|
|
conn *pg.DB
|
|
|
|
log *logrus.Entry
|
|
|
|
cancel context.CancelFunc
|
|
|
|
federationDB pub.Database
|
2021-03-02 17:26:30 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
// newPostgresService returns a postgresService derived from the provided config, which implements the go-fed DB interface.
|
|
|
|
// Under the hood, it uses https://github.com/go-pg/pg to create and maintain a database connection.
|
2021-04-01 18:46:45 +00:00
|
|
|
func newPostgresService(ctx context.Context, c *config.Config, log *logrus.Entry) (DB, error) {
|
2021-03-04 13:38:18 +00:00
|
|
|
opts, err := derivePGOptions(c)
|
2021-03-02 17:26:30 +00:00
|
|
|
if err != nil {
|
|
|
|
return nil, fmt.Errorf("could not create postgres service: %s", err)
|
|
|
|
}
|
2021-03-05 17:31:12 +00:00
|
|
|
log.Debugf("using pg options: %+v", opts)
|
2021-03-02 21:52:31 +00:00
|
|
|
|
|
|
|
readyChan := make(chan interface{})
|
2021-03-05 17:31:12 +00:00
|
|
|
opts.OnConnect = func(ctx context.Context, c *pg.Conn) error {
|
2021-03-02 21:52:31 +00:00
|
|
|
close(readyChan)
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
|
|
|
// create a connection
|
|
|
|
pgCtx, cancel := context.WithCancel(ctx)
|
|
|
|
conn := pg.Connect(opts).WithContext(pgCtx)
|
|
|
|
|
2021-03-05 17:31:12 +00:00
|
|
|
// this will break the logfmt format we normally log in,
|
|
|
|
// since we can't choose where pg outputs to and it defaults to
|
|
|
|
// stdout. So use this option with care!
|
|
|
|
if log.Logger.GetLevel() >= logrus.TraceLevel {
|
|
|
|
conn.AddQueryHook(pgdebug.DebugHook{
|
|
|
|
// Print all queries.
|
|
|
|
Verbose: true,
|
|
|
|
})
|
|
|
|
}
|
|
|
|
|
2021-03-02 21:52:31 +00:00
|
|
|
// actually *begin* the connection so that we can tell if the db is there
|
|
|
|
// and listening, and also trigger the opts.OnConnect function passed in above
|
2021-03-05 17:31:12 +00:00
|
|
|
if err := conn.Ping(ctx); err != nil {
|
2021-03-02 21:52:31 +00:00
|
|
|
cancel()
|
|
|
|
return nil, fmt.Errorf("db connection error: %s", err)
|
|
|
|
}
|
|
|
|
|
2021-03-05 17:31:12 +00:00
|
|
|
// print out discovered postgres version
|
|
|
|
var version string
|
|
|
|
if _, err = conn.QueryOneContext(ctx, pg.Scan(&version), "SELECT version()"); err != nil {
|
2021-03-02 21:52:31 +00:00
|
|
|
cancel()
|
|
|
|
return nil, fmt.Errorf("db connection error: %s", err)
|
|
|
|
}
|
2021-03-05 17:31:12 +00:00
|
|
|
log.Infof("connected to postgres version: %s", version)
|
2021-03-02 21:52:31 +00:00
|
|
|
|
|
|
|
// make sure the opts.OnConnect function has been triggered
|
|
|
|
// and closed the ready channel
|
|
|
|
select {
|
|
|
|
case <-readyChan:
|
|
|
|
log.Infof("postgres connection ready")
|
|
|
|
case <-time.After(5 * time.Second):
|
|
|
|
cancel()
|
|
|
|
return nil, errors.New("db connection timeout")
|
|
|
|
}
|
|
|
|
|
2021-04-01 18:46:45 +00:00
|
|
|
ps := &postgresService{
|
|
|
|
config: c,
|
|
|
|
conn: conn,
|
|
|
|
log: log,
|
|
|
|
cancel: cancel,
|
|
|
|
}
|
2021-03-02 17:26:30 +00:00
|
|
|
|
2021-04-01 18:46:45 +00:00
|
|
|
federatingDB := newFederatingDB(ps, c)
|
|
|
|
ps.federationDB = federatingDB
|
|
|
|
|
|
|
|
// we can confidently return this useable postgres service now
|
|
|
|
return ps, nil
|
2021-03-22 21:26:54 +00:00
|
|
|
}
|
|
|
|
|
2021-03-02 17:26:30 +00:00
|
|
|
/*
|
|
|
|
HANDY STUFF
|
|
|
|
*/
|
|
|
|
|
|
|
|
// derivePGOptions takes an application config and returns either a ready-to-use *pg.Options
|
|
|
|
// with sensible defaults, or an error if it's not satisfied by the provided config.
|
2021-03-04 13:38:18 +00:00
|
|
|
func derivePGOptions(c *config.Config) (*pg.Options, error) {
|
|
|
|
if strings.ToUpper(c.DBConfig.Type) != dbTypePostgres {
|
|
|
|
return nil, fmt.Errorf("expected db type of %s but got %s", dbTypePostgres, c.DBConfig.Type)
|
2021-03-02 17:26:30 +00:00
|
|
|
}
|
|
|
|
|
2021-03-04 11:07:24 +00:00
|
|
|
// validate port
|
2021-03-04 13:38:18 +00:00
|
|
|
if c.DBConfig.Port == 0 {
|
2021-03-04 11:07:24 +00:00
|
|
|
return nil, errors.New("no port set")
|
2021-03-02 17:26:30 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
// validate address
|
2021-03-04 13:38:18 +00:00
|
|
|
if c.DBConfig.Address == "" {
|
2021-03-04 11:07:24 +00:00
|
|
|
return nil, errors.New("no address set")
|
2021-03-02 21:52:31 +00:00
|
|
|
}
|
2021-03-04 13:38:18 +00:00
|
|
|
|
|
|
|
ipv4Regex := regexp.MustCompile(`^(?:(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.){3}(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)$`)
|
|
|
|
hostnameRegex := regexp.MustCompile(`^(?:[a-z0-9]+(?:-[a-z0-9]+)*\.)+[a-z]{2,}$`)
|
|
|
|
if !hostnameRegex.MatchString(c.DBConfig.Address) && !ipv4Regex.MatchString(c.DBConfig.Address) && c.DBConfig.Address != "localhost" {
|
|
|
|
return nil, fmt.Errorf("address %s was neither an ipv4 address nor a valid hostname", c.DBConfig.Address)
|
2021-03-02 21:52:31 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
// validate username
|
2021-03-04 13:38:18 +00:00
|
|
|
if c.DBConfig.User == "" {
|
2021-03-04 11:07:24 +00:00
|
|
|
return nil, errors.New("no user set")
|
2021-03-02 17:26:30 +00:00
|
|
|
}
|
2021-03-02 21:52:31 +00:00
|
|
|
|
|
|
|
// validate that there's a password
|
2021-03-04 13:38:18 +00:00
|
|
|
if c.DBConfig.Password == "" {
|
2021-03-02 21:52:31 +00:00
|
|
|
return nil, errors.New("no password set")
|
|
|
|
}
|
|
|
|
|
|
|
|
// validate database
|
2021-03-04 13:38:18 +00:00
|
|
|
if c.DBConfig.Database == "" {
|
2021-03-04 11:07:24 +00:00
|
|
|
return nil, errors.New("no database set")
|
2021-03-02 17:26:30 +00:00
|
|
|
}
|
|
|
|
|
2021-03-02 21:52:31 +00:00
|
|
|
// We can rely on the pg library we're using to set
|
|
|
|
// sensible defaults for everything we don't set here.
|
2021-03-02 17:26:30 +00:00
|
|
|
options := &pg.Options{
|
2021-03-04 13:38:18 +00:00
|
|
|
Addr: fmt.Sprintf("%s:%d", c.DBConfig.Address, c.DBConfig.Port),
|
|
|
|
User: c.DBConfig.User,
|
|
|
|
Password: c.DBConfig.Password,
|
|
|
|
Database: c.DBConfig.Database,
|
|
|
|
ApplicationName: c.ApplicationName,
|
2021-03-02 17:26:30 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
return options, nil
|
|
|
|
}
|
|
|
|
|
|
|
|
/*
|
2021-04-01 18:46:45 +00:00
|
|
|
FEDERATION FUNCTIONALITY
|
2021-03-02 17:26:30 +00:00
|
|
|
*/
|
|
|
|
|
2021-04-01 18:46:45 +00:00
|
|
|
func (ps *postgresService) Federation() pub.Database {
|
|
|
|
return ps.federationDB
|
|
|
|
}
|
|
|
|
|
|
|
|
/*
|
|
|
|
BASIC DB FUNCTIONALITY
|
|
|
|
*/
|
|
|
|
|
|
|
|
func (ps *postgresService) CreateTable(i interface{}) error {
|
|
|
|
return ps.conn.Model(i).CreateTable(&orm.CreateTableOptions{
|
|
|
|
IfNotExists: true,
|
|
|
|
})
|
|
|
|
}
|
|
|
|
|
|
|
|
func (ps *postgresService) DropTable(i interface{}) error {
|
|
|
|
return ps.conn.Model(i).DropTable(&orm.DropTableOptions{
|
|
|
|
IfExists: true,
|
|
|
|
})
|
|
|
|
}
|
|
|
|
|
2021-03-02 21:52:31 +00:00
|
|
|
func (ps *postgresService) Stop(ctx context.Context) error {
|
|
|
|
ps.log.Info("closing db connection")
|
|
|
|
if err := ps.conn.Close(); err != nil {
|
|
|
|
// only cancel if there's a problem closing the db
|
|
|
|
ps.cancel()
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
return nil
|
2021-03-02 17:26:30 +00:00
|
|
|
}
|
2021-03-05 17:31:12 +00:00
|
|
|
|
2021-04-01 18:46:45 +00:00
|
|
|
func (ps *postgresService) IsHealthy(ctx context.Context) error {
|
|
|
|
return ps.conn.Ping(ctx)
|
|
|
|
}
|
|
|
|
|
2021-03-05 17:31:12 +00:00
|
|
|
func (ps *postgresService) CreateSchema(ctx context.Context) error {
|
|
|
|
models := []interface{}{
|
2021-04-01 18:46:45 +00:00
|
|
|
(*model.Account)(nil),
|
|
|
|
(*model.Status)(nil),
|
|
|
|
(*model.User)(nil),
|
2021-03-05 17:31:12 +00:00
|
|
|
}
|
|
|
|
ps.log.Info("creating db schema")
|
|
|
|
|
|
|
|
for _, model := range models {
|
|
|
|
err := ps.conn.Model(model).CreateTable(&orm.CreateTableOptions{
|
|
|
|
IfNotExists: true,
|
|
|
|
})
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
ps.log.Info("db schema created")
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
2021-03-22 21:26:54 +00:00
|
|
|
func (ps *postgresService) GetByID(id string, i interface{}) error {
|
2021-04-01 18:46:45 +00:00
|
|
|
if err := ps.conn.Model(i).Where("id = ?", id).Select(); err != nil {
|
|
|
|
if err == pg.ErrNoRows {
|
|
|
|
return ErrNoEntries{}
|
|
|
|
}
|
|
|
|
return err
|
|
|
|
|
|
|
|
}
|
|
|
|
return nil
|
2021-03-22 21:26:54 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
func (ps *postgresService) GetWhere(key string, value interface{}, i interface{}) error {
|
2021-04-01 18:46:45 +00:00
|
|
|
if err := ps.conn.Model(i).Where("? = ?", pg.Safe(key), value).Select(); err != nil {
|
|
|
|
if err == pg.ErrNoRows {
|
|
|
|
return ErrNoEntries{}
|
|
|
|
}
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
return nil
|
2021-03-22 21:26:54 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
func (ps *postgresService) GetAll(i interface{}) error {
|
2021-04-01 18:46:45 +00:00
|
|
|
if err := ps.conn.Model(i).Select(); err != nil {
|
|
|
|
if err == pg.ErrNoRows {
|
|
|
|
return ErrNoEntries{}
|
|
|
|
}
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
return nil
|
2021-03-22 21:26:54 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
func (ps *postgresService) Put(i interface{}) error {
|
|
|
|
_, err := ps.conn.Model(i).Insert(i)
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
|
|
|
func (ps *postgresService) UpdateByID(id string, i interface{}) error {
|
2021-04-01 18:46:45 +00:00
|
|
|
if _, err := ps.conn.Model(i).OnConflict("(id) DO UPDATE").Insert(); err != nil {
|
|
|
|
if err == pg.ErrNoRows {
|
|
|
|
return ErrNoEntries{}
|
|
|
|
}
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
|
|
|
func (ps *postgresService) UpdateOneByID(id string, key string, value interface{}, i interface{}) error {
|
|
|
|
_, err := ps.conn.Model(i).Set("? = ?", pg.Safe(key), value).Where("id = ?", id).Update()
|
2021-03-22 21:26:54 +00:00
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
|
|
|
func (ps *postgresService) DeleteByID(id string, i interface{}) error {
|
2021-04-01 18:46:45 +00:00
|
|
|
if _, err := ps.conn.Model(i).Where("id = ?", id).Delete(); err != nil {
|
|
|
|
if err == pg.ErrNoRows {
|
|
|
|
return ErrNoEntries{}
|
|
|
|
}
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
return nil
|
2021-03-22 21:26:54 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
func (ps *postgresService) DeleteWhere(key string, value interface{}, i interface{}) error {
|
2021-04-01 18:46:45 +00:00
|
|
|
if _, err := ps.conn.Model(i).Where("? = ?", pg.Safe(key), value).Delete(); err != nil {
|
|
|
|
if err == pg.ErrNoRows {
|
|
|
|
return ErrNoEntries{}
|
|
|
|
}
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
|
|
|
/*
|
|
|
|
HANDY SHORTCUTS
|
|
|
|
*/
|
|
|
|
|
|
|
|
func (ps *postgresService) GetAccountByUserID(userID string, account *model.Account) error {
|
|
|
|
user := &model.User{
|
|
|
|
ID: userID,
|
|
|
|
}
|
|
|
|
if err := ps.conn.Model(user).Where("id = ?", userID).Select(); err != nil {
|
|
|
|
if err == pg.ErrNoRows {
|
|
|
|
return ErrNoEntries{}
|
|
|
|
}
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
if err := ps.conn.Model(account).Where("id = ?", user.AccountID).Select(); err != nil {
|
|
|
|
if err == pg.ErrNoRows {
|
|
|
|
return ErrNoEntries{}
|
|
|
|
}
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
|
|
|
func (ps *postgresService) GetFollowRequestsForAccountID(accountID string, followRequests *[]model.FollowRequest) error {
|
|
|
|
if err := ps.conn.Model(followRequests).Where("target_account_id = ?", accountID).Select(); err != nil {
|
|
|
|
if err == pg.ErrNoRows {
|
|
|
|
return ErrNoEntries{}
|
|
|
|
}
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
|
|
|
func (ps *postgresService) GetFollowingByAccountID(accountID string, following *[]model.Follow) error {
|
|
|
|
if err := ps.conn.Model(following).Where("account_id = ?", accountID).Select(); err != nil {
|
|
|
|
if err == pg.ErrNoRows {
|
|
|
|
return ErrNoEntries{}
|
|
|
|
}
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
|
|
|
func (ps *postgresService) GetFollowersByAccountID(accountID string, followers *[]model.Follow) error {
|
|
|
|
if err := ps.conn.Model(followers).Where("target_account_id = ?", accountID).Select(); err != nil {
|
|
|
|
if err == pg.ErrNoRows {
|
|
|
|
return ErrNoEntries{}
|
|
|
|
}
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
|
|
|
func (ps *postgresService) GetStatusesByAccountID(accountID string, statuses *[]model.Status) error {
|
|
|
|
if err := ps.conn.Model(statuses).Where("account_id = ?", accountID).Select(); err != nil {
|
|
|
|
if err == pg.ErrNoRows {
|
|
|
|
return ErrNoEntries{}
|
|
|
|
}
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
|
|
|
func (ps *postgresService) GetStatusesByTimeDescending(accountID string, statuses *[]model.Status, limit int) error {
|
|
|
|
q := ps.conn.Model(statuses).Order("created_at DESC")
|
|
|
|
if limit != 0 {
|
|
|
|
q = q.Limit(limit)
|
|
|
|
}
|
|
|
|
if accountID != "" {
|
|
|
|
q = q.Where("account_id = ?", accountID)
|
|
|
|
}
|
|
|
|
if err := q.Select(); err != nil {
|
|
|
|
if err == pg.ErrNoRows {
|
|
|
|
return ErrNoEntries{}
|
|
|
|
}
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
|
|
|
func (ps *postgresService) GetLastStatusForAccountID(accountID string, status *model.Status) error {
|
|
|
|
if err := ps.conn.Model(status).Order("created_at DESC").Limit(1).Where("account_id = ?", accountID).Select(); err != nil {
|
|
|
|
if err == pg.ErrNoRows {
|
|
|
|
return ErrNoEntries{}
|
|
|
|
}
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
return nil
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
func (ps *postgresService) IsUsernameAvailable(username string) error {
|
|
|
|
// if no error we fail because it means we found something
|
|
|
|
// if error but it's not pg.ErrNoRows then we fail
|
|
|
|
// if err is pg.ErrNoRows we're good, we found nothing so continue
|
|
|
|
if err := ps.conn.Model(&model.Account{}).Where("username = ?", username).Where("domain = ?", nil).Select(); err == nil {
|
|
|
|
return fmt.Errorf("username %s already in use", username)
|
|
|
|
} else if err != pg.ErrNoRows {
|
|
|
|
return fmt.Errorf("db error: %s", err)
|
|
|
|
}
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
|
|
|
func (ps *postgresService) IsEmailAvailable(email string) error {
|
|
|
|
// parse the domain from the email
|
|
|
|
m, err := mail.ParseAddress(email)
|
|
|
|
if err != nil {
|
|
|
|
return fmt.Errorf("error parsing email address %s: %s", email, err)
|
|
|
|
}
|
|
|
|
domain := strings.Split(m.Address, "@")[1] // domain will always be the second part after @
|
|
|
|
|
|
|
|
// check if the email domain is blocked
|
|
|
|
if err := ps.conn.Model(&model.EmailDomainBlock{}).Where("domain = ?", domain).Select(); err == nil {
|
|
|
|
// fail because we found something
|
|
|
|
return fmt.Errorf("email domain %s is blocked", domain)
|
|
|
|
} else if err != pg.ErrNoRows {
|
|
|
|
// fail because we got an unexpected error
|
|
|
|
return fmt.Errorf("db error: %s", err)
|
|
|
|
}
|
|
|
|
|
|
|
|
// check if this email is associated with a user already
|
|
|
|
if err := ps.conn.Model(&model.User{}).Where("email = ?", email).WhereOr("unconfirmed_email = ?", email).Select(); err == nil {
|
|
|
|
// fail because we found something
|
|
|
|
return fmt.Errorf("email %s already in use", email)
|
|
|
|
} else if err != pg.ErrNoRows {
|
|
|
|
// fail because we got an unexpected error
|
|
|
|
return fmt.Errorf("db error: %s", err)
|
|
|
|
}
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
|
|
|
func (ps *postgresService) NewSignup(username string, reason string, requireApproval bool, email string, password string, signUpIP net.IP, locale string, appID string) (*model.User, error) {
|
|
|
|
key, err := rsa.GenerateKey(rand.Reader, 2048)
|
|
|
|
if err != nil {
|
|
|
|
ps.log.Errorf("error creating new rsa key: %s", err)
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
|
|
|
|
uris := util.GenerateURIs(username, ps.config.Protocol, ps.config.Host)
|
|
|
|
|
|
|
|
a := &model.Account{
|
|
|
|
Username: username,
|
|
|
|
DisplayName: username,
|
|
|
|
Reason: reason,
|
|
|
|
URL: uris.UserURL,
|
|
|
|
PrivateKey: key,
|
|
|
|
PublicKey: &key.PublicKey,
|
|
|
|
ActorType: "Person",
|
|
|
|
URI: uris.UserURI,
|
|
|
|
InboxURL: uris.InboxURL,
|
|
|
|
OutboxURL: uris.OutboxURL,
|
|
|
|
FollowersURL: uris.FollowersURL,
|
|
|
|
FeaturedCollectionURL: uris.CollectionURL,
|
|
|
|
}
|
|
|
|
if _, err = ps.conn.Model(a).Insert(); err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
|
|
|
|
pw, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
|
|
|
|
if err != nil {
|
|
|
|
return nil, fmt.Errorf("error hashing password: %s", err)
|
|
|
|
}
|
|
|
|
u := &model.User{
|
|
|
|
AccountID: a.ID,
|
|
|
|
EncryptedPassword: string(pw),
|
|
|
|
SignUpIP: signUpIP,
|
|
|
|
Locale: locale,
|
|
|
|
UnconfirmedEmail: email,
|
|
|
|
CreatedByApplicationID: appID,
|
|
|
|
Approved: !requireApproval, // if we don't require moderator approval, just pre-approve the user
|
|
|
|
}
|
|
|
|
if _, err = ps.conn.Model(u).Insert(); err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
|
|
|
|
return u, nil
|
|
|
|
}
|
|
|
|
|
|
|
|
func (ps *postgresService) SetHeaderOrAvatarForAccountID(mediaAttachment *model.MediaAttachment, accountID string) error {
|
|
|
|
_, err := ps.conn.Model(mediaAttachment).Insert()
|
2021-03-22 21:26:54 +00:00
|
|
|
return err
|
2021-03-14 16:56:16 +00:00
|
|
|
}
|
2021-04-01 18:46:45 +00:00
|
|
|
|
|
|
|
func (ps *postgresService) GetHeaderForAccountID(header *model.MediaAttachment, accountID string) error {
|
|
|
|
if err := ps.conn.Model(header).Where("account_id = ?", accountID).Where("header = ?", true).Select(); err != nil {
|
|
|
|
if err == pg.ErrNoRows {
|
|
|
|
return ErrNoEntries{}
|
|
|
|
}
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
|
|
|
func (ps *postgresService) GetAvatarForAccountID(avatar *model.MediaAttachment, accountID string) error {
|
|
|
|
if err := ps.conn.Model(avatar).Where("account_id = ?", accountID).Where("avatar = ?", true).Select(); err != nil {
|
|
|
|
if err == pg.ErrNoRows {
|
|
|
|
return ErrNoEntries{}
|
|
|
|
}
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
|
|
|
/*
|
|
|
|
CONVERSION FUNCTIONS
|
|
|
|
*/
|
|
|
|
|
|
|
|
// AccountToMastoSensitive takes an internal account model and transforms it into an account ready to be served through the API.
|
|
|
|
// The resulting account fits the specifications for the path /api/v1/accounts/verify_credentials, as described here:
|
|
|
|
// https://docs.joinmastodon.org/methods/accounts/. Note that it's *sensitive* because it's only meant to be exposed to the user
|
|
|
|
// that the account actually belongs to.
|
|
|
|
func (ps *postgresService) AccountToMastoSensitive(a *model.Account) (*mastotypes.Account, error) {
|
|
|
|
// we can build this sensitive account easily by first getting the public account....
|
|
|
|
mastoAccount, err := ps.AccountToMastoPublic(a)
|
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
|
|
|
|
// then adding the Source object to it...
|
|
|
|
|
|
|
|
// check pending follow requests aimed at this account
|
|
|
|
fr := []model.FollowRequest{}
|
|
|
|
if err := ps.GetFollowRequestsForAccountID(a.ID, &fr); err != nil {
|
|
|
|
if _, ok := err.(ErrNoEntries); !ok {
|
|
|
|
return nil, fmt.Errorf("error getting follow requests: %s", err)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
var frc int
|
|
|
|
if fr != nil {
|
|
|
|
frc = len(fr)
|
|
|
|
}
|
|
|
|
|
|
|
|
mastoAccount.Source = &mastotypes.Source{
|
|
|
|
Privacy: a.Privacy,
|
|
|
|
Sensitive: a.Sensitive,
|
|
|
|
Language: a.Language,
|
|
|
|
Note: a.Note,
|
|
|
|
Fields: mastoAccount.Fields,
|
|
|
|
FollowRequestsCount: frc,
|
|
|
|
}
|
|
|
|
|
|
|
|
return mastoAccount, nil
|
|
|
|
}
|
|
|
|
|
|
|
|
func (ps *postgresService) AccountToMastoPublic(a *model.Account) (*mastotypes.Account, error) {
|
|
|
|
// count followers
|
|
|
|
followers := []model.Follow{}
|
|
|
|
if err := ps.GetFollowersByAccountID(a.ID, &followers); err != nil {
|
|
|
|
if _, ok := err.(ErrNoEntries); !ok {
|
|
|
|
return nil, fmt.Errorf("error getting followers: %s", err)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
var followersCount int
|
|
|
|
if followers != nil {
|
|
|
|
followersCount = len(followers)
|
|
|
|
}
|
|
|
|
|
|
|
|
// count following
|
|
|
|
following := []model.Follow{}
|
|
|
|
if err := ps.GetFollowingByAccountID(a.ID, &following); err != nil {
|
|
|
|
if _, ok := err.(ErrNoEntries); !ok {
|
|
|
|
return nil, fmt.Errorf("error getting following: %s", err)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
var followingCount int
|
|
|
|
if following != nil {
|
|
|
|
followingCount = len(following)
|
|
|
|
}
|
|
|
|
|
|
|
|
// count statuses
|
|
|
|
statuses := []model.Status{}
|
|
|
|
if err := ps.GetStatusesByAccountID(a.ID, &statuses); err != nil {
|
|
|
|
if _, ok := err.(ErrNoEntries); !ok {
|
|
|
|
return nil, fmt.Errorf("error getting last statuses: %s", err)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
var statusesCount int
|
|
|
|
if statuses != nil {
|
|
|
|
statusesCount = len(statuses)
|
|
|
|
}
|
|
|
|
|
|
|
|
// check when the last status was
|
|
|
|
lastStatus := &model.Status{}
|
|
|
|
if err := ps.GetLastStatusForAccountID(a.ID, lastStatus); err != nil {
|
|
|
|
if _, ok := err.(ErrNoEntries); !ok {
|
|
|
|
return nil, fmt.Errorf("error getting last status: %s", err)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
var lastStatusAt string
|
|
|
|
if lastStatus != nil {
|
|
|
|
lastStatusAt = lastStatus.CreatedAt.Format(time.RFC3339)
|
|
|
|
}
|
|
|
|
|
|
|
|
// build the avatar and header URLs
|
|
|
|
avi := &model.MediaAttachment{}
|
|
|
|
if err := ps.GetAvatarForAccountID(avi, a.ID); err != nil {
|
|
|
|
if _, ok := err.(ErrNoEntries); !ok {
|
|
|
|
return nil, fmt.Errorf("error getting avatar: %s", err)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
aviURL := avi.File.Path
|
|
|
|
aviURLStatic := avi.Thumbnail.Path
|
|
|
|
|
|
|
|
header := &model.MediaAttachment{}
|
|
|
|
if err := ps.GetHeaderForAccountID(avi, a.ID); err != nil {
|
|
|
|
if _, ok := err.(ErrNoEntries); !ok {
|
|
|
|
return nil, fmt.Errorf("error getting header: %s", err)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
headerURL := header.File.Path
|
|
|
|
headerURLStatic := header.Thumbnail.Path
|
|
|
|
|
|
|
|
// get the fields set on this account
|
|
|
|
fields := []mastotypes.Field{}
|
|
|
|
for _, f := range a.Fields {
|
|
|
|
mField := mastotypes.Field{
|
|
|
|
Name: f.Name,
|
|
|
|
Value: f.Value,
|
|
|
|
}
|
|
|
|
if !f.VerifiedAt.IsZero() {
|
|
|
|
mField.VerifiedAt = f.VerifiedAt.Format(time.RFC3339)
|
|
|
|
}
|
|
|
|
fields = append(fields, mField)
|
|
|
|
}
|
|
|
|
|
|
|
|
var acct string
|
|
|
|
if a.Domain != "" {
|
|
|
|
// this is a remote user
|
|
|
|
acct = fmt.Sprintf("%s@%s", a.Username, a.Domain)
|
|
|
|
} else {
|
|
|
|
// this is a local user
|
|
|
|
acct = a.Username
|
|
|
|
}
|
|
|
|
|
|
|
|
return &mastotypes.Account{
|
|
|
|
ID: a.ID,
|
|
|
|
Username: a.Username,
|
|
|
|
Acct: acct,
|
|
|
|
DisplayName: a.DisplayName,
|
|
|
|
Locked: a.Locked,
|
|
|
|
Bot: a.Bot,
|
|
|
|
CreatedAt: a.CreatedAt.Format(time.RFC3339),
|
|
|
|
Note: a.Note,
|
|
|
|
URL: a.URL,
|
|
|
|
Avatar: aviURL,
|
|
|
|
AvatarStatic: aviURLStatic,
|
|
|
|
Header: headerURL,
|
|
|
|
HeaderStatic: headerURLStatic,
|
|
|
|
FollowersCount: followersCount,
|
|
|
|
FollowingCount: followingCount,
|
|
|
|
StatusesCount: statusesCount,
|
|
|
|
LastStatusAt: lastStatusAt,
|
|
|
|
Emojis: nil, // TODO: implement this
|
|
|
|
Fields: fields,
|
|
|
|
}, nil
|
|
|
|
}
|