account update nearly working

This commit is contained in:
tsmethurst 2021-03-30 22:26:52 +02:00
parent 362ccf5817
commit c8ff849a02
5 changed files with 337 additions and 76 deletions

View file

@ -92,6 +92,9 @@ type DB interface {
// The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice.
UpdateByID(id string, i interface{}) error
// UpdateOneByID updates interface i with database the given database id. It will update one field of key key and value value.
UpdateOneByID(id string, key string, value interface{}, i interface{}) error
// DeleteByID removes i with id id.
// If i didn't exist anyway, then no error should be returned.
DeleteByID(id string, i interface{}) error
@ -156,7 +159,15 @@ type DB interface {
NewSignup(username string, reason string, requireApproval bool, email string, password string, signUpIP net.IP, locale string, appID string) (*model.User, error)
// SetHeaderOrAvatarForAccountID sets the header or avatar for the given accountID to the given media attachment.
SetHeaderOrAvatarForAccountID(mediaAttachmen *model.MediaAttachment, accountID string) error
SetHeaderOrAvatarForAccountID(mediaAttachment *model.MediaAttachment, accountID string) error
// GetHeaderAvatarForAccountID gets the current avatar for the given account ID.
// The passed mediaAttachment pointer will be populated with the value of the avatar, if it exists.
GetAvatarForAccountID(avatar *model.MediaAttachment, accountID string) error
// GetHeaderForAccountID gets the current header for the given account ID.
// The passed mediaAttachment pointer will be populated with the value of the header, if it exists.
GetHeaderForAccountID(header *model.MediaAttachment, accountID string) error
/*
USEFUL CONVERSION FUNCTIONS

View file

@ -274,6 +274,11 @@ func (ps *postgresService) UpdateByID(id string, i interface{}) error {
return nil
}
func (ps *postgresService) UpdateOneByID(id string, key string, value interface{}, i interface{}) error {
_, err := ps.conn.Model(i).Set("? = ?", key, value).Where("id = ?", id).Update()
return err
}
func (ps *postgresService) DeleteByID(id string, i interface{}) error {
if _, err := ps.conn.Model(i).Where("id = ?", id).Delete(); err != nil {
if err == pg.ErrNoRows {
@ -468,6 +473,26 @@ func (ps *postgresService) SetHeaderOrAvatarForAccountID(mediaAttachment *model.
return err
}
func (ps *postgresService) GetHeaderForAccountID(header *model.MediaAttachment, accountID string) error {
if err := ps.conn.Model(header).Where("account_id = ?", accountID).Where("header = ?", true).Select(); err != nil {
if err == pg.ErrNoRows {
return ErrNoEntries{}
}
return err
}
return nil
}
func (ps *postgresService) GetAvatarForAccountID(avatar *model.MediaAttachment, accountID string) error {
if err := ps.conn.Model(avatar).Where("account_id = ?", accountID).Where("avatar = ?", true).Select(); err != nil {
if err == pg.ErrNoRows {
return ErrNoEntries{}
}
return err
}
return nil
}
/*
CONVERSION FUNCTIONS
*/
@ -478,18 +503,6 @@ func (ps *postgresService) SetHeaderOrAvatarForAccountID(mediaAttachment *model.
// that the account actually belongs to.
func (ps *postgresService) AccountToMastoSensitive(a *model.Account) (*mastotypes.Account, error) {
fields := []mastotypes.Field{}
for _, f := range a.Fields {
mField := mastotypes.Field{
Name: f.Name,
Value: f.Value,
}
if !f.VerifiedAt.IsZero() {
mField.VerifiedAt = f.VerifiedAt.Format(time.RFC3339)
}
fields = append(fields, mField)
}
// count followers
followers := []model.Follow{}
if err := ps.GetFollowersByAccountID(a.ID, &followers); err != nil {
@ -538,6 +551,39 @@ func (ps *postgresService) AccountToMastoSensitive(a *model.Account) (*mastotype
lastStatusAt = lastStatus.CreatedAt.Format(time.RFC3339)
}
// build the avatar and header URLs
avi := &model.MediaAttachment{}
if err := ps.GetAvatarForAccountID(avi, a.ID); err != nil {
if _, ok := err.(ErrNoEntries); !ok {
return nil, fmt.Errorf("error getting avatar: %s", err)
}
}
aviURL := avi.File.Path
aviURLStatic := avi.Thumbnail.Path
header := &model.MediaAttachment{}
if err := ps.GetHeaderForAccountID(avi, a.ID); err != nil {
if _, ok := err.(ErrNoEntries); !ok {
return nil, fmt.Errorf("error getting header: %s", err)
}
}
headerURL := header.File.Path
headerURLStatic := header.Thumbnail.Path
// get the fields set on this account
fields := []mastotypes.Field{}
for _, f := range a.Fields {
mField := mastotypes.Field{
Name: f.Name,
Value: f.Value,
}
if !f.VerifiedAt.IsZero() {
mField.VerifiedAt = f.VerifiedAt.Format(time.RFC3339)
}
fields = append(fields, mField)
}
// check pending follow requests aimed at this account
fr := []model.FollowRequest{}
if err := ps.GetFollowRequestsForAccountID(a.ID, &fr); err != nil {
if _, ok := err.(ErrNoEntries); !ok {
@ -549,6 +595,7 @@ func (ps *postgresService) AccountToMastoSensitive(a *model.Account) (*mastotype
frc = len(fr)
}
// derive source from fields and other info
source := &mastotypes.Source{
Privacy: a.Privacy,
Sensitive: a.Sensitive,
@ -567,17 +614,17 @@ func (ps *postgresService) AccountToMastoSensitive(a *model.Account) (*mastotype
Bot: a.Bot,
CreatedAt: a.CreatedAt.Format(time.RFC3339),
Note: a.Note,
URL: a.URL,
Avatar: a.AvatarRemoteURL.String(),
AvatarStatic: a.AvatarRemoteURL.String(),
Header: a.HeaderRemoteURL.String(),
HeaderStatic: a.HeaderRemoteURL.String(),
URL: a.URL, // TODO: set this during account creation
Avatar: aviURL, // TODO: build this url properly using host and protocol from config
AvatarStatic: aviURLStatic, // TODO: build this url properly using host and protocol from config
Header: headerURL, // TODO: build this url properly using host and protocol from config
HeaderStatic: headerURLStatic, // TODO: build this url properly using host and protocol from config
FollowersCount: followersCount,
FollowingCount: followingCount,
StatusesCount: statusesCount,
LastStatusAt: lastStatusAt,
Source: source,
Emojis: nil,
Emojis: nil, // TODO: implement this
Fields: fields,
}, nil
}

View file

@ -19,8 +19,11 @@
package account
import (
"bytes"
"errors"
"fmt"
"io"
"mime/multipart"
"net"
"net/http"
@ -39,8 +42,9 @@ import (
)
const (
idKey = "id"
basePath = "/api/v1/accounts"
basePathWithID = basePath + "/:id"
basePathWithID = basePath + "/:" + idKey
verifyPath = basePath + "/verify_credentials"
updateCredentialsPath = basePath + "/update_credentials"
)
@ -144,6 +148,10 @@ func (m *accountModule) accountVerifyGETHandler(c *gin.Context) {
// accountUpdateCredentialsPATCHHandler allows a user to modify their account/profile settings.
// It should be served as a PATCH at /api/v1/accounts/update_credentials
//
// TODO: this can be optimized massively by building up a picture of what we want the new account
// details to be, and then inserting it all in the database at once. As it is, we do queries one-by-one
// which is not gonna make the database very happy when lots of requests are going through.
func (m *accountModule) accountUpdateCredentialsPATCHHandler(c *gin.Context) {
l := m.log.WithField("func", "accountUpdateCredentialsPATCHHandler")
authed, err := oauth.MustAuth(c, true, false, false, true)
@ -152,63 +160,180 @@ func (m *accountModule) accountUpdateCredentialsPATCHHandler(c *gin.Context) {
c.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
return
}
l.Tracef("retrieved account %+v", authed.Account.ID)
l.Trace("parsing request form")
form := &mastotypes.UpdateCredentialsRequest{}
if err := c.ShouldBind(form); err != nil || form == nil {
l.Debugf("could not parse form from request: %s", err)
c.JSON(http.StatusBadRequest, gin.H{"error": "missing one or more required form values"})
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
// TODO: proper form validation
// TODO: tidy this code into subfunctions
if form.Header != nil && form.Header.Size != 0 {
if form.Header.Size > m.config.MediaConfig.MaxImageSize {
err = fmt.Errorf("header with size %d exceeded max image size of %d bytes", form.Header.Size, m.config.MediaConfig.MaxImageSize)
l.Debugf("error processing header: %s", err)
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
f, err := form.Header.Open()
if err != nil {
l.Debugf("error processing header: %s", err)
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("could not read provided header: %s", err)})
return
}
// extract the bytes
imageBytes := []byte{}
size, err := f.Read(imageBytes)
defer func(){
if err := f.Close(); err != nil {
m.log.Errorf("error closing multipart file: %s", err)
}
}()
if err != nil || size == 0 {
l.Debugf("error processing header: %s", err)
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("could not read provided header: %s", err)})
return
}
// do the setting
headerInfo, err := m.mediaHandler.SetHeaderOrAvatarForAccountID(imageBytes, authed.Account.ID, "header")
if err != nil {
l.Debugf("error processing header: %s", err)
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
l.Tracef("new header info for account %s is %+v", headerInfo)
// if everything on the form is nil, then nothing has been set and we shouldn't continue
if form.Discoverable == nil && form.Bot == nil && form.DisplayName == nil && form.Note == nil && form.Avatar == nil && form.Header == nil && form.Locked == nil && form.Source == nil && form.FieldsAttributes == nil {
l.Debugf("could not parse form from request")
c.JSON(http.StatusBadRequest, gin.H{"error": "empty form submitted"})
return
}
l.Tracef("retrieved account %+v", authed.Account.ID)
if form.Discoverable != nil {
if err := m.db.UpdateOneByID(authed.Account.ID, "discoverable", *form.Discoverable, &model.Account{}); err != nil {
l.Debugf("error updating discoverable: %s", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
}
if form.Bot != nil {
if err := m.db.UpdateOneByID(authed.Account.ID, "bot", *form.Bot, &model.Account{}); err != nil {
l.Debugf("error updating bot: %s", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
}
if form.DisplayName != nil {
if err := m.db.UpdateOneByID(authed.Account.ID, "display_name", *form.DisplayName, &model.Account{}); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
}
if form.Note != nil {
if err := m.db.UpdateOneByID(authed.Account.ID, "note", *form.Note, &model.Account{}); err != nil {
l.Debugf("error updating note: %s", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
}
if form.Avatar != nil && form.Avatar.Size != 0 {
avatarInfo, err := m.UpdateAccountAvatar(form.Avatar, authed.Account.ID)
if err != nil {
l.Debugf("could not update avatar for account %s: %s", authed.Account.ID, err)
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
l.Tracef("new avatar info for account %s is %+v", authed.Account.ID, avatarInfo)
}
if form.Header != nil && form.Header.Size != 0 {
headerInfo, err := m.UpdateAccountHeader(form.Header, authed.Account.ID)
if err != nil {
l.Debugf("could not update header for account %s: %s", authed.Account.ID, err)
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
l.Tracef("new header info for account %s is %+v", authed.Account.ID, headerInfo)
}
if form.Locked != nil {
if err := m.db.UpdateOneByID(authed.Account.ID, "locked", *form.Locked, &model.Account{}); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"": err.Error()})
return
}
}
if form.Source != nil {
}
if form.FieldsAttributes != nil {
}
// fetch the account with all updated values set
updatedAccount := &model.Account{}
if err := m.db.GetByID(authed.Account.ID, updatedAccount); err != nil {
l.Debugf("could not fetch updated account %s: %s", authed.Account.ID, err)
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
acctSensitive, err := m.db.AccountToMastoSensitive(updatedAccount)
if err != nil {
l.Tracef("could not convert account into mastosensitive account: %s", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
l.Tracef("conversion successful, returning OK and mastosensitive account %+v", acctSensitive)
c.JSON(http.StatusOK, acctSensitive)
}
/*
HELPER FUNCTIONS
*/
// TODO: try to combine the below two functions because this is a lot of code repetition.
// UpdateAccountAvatar does the dirty work of checking the avatar part of an account update form,
// parsing and checking the image, and doing the necessary updates in the database for this to become
// the account's new avatar image.
func (m *accountModule) UpdateAccountAvatar(avatar *multipart.FileHeader, accountID string) (*model.MediaAttachment, error) {
var err error
if avatar.Size > m.config.MediaConfig.MaxImageSize {
err = fmt.Errorf("avatar with size %d exceeded max image size of %d bytes", avatar.Size, m.config.MediaConfig.MaxImageSize)
return nil, err
}
f, err := avatar.Open()
if err != nil {
return nil, fmt.Errorf("could not read provided avatar: %s", err)
}
// extract the bytes
buf := new(bytes.Buffer)
size, err := io.Copy(buf, f)
if err != nil {
return nil, fmt.Errorf("could not read provided avatar: %s", err)
}
if size == 0 {
return nil, errors.New("could not read provided avatar: size 0 bytes")
}
// do the setting
avatarInfo, err := m.mediaHandler.SetHeaderOrAvatarForAccountID(buf.Bytes(), accountID, "avatar")
if err != nil {
return nil, fmt.Errorf("error processing avatar: %s", err)
}
return avatarInfo, f.Close()
}
// UpdateAccountHeader does the dirty work of checking the header part of an account update form,
// parsing and checking the image, and doing the necessary updates in the database for this to become
// the account's new header image.
func (m *accountModule) UpdateAccountHeader(header *multipart.FileHeader, accountID string) (*model.MediaAttachment, error) {
var err error
if header.Size > m.config.MediaConfig.MaxImageSize {
err = fmt.Errorf("header with size %d exceeded max image size of %d bytes", header.Size, m.config.MediaConfig.MaxImageSize)
return nil, err
}
f, err := header.Open()
if err != nil {
return nil, fmt.Errorf("could not read provided header: %s", err)
}
// extract the bytes
buf := new(bytes.Buffer)
size, err := io.Copy(buf, f)
if err != nil {
return nil, fmt.Errorf("could not read provided header: %s", err)
}
if size == 0 {
return nil, errors.New("could not read provided header: size 0 bytes")
}
// do the setting
headerInfo, err := m.mediaHandler.SetHeaderOrAvatarForAccountID(buf.Bytes(), accountID, "header")
if err != nil {
return nil, fmt.Errorf("error processing header: %s", err)
}
return headerInfo, f.Close()
}
// accountCreate does the dirty work of making an account and user in the database.
// It then returns a token to the caller, for use with the new account, as per the
// spec here: https://docs.joinmastodon.org/methods/accounts/

View file

@ -19,13 +19,17 @@
package account
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"io/ioutil"
"mime/multipart"
"net/http"
"net/http/httptest"
"net/url"
"os"
"testing"
"time"
@ -40,6 +44,7 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/db/model"
"github.com/superseriousbusiness/gotosocial/internal/media"
"github.com/superseriousbusiness/gotosocial/internal/oauth"
"github.com/superseriousbusiness/gotosocial/internal/storage"
"github.com/superseriousbusiness/gotosocial/pkg/mastotypes"
"github.com/superseriousbusiness/oauth2/v4"
"github.com/superseriousbusiness/oauth2/v4/models"
@ -57,7 +62,8 @@ type AccountTestSuite struct {
testApplication *model.Application
testToken oauth2.TokenInfo
mockOauthServer *oauth.MockServer
mockMediaHandler *media.MockMediaHandler
mockStorage *storage.MockStorage
mediaHandler media.MediaHandler
db db.DB
accountModule *accountModule
newUserFormHappyPath url.Values
@ -74,6 +80,11 @@ func (suite *AccountTestSuite) SetupSuite() {
log.SetLevel(logrus.TraceLevel)
suite.log = log
suite.testAccountLocal = &model.Account{
ID: uuid.NewString(),
Username: "test_user",
}
// can use this test application throughout
suite.testApplication = &model.Application{
ID: "weeweeeeeeeeeeeeee",
@ -107,6 +118,9 @@ func (suite *AccountTestSuite) SetupSuite() {
Database: "postgres",
ApplicationName: "gotosocial",
}
c.MediaConfig = &config.MediaConfig{
MaxImageSize: 2 << 20,
}
suite.config = c
// use an actual database for this, because it's just easier than mocking one out
@ -130,11 +144,15 @@ func (suite *AccountTestSuite) SetupSuite() {
Code: "we're authorized now!",
}, nil)
// mock the media handler because some handlers (eg update credentials) need to upload media (new header/avatar)
suite.mockMediaHandler = &media.MockMediaHandler{}
suite.mockStorage = &storage.MockStorage{}
// We don't need storage to do anything for these tests, so just simulate a success and do nothing -- we won't need to return anything from storage
suite.mockStorage.On("StoreFileAt", mock.AnythingOfType("string"), mock.AnythingOfType("[]uint8")).Return(nil)
// set a media handler because some handlers (eg update credentials) need to upload media (new header/avatar)
suite.mediaHandler = media.New(suite.config, suite.db, suite.mockStorage, log)
// and finally here's the thing we're actually testing!
suite.accountModule = New(suite.config, suite.db, suite.mockOauthServer, suite.mockMediaHandler, suite.log).(*accountModule)
suite.accountModule = New(suite.config, suite.db, suite.mockOauthServer, suite.mediaHandler, suite.log).(*accountModule)
}
func (suite *AccountTestSuite) TearDownSuite() {
@ -150,9 +168,11 @@ func (suite *AccountTestSuite) SetupTest() {
&model.User{},
&model.Account{},
&model.Follow{},
&model.FollowRequest{},
&model.Status{},
&model.Application{},
&model.EmailDomainBlock{},
&model.MediaAttachment{},
}
for _, m := range models {
if err := suite.db.CreateTable(m); err != nil {
@ -186,9 +206,11 @@ func (suite *AccountTestSuite) TearDownTest() {
&model.User{},
&model.Account{},
&model.Follow{},
&model.FollowRequest{},
&model.Status{},
&model.Application{},
&model.EmailDomainBlock{},
&model.MediaAttachment{},
}
for _, m := range models {
if err := suite.db.DropTable(m); err != nil {
@ -201,6 +223,10 @@ func (suite *AccountTestSuite) TearDownTest() {
ACTUAL TESTS
*/
/*
TESTING: AccountCreatePOSTHandler
*/
// TestAccountCreatePOSTHandlerSuccessful checks the happy path for an account creation request: all the fields provided are valid,
// and at the end of it a new user and account should be added into the database.
//
@ -455,6 +481,58 @@ func (suite *AccountTestSuite) TestAccountCreatePOSTHandlerInsufficientReason()
assert.Equal(suite.T(), `{"error":"reason should be at least 40 chars but 'just cuz' was 8"}`, string(b))
}
/*
TESTING: AccountUpdateCredentialsPATCHHandler
*/
func (suite *AccountTestSuite) TestAccountUpdateCredentialsPATCHHandler() {
// put test local account in db
err := suite.db.Put(suite.testAccountLocal)
assert.NoError(suite.T(), err)
// attach avatar to request
aviFile, err := os.Open("../../media/test/test-jpeg.jpg")
assert.NoError(suite.T(), err)
body := &bytes.Buffer{}
writer := multipart.NewWriter(body)
part, err := writer.CreateFormFile("avatar", "test-jpeg.jpg")
assert.NoError(suite.T(), err)
_, err = io.Copy(part, aviFile)
assert.NoError(suite.T(), err)
err = aviFile.Close()
assert.NoError(suite.T(), err)
err = writer.Close()
assert.NoError(suite.T(), err)
// setup
recorder := httptest.NewRecorder()
ctx, _ := gin.CreateTestContext(recorder)
ctx.Set(oauth.SessionAuthorizedAccount, suite.testAccountLocal)
ctx.Set(oauth.SessionAuthorizedToken, suite.testToken)
ctx.Request = httptest.NewRequest(http.MethodPatch, fmt.Sprintf("http://localhost:8080/%s", updateCredentialsPath), body) // the endpoint we're hitting
ctx.Request.Header.Set("Content-Type", writer.FormDataContentType())
suite.accountModule.accountUpdateCredentialsPATCHHandler(ctx)
// check response
// 1. we should have OK because our request was valid
suite.EqualValues(http.StatusOK, recorder.Code)
// 2. we should have an error message in the result body
result := recorder.Result()
defer result.Body.Close()
// TODO: implement proper checks here
//
// b, err := ioutil.ReadAll(result.Body)
// assert.NoError(suite.T(), err)
// assert.Equal(suite.T(), `{"error":"not authorized"}`, string(b))
}
func TestAccountTestSuite(t *testing.T) {
suite.Run(t, new(AccountTestSuite))
}

View file

@ -92,40 +92,40 @@ type AccountCreateRequest struct {
// See https://docs.joinmastodon.org/methods/accounts/
type UpdateCredentialsRequest struct {
// Whether the account should be shown in the profile directory.
Discoverable string `form:"discoverable"`
Discoverable *bool `form:"discoverable"`
// Whether the account has a bot flag.
Bot bool `form:"bot"`
Bot *bool `form:"bot"`
// The display name to use for the profile.
DisplayName string `form:"display_name"`
DisplayName *string `form:"display_name"`
// The account bio.
Note string `form:"note"`
Note *string `form:"note"`
// Avatar image encoded using multipart/form-data
Avatar *multipart.FileHeader `form:"avatar"`
// Header image encoded using multipart/form-data
Header *multipart.FileHeader `form:"header"`
// Whether manual approval of follow requests is required.
Locked bool `form:"locked"`
Locked *bool `form:"locked"`
// New Source values for this account
Source *UpdateSource `form:"source"`
// Profile metadata name and value
FieldsAttributes []UpdateField `form:"fields_attributes"`
FieldsAttributes *[]UpdateField `form:"fields_attributes"`
}
// UpdateSource is to be used specifically in an UpdateCredentialsRequest.
type UpdateSource struct {
// Default post privacy for authored statuses.
Privacy string `form:"privacy"`
Privacy *string `form:"privacy"`
// Whether to mark authored statuses as sensitive by default.
Sensitive bool `form:"sensitive"`
Sensitive *bool `form:"sensitive"`
// Default language to use for authored statuses. (ISO 6391)
Language string `form:"language"`
Language *string `form:"language"`
}
// UpdateField is to be used specifically in an UpdateCredentialsRequest.
// By default, max 4 fields and 255 characters per property/value.
type UpdateField struct {
// Name of the field
Name string `form:"name"`
Name *string `form:"name"`
// Value of the field
Value string `form:"value"`
Value *string `form:"value"`
}