Api/v1/accounts (#8)

* start work on accounts module

* plodding away on the accounts endpoint

* groundwork for other account routes

* add password validator

* validation utils

* require account approval flags

* comments

* comments

* go fmt

* comments

* add distributor stub

* rename api to federator

* tidy a bit

* validate new account requests

* rename r router

* comments

* add domain blocks

* add some more shortcuts

* add some more shortcuts

* check email + username availability

* email block checking for signups

* chunking away at it

* tick off a few more things

* some fiddling with tests

* add mock package

* relocate repo

* move mocks around

* set app id on new signups

* initialize oauth server properly

* rename oauth server

* proper mocking tests

* go fmt ./...

* add required fields

* change name of func

* move validation to account.go

* more tests!

* add some file utility tools

* add mediaconfig

* new shortcut

* add some more fields

* add followrequest model

* add notify

* update mastotypes

* mock out storage interface

* start building media interface

* start on update credentials

* mess about with media a bit more

* test image manipulation

* media more or less working

* account update nearly working

* rearranging my package ;) ;) ;)

* phew big stuff!!!!

* fix type checking

* *fiddles*

* Add CreateTables func

* account registration flow working

* tidy

* script to step through auth flow

* add a lil helper for generating user uris

* fiddling with federation a bit

* update progress

* Tidying and linting
This commit is contained in:
Tobi Smethurst 2021-04-01 20:46:45 +02:00 committed by GitHub
parent aa9ce272dc
commit 71a49e2b43
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
94 changed files with 6585 additions and 955 deletions

View file

@ -11,10 +11,10 @@
* [x] /auth/sign_in GET (Show form for user signin)
* [x] /auth/sign_in POST (Validate username and password and sign user in)
* [ ] Accounts
* [ ] /api/v1/accounts POST (Register a new account)
* [ ] /api/v1/accounts/verify_credentials GET (Verify account credentials with a user token)
* [ ] /api/v1/accounts/update_credentials PATCH (Update user's display name/preferences)
* [ ] /api/v1/accounts/:id GET (Get account information)
* [x] /api/v1/accounts POST (Register a new account)
* [x] /api/v1/accounts/verify_credentials GET (Verify account credentials with a user token)
* [x] /api/v1/accounts/update_credentials PATCH (Update user's display name/preferences)
* [x] /api/v1/accounts/:id GET (Get account information)
* [ ] /api/v1/accounts/:id/statuses GET (Get an account's statuses)
* [ ] /api/v1/accounts/:id/followers GET (Get an account's followers)
* [ ] /api/v1/accounts/:id/following GET (Get an account's following)
@ -184,7 +184,7 @@
* [ ] Cache
* [ ] In-memory cache
* [ ] Security features
* [ ] Authorization middleware
* [x] Authorization middleware
* [ ] Rate limiting middleware
* [ ] Scope middleware
* [ ] Permissions/acl middleware for admins+moderators

View file

@ -22,12 +22,12 @@ import (
"fmt"
"os"
"github.com/gotosocial/gotosocial/internal/action"
"github.com/gotosocial/gotosocial/internal/config"
"github.com/gotosocial/gotosocial/internal/db"
"github.com/gotosocial/gotosocial/internal/gotosocial"
"github.com/gotosocial/gotosocial/internal/log"
"github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/action"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gotosocial"
"github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/urfave/cli/v2"
)
@ -111,10 +111,70 @@ func main() {
// TEMPLATE FLAGS
&cli.StringFlag{
Name: flagNames.TemplateBaseDir,
Usage: "Basedir for html templating files for rendering pages and composing emails",
Usage: "Basedir for html templating files for rendering pages and composing emails.",
Value: "./web/template/",
EnvVars: []string{envNames.TemplateBaseDir},
},
// ACCOUNTS FLAGS
&cli.BoolFlag{
Name: flagNames.AccountsOpenRegistration,
Usage: "Allow anyone to submit an account signup request. If false, server will be invite-only.",
Value: true,
EnvVars: []string{envNames.AccountsOpenRegistration},
},
&cli.BoolFlag{
Name: flagNames.AccountsRequireApproval,
Usage: "Do account signups require approval by an admin or moderator before user can log in? If false, new registrations will be automatically approved.",
Value: true,
EnvVars: []string{envNames.AccountsRequireApproval},
},
// MEDIA FLAGS
&cli.IntFlag{
Name: flagNames.MediaMaxImageSize,
Usage: "Max size of accepted images in bytes",
Value: 1048576, // 1mb
EnvVars: []string{envNames.MediaMaxImageSize},
},
&cli.IntFlag{
Name: flagNames.MediaMaxVideoSize,
Usage: "Max size of accepted videos in bytes",
Value: 5242880, // 5mb
EnvVars: []string{envNames.MediaMaxVideoSize},
},
// STORAGE FLAGS
&cli.StringFlag{
Name: flagNames.StorageBackend,
Usage: "Storage backend to use for media attachments",
Value: "local",
EnvVars: []string{envNames.StorageBackend},
},
&cli.StringFlag{
Name: flagNames.StorageBasePath,
Usage: "Full path to an already-created directory where gts should store/retrieve media files",
Value: "/opt/gotosocial",
EnvVars: []string{envNames.StorageBasePath},
},
&cli.StringFlag{
Name: flagNames.StorageServeProtocol,
Usage: "Protocol to use for serving media attachments (use https if storage is local)",
Value: "https",
EnvVars: []string{envNames.StorageServeProtocol},
},
&cli.StringFlag{
Name: flagNames.StorageServeHost,
Usage: "Hostname to serve media attachments from (use the same value as host if storage is local)",
Value: "localhost",
EnvVars: []string{envNames.StorageServeHost},
},
&cli.StringFlag{
Name: flagNames.StorageServeBasePath,
Usage: "Path to append to protocol and hostname to create the base path from which media files will be served (default will mostly be fine)",
Value: "/fileserver/media",
EnvVars: []string{envNames.StorageServeBasePath},
},
},
Commands: []*cli.Command{
{

View file

@ -14,10 +14,9 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
###################
##### CONFIG ######
###################
###########################
##### GENERAL CONFIG ######
###########################
# String. Log level to use throughout the application. Must be lower-case.
# Options: ["debug","info","warn","error","fatal"]
# Default: "info"
@ -39,6 +38,9 @@ host: "localhost"
# Default: "https"
protocol: "https"
############################
##### DATABASE CONFIG ######
############################
# Config pertaining to the Gotosocial database connection
db:
# String. Database type.
@ -72,9 +74,26 @@ db:
# Default: "postgres"
database: "postgres"
###############################
##### WEB TEMPLATE CONFIG #####
###############################
# Config pertaining to templating of web pages/email notifications and the like
template:
# String. Directory from which gotosocial will attempt to load html templates (.tmpl files).
# Examples: ["/some/absolute/path/", "./relative/path/", "../../some/weird/path/"]
# Default: "./web/template/"
baseDir: "./web/template/"
###########################
##### ACCOUNTS CONFIG #####
###########################
# Config pertaining to creation and maintenance of accounts on the server, as well as defaults for new accounts.
accounts:
# Bool. Do we want people to be able to just submit sign up requests, or do we want invite only?
# Options: [true, false]
# Default: true
openRegistration: true
# Bool. Do sign up requests require approval from an admin/moderator before an account can sign in/use the server?
# Options: [true, false]
# Default: true
requireApproval: true

13
go.mod
View file

@ -1,8 +1,10 @@
module github.com/gotosocial/gotosocial
module github.com/superseriousbusiness/gotosocial
go 1.16
require (
github.com/buckket/go-blurhash v1.1.0
github.com/cpuguy83/go-md2man/v2 v2.0.0 // indirect
github.com/gin-contrib/sessions v0.0.3
github.com/gin-gonic/gin v1.6.3
github.com/go-fed/activity v1.0.0
@ -10,16 +12,23 @@ require (
github.com/go-pg/pg/v10 v10.8.0
github.com/golang/mock v1.4.4 // indirect
github.com/google/uuid v1.2.0
github.com/gotosocial/oauth2/v4 v4.2.1-0.20210316171520-7b12112bbb88
github.com/h2non/filetype v1.1.1
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/modern-go/reflect2 v1.0.1 // indirect
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646
github.com/onsi/ginkgo v1.15.0 // indirect
github.com/onsi/gomega v1.10.5 // indirect
github.com/sirupsen/logrus v1.8.0
github.com/stretchr/testify v1.7.0
github.com/superseriousbusiness/exifremove v0.0.0-20210330092427-6acd27eac203
github.com/superseriousbusiness/oauth2/v4 v4.2.1-0.20210327102222-902aba1ef45f
github.com/tidwall/btree v0.4.2 // indirect
github.com/tidwall/buntdb v1.2.0 // indirect
github.com/tidwall/pretty v1.1.0 // indirect
github.com/urfave/cli/v2 v2.3.0
github.com/wagslane/go-password-validator v0.3.0
golang.org/x/crypto v0.0.0-20210314154223-e6e6c4f2bb5b
golang.org/x/text v0.3.3
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect
gopkg.in/yaml.v2 v2.3.0
)

57
go.sum
View file

@ -7,16 +7,37 @@ github.com/andybalholm/brotli v1.0.0/go.mod h1:loMXtMfwqflxFJPmdbJO0a3KNoPuLBgiu
github.com/boj/redistore v0.0.0-20180917114910-cd5dcc76aeff/go.mod h1:+RTT1BOk5P97fT2CiHkbFQwkK3mjsFAP6zCYV2aXtjw=
github.com/bradfitz/gomemcache v0.0.0-20190329173943-551aad21a668/go.mod h1:H0wQNHz2YrLsuXOZozoeDmnHXkNCRmMW0gwFWDfEZDA=
github.com/bradleypeabody/gorilla-sessions-memcache v0.0.0-20181103040241-659414f458e1/go.mod h1:dkChI7Tbtx7H1Tj7TqGSZMOeGpMP5gLHtjroHd4agiI=
github.com/buckket/go-blurhash v1.1.0 h1:X5M6r0LIvwdvKiUtiNcRL2YlmOfMzYobI3VCKCZc9Do=
github.com/buckket/go-blurhash v1.1.0/go.mod h1:aT2iqo5W9vu9GpyoLErKfTHwgODsZp3bQfXjXJUxNb8=
github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU=
github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw=
github.com/cpuguy83/go-md2man/v2 v2.0.0-20190314233015-f79a8a8ca69d h1:U+s90UTSYgptZMwQh2aRr3LuazLJIa+Pg3Kc1ylSYVY=
github.com/cpuguy83/go-md2man/v2 v2.0.0-20190314233015-f79a8a8ca69d/go.mod h1:maD7wRr/U5Z6m/iR4s+kqSMx2CaBsrgA7czyZG/E6dU=
github.com/cpuguy83/go-md2man/v2 v2.0.0 h1:EoUDS0afbrsXAZ9YQ9jdu/mZ2sXgT1/2yyNng4PGlyM=
github.com/cpuguy83/go-md2man/v2 v2.0.0/go.mod h1:maD7wRr/U5Z6m/iR4s+kqSMx2CaBsrgA7czyZG/E6dU=
github.com/dave/jennifer v1.3.0/go.mod h1:fIb+770HOpJ2fmN9EPPKOqm1vMGhB+TwXKMZhrIygKg=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/dgrijalva/jwt-go v3.2.0+incompatible h1:7qlOGliEKZXTDg6OTjfoBKDXWrumCAMpl/TFQ4/5kLM=
github.com/dgrijalva/jwt-go v3.2.0+incompatible/go.mod h1:E3ru+11k8xSBh+hMPgOLZmtrrCbhqsmaPHjLKYnJCaQ=
github.com/dsoprea/go-exif v0.0.0-20210131231135-d154f10435cc h1:AuzYp98IFVOi0NU/WcZyGDQ6vAh/zkCjxGD3kt8aLzA=
github.com/dsoprea/go-exif v0.0.0-20210131231135-d154f10435cc/go.mod h1:lOaOt7+UEppOgyvRy749v3do836U/hw0YVJNjoyPaEs=
github.com/dsoprea/go-exif/v2 v2.0.0-20200321225314-640175a69fe4/go.mod h1:Lm2lMM2zx8p4a34ZemkaUV95AnMl4ZvLbCUbwOvLC2E=
github.com/dsoprea/go-exif/v2 v2.0.0-20200604193436-ca8584a0e1c4 h1:Mg7pY7kxDQD2Bkvr1N+XW4BESSIQ7tTTR7Vv+Gi2CsM=
github.com/dsoprea/go-exif/v2 v2.0.0-20200604193436-ca8584a0e1c4/go.mod h1:9EXlPeHfblFFnwu5UOqmP2eoZfJyAZ2Ri/Vki33ajO0=
github.com/dsoprea/go-iptc v0.0.0-20200609062250-162ae6b44feb h1:gwjJjUr6FY7zAWVEueFPrcRHhd9+IK81TcItbqw2du4=
github.com/dsoprea/go-iptc v0.0.0-20200609062250-162ae6b44feb/go.mod h1:kYIdx9N9NaOyD7U6D+YtExN7QhRm+5kq7//yOsRXQtM=
github.com/dsoprea/go-jpeg-image-structure v0.0.0-20210128210355-86b1014917f2 h1:ULCSN6v0WISNbALxomGPXh4dSjRKPW+7+seYoMz8UTc=
github.com/dsoprea/go-jpeg-image-structure v0.0.0-20210128210355-86b1014917f2/go.mod h1:ZoOP3yUG0HD1T4IUjIFsz/2OAB2yB4YX6NSm4K+uJRg=
github.com/dsoprea/go-logging v0.0.0-20190624164917-c4f10aab7696/go.mod h1:Nm/x2ZUNRW6Fe5C3LxdY1PyZY5wmDv/s5dkPJ/VB3iA=
github.com/dsoprea/go-logging v0.0.0-20200517223158-a10564966e9d h1:F/7L5wr/fP/SKeO5HuMlNEX9Ipyx2MbH2rV9G4zJRpk=
github.com/dsoprea/go-logging v0.0.0-20200517223158-a10564966e9d/go.mod h1:7I+3Pe2o/YSU88W0hWlm9S22W7XI1JFNJ86U0zPKMf8=
github.com/dsoprea/go-photoshop-info-format v0.0.0-20200609050348-3db9b63b202c h1:7j5aWACOzROpr+dvMtu8GnI97g9ShLWD72XIELMgn+c=
github.com/dsoprea/go-photoshop-info-format v0.0.0-20200609050348-3db9b63b202c/go.mod h1:pqKB+ijp27cEcrHxhXVgUUMlSDRuGJJp1E+20Lj5H0E=
github.com/dsoprea/go-png-image-structure v0.0.0-20200807080309-a98d4e94ac82 h1:RdwKOEEe2ND/JmoKh6I/EQlR9idKJTDOMffPFK6vN2M=
github.com/dsoprea/go-png-image-structure v0.0.0-20200807080309-a98d4e94ac82/go.mod h1:aDYQkL/5gfRNZkoxiLTSWU4Y8/gV/4MVsy/MU9uwTak=
github.com/dsoprea/go-utility v0.0.0-20200512094054-1abbbc781176 h1:CfXezFYb2STGOd1+n1HshvE191zVx+QX3A1nML5xxME=
github.com/dsoprea/go-utility v0.0.0-20200512094054-1abbbc781176/go.mod h1:95+K3z2L0mqsVYd6yveIv1lmtT3tcQQ3dVakPySffW8=
github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4=
github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c=
github.com/fasthttp-contrib/websocket v0.0.0-20160511215533-1f3b11f56072/go.mod h1:duJ4Jxv5lDcvg4QuQr0oowTf7dz4/CR8NtyCooz9HL8=
@ -35,6 +56,9 @@ github.com/gin-gonic/gin v1.5.0/go.mod h1:Nd6IXA8m5kNZdNEHMBd93KT+mdY3+bewLgRvmC
github.com/gin-gonic/gin v1.6.3 h1:ahKqKTFpO5KTPHxWZjEdPScmYaGtLo8Y4DMHoEsnp14=
github.com/gin-gonic/gin v1.6.3/go.mod h1:75u5sXoLsGZoRN5Sgbi1eraJ4GU3++wFwWzhwvtwp4M=
github.com/globalsign/mgo v0.0.0-20181015135952-eeefdecb41b8/go.mod h1:xkRDCp4j0OGD1HRkm4kmhM+pmpv3AKq5SU7GMg4oO/Q=
github.com/go-errors/errors v1.0.1/go.mod h1:f4zRHt4oKfwPJE5k8C9vpYG+aDHdBFUsgrm6/TyX73Q=
github.com/go-errors/errors v1.0.2 h1:xMxH9j2fNg/L4hLn/4y3M0IUsn0M6Wbu/Uh9QlOfBh4=
github.com/go-errors/errors v1.0.2/go.mod h1:psDX2osz5VnTOnFWbDeWwS7yejl+uV3FEWEp4lssFEs=
github.com/go-fed/activity v1.0.0 h1:j7w3auHZnVCjUcgA1mE+UqSOjFBhvW2Z2res3vNol+o=
github.com/go-fed/activity v1.0.0/go.mod h1:v4QoPaAzjWZ8zN2VFVGL5ep9C02mst0hQYHUpQwso4Q=
github.com/go-fed/httpsig v0.1.1-0.20190914113940-c2de3672e5b5 h1:WLvFZqoXnuVTBKA6U/1FnEHNQ0Rq0QM0rGhY8Tx6R1g=
@ -58,6 +82,11 @@ github.com/go-playground/validator/v10 v10.2.0 h1:KgJ0snyC2R9VXYN2rneOtQcw5aHQB1
github.com/go-playground/validator/v10 v10.2.0/go.mod h1:uOYAAleCW8F/7oMFd6aG0GOhaH6EGOAJShg8Id5JGkI=
github.com/go-test/deep v1.0.1 h1:UQhStjbkDClarlmv0am7OXXO4/GaPdCGiUiMTvi28sg=
github.com/go-test/deep v1.0.1/go.mod h1:wGDj63lr65AM2AQyKZd/NYHGb0R+1RLqB8NKt3aSFNA=
github.com/go-xmlfmt/xmlfmt v0.0.0-20191208150333-d5b6f63a941b h1:khEcpUM4yFcxg4/FHQWkvVRmgijNXRfzkIDHh23ggEo=
github.com/go-xmlfmt/xmlfmt v0.0.0-20191208150333-d5b6f63a941b/go.mod h1:aUCEOzzezBEjDBbFBoSiya/gduyIiWYRP6CnSFIV8AM=
github.com/golang/geo v0.0.0-20190916061304-5b978397cfec/go.mod h1:QZ0nwyI2jOfgRAoBvP+ab5aRr7c9x7lhGEJrKvBwjWI=
github.com/golang/geo v0.0.0-20200319012246-673a6f80352d h1:C/hKUcHT483btRbeGkrRjJz+Zbcj8audldIi9tRJDCc=
github.com/golang/geo v0.0.0-20200319012246-673a6f80352d/go.mod h1:QZ0nwyI2jOfgRAoBvP+ab5aRr7c9x7lhGEJrKvBwjWI=
github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q=
github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A=
github.com/golang/mock v1.2.0/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A=
@ -103,11 +132,12 @@ github.com/gorilla/sessions v1.1.3 h1:uXoZdcdA5XdXF3QzuSlheVRUvjl+1rKY7zBXL68L9R
github.com/gorilla/sessions v1.1.3/go.mod h1:8KCfur6+4Mqcc6S0FEfKuN15Vl5MgXW92AE8ovaJD0w=
github.com/gorilla/websocket v1.4.2 h1:+/TMaTYc4QFitKJxsQ7Yye35DkWvkdLcvGKqM+x0Ufc=
github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/gotosocial/oauth2/v4 v4.2.1-0.20210316171520-7b12112bbb88 h1:YJ//HmHOYJ4srm/LA6VPNjNisneMbY6TTM1xttV/ZQU=
github.com/gotosocial/oauth2/v4 v4.2.1-0.20210316171520-7b12112bbb88/go.mod h1:zl5kwHf/atRUrY5yOyDnk49Us1Ygs0BzdW4jKAgoiP8=
github.com/h2non/filetype v1.1.1 h1:xvOwnXKAckvtLWsN398qS9QhlxlnVXBjXBydK2/UFB4=
github.com/h2non/filetype v1.1.1/go.mod h1:319b3zT68BvV+WRj7cwy856M2ehB3HqNOt6sy1HndBY=
github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU=
github.com/imkira/go-interpol v1.1.0 h1:KIiKr0VSG2CUW1hl1jpiyuzuJeKUUpC8iM1AIE7N1Vk=
github.com/imkira/go-interpol v1.1.0/go.mod h1:z0h2/2T3XF8kyEPpRgJ3kmNv+C43p+I/CoI+jC3w2iA=
github.com/jessevdk/go-flags v1.4.0/go.mod h1:4FA24M0QyGHXBuZZK/XkWh8h0e1EYbRYJSGM75WSRxI=
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
github.com/json-iterator/go v1.1.7/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4=
@ -136,12 +166,16 @@ github.com/mattn/go-isatty v0.0.9/go.mod h1:YNRxwqDuOph6SZLI9vUUz6OYw3QyUt7WiY2y
github.com/mattn/go-isatty v0.0.12 h1:wuysRhFDzyxgEmMf5xjvJ2M9dZoWAXNNr5LSBS7uHXY=
github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU=
github.com/memcachier/mc v2.0.1+incompatible/go.mod h1:7bkvFE61leUBvXz+yxsOnGBQSZpBSPIMUQSmmSHvuXc=
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421 h1:ZqeYNhU3OHLH3mGKHDcjJRFFRrJa6eAM5H+CtDdOsPc=
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742 h1:Esafd1046DLDQ0W1YjYsBW+p8U2u7vzgW2SQVmlNazg=
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg=
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0=
github.com/modern-go/reflect2 v1.0.1 h1:9f412s+6RmYXLWZSEzVVgPGK7C2PphHj5RJrvfx9AWI=
github.com/modern-go/reflect2 v1.0.1/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0=
github.com/moul/http2curl v1.0.0 h1:dRMWoAtb+ePxMlLkrCbAqh4TlPHXvoGUSQ323/9Zahs=
github.com/moul/http2curl v1.0.0/go.mod h1:8UbvGypXm98wA/IqH45anm5Y2Z6ep6O31QGOAZ3H0fQ=
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 h1:zYyBkD/k9seD2A7fsi6Oo2LfFZAehjjQMERAvZLEDnQ=
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646/go.mod h1:jpp1/29i3P1S/RLdc7JQKbRpFeM1dOBd8T9ki5s+AY8=
github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno=
github.com/nxadm/tail v1.4.4 h1:DQuhQpB1tVlglWS2hLQ5OV6B5r8aGxSrPc5Qo6uTN78=
github.com/nxadm/tail v1.4.4/go.mod h1:kenIhsEOeOJmVchQTgglprH7qJGnHDVpk1VPCcaMI8A=
@ -174,6 +208,7 @@ github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d h1:zE9ykE
github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d/go.mod h1:OnSkiWE9lh6wB0YB77sQom3nweQdgAjqCqsofrRNTgc=
github.com/smartystreets/goconvey v1.6.4 h1:fv0U8FUIMPNf1L9lnHLvLhgicrIVChEkdzIKYqbNC9s=
github.com/smartystreets/goconvey v1.6.4/go.mod h1:syvi0/a8iFYH4r/RixwvyeAJjdLS9QV7WQ/tjFTllLA=
github.com/stretchr/objx v0.1.0 h1:4G4v2dO3VZwixGIRoQ5Lfboy6nUhCyYzaqnIAPPhYs4=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
@ -182,6 +217,10 @@ github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5
github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/superseriousbusiness/exifremove v0.0.0-20210330092427-6acd27eac203 h1:1SWXcTphBQjYGWRRxLFIAR1LVtQEj4eR7xPtyeOVM/c=
github.com/superseriousbusiness/exifremove v0.0.0-20210330092427-6acd27eac203/go.mod h1:0Xw5cYMOYpgaWs+OOSx41ugycl2qvKTi9tlMMcZhFyY=
github.com/superseriousbusiness/oauth2/v4 v4.2.1-0.20210327102222-902aba1ef45f h1:0YcjA/ieDuDFHJPg5w2hk3r5kIWNvEyl7GsoArxdI3s=
github.com/superseriousbusiness/oauth2/v4 v4.2.1-0.20210327102222-902aba1ef45f/go.mod h1:8p0a/BEN9hhsGzE3tPaFFlIZgxAaLyLN5KY0bPg9ZBc=
github.com/tidwall/btree v0.0.0-20191029221954-400434d76274/go.mod h1:huei1BkDWJ3/sLXmO+bsCNELL+Bp2Kks9OLyQFkzvA8=
github.com/tidwall/btree v0.3.0/go.mod h1:huei1BkDWJ3/sLXmO+bsCNELL+Bp2Kks9OLyQFkzvA8=
github.com/tidwall/btree v0.4.2 h1:aLwwJlG+InuFzdAPuBf9YCAR1LvSQ9zhC5aorFPlIPs=
@ -235,6 +274,8 @@ github.com/vmihailenco/tagparser v0.1.2 h1:gnjoVuB/kljJ5wICEEOpx98oXMWPLj22G67Vb
github.com/vmihailenco/tagparser v0.1.2/go.mod h1:OeAg3pn3UbLjkWt+rN9oFYB6u/cQgqMEUPoW2WPyhdI=
github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAhO7/IwNM9g=
github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod h1:Wri+At7QHww0WTrCBeu4J6bNtoV6mEfg5OIWRZA9qds=
github.com/wagslane/go-password-validator v0.3.0 h1:vfxOPzGHkz5S146HDpavl0cw1DSVP061Ry2PX0/ON6I=
github.com/wagslane/go-password-validator v0.3.0/go.mod h1:TI1XJ6T5fRdRnHqHt14pvy1tNVnrwe7m3/f1f2fDphQ=
github.com/xeipuuv/gojsonpointer v0.0.0-20180127040702-4e3ac2762d5f h1:J9EGpcZtP0E/raorCMxlFGSTBrsSlaDGf3jU/qvAE2c=
github.com/xeipuuv/gojsonpointer v0.0.0-20180127040702-4e3ac2762d5f/go.mod h1:N2zxlSyiKSe5eX1tZViRH5QA0qijqEDrYZiPEAiq3wU=
github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415 h1:EzJWgHovont7NscjpAxXsDA8S8BMYve8Y5+7cuRE7R0=
@ -280,9 +321,14 @@ golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20191209160850-c0dbc17a3553/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20200301022130-244492dfa37a/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20200320220750-118fecf932d8/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20200324143707-d3edc9973b7e/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A=
golang.org/x/net v0.0.0-20200501053045-e0ff5e5a1de5/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A=
golang.org/x/net v0.0.0-20200513185701-a91f0712d120/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A=
golang.org/x/net v0.0.0-20200520004742-59133d7f0dd7/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A=
golang.org/x/net v0.0.0-20200520182314-0ba52f642ac2/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A=
golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA=
golang.org/x/net v0.0.0-20201006153459-a7d1128ccaa0/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
golang.org/x/net v0.0.0-20201010224723-4f7140c49acb/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
@ -371,6 +417,7 @@ gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWD
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v2 v2.2.3/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v2 v2.2.4/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v2 v2.2.7/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v2 v2.3.0 h1:clyUAQHOM3G0M3f5vQj7LuJrETvjVot3Z5el9nffUtU=
gopkg.in/yaml.v2 v2.3.0/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=

View file

@ -21,8 +21,8 @@ package action
import (
"context"
"github.com/gotosocial/gotosocial/internal/config"
"github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/config"
)
// GTSAction defines one *action* that can be taken by the gotosocial cli command.

View file

@ -0,0 +1,32 @@
// Code generated by mockery v2.7.4. DO NOT EDIT.
package action
import (
context "context"
config "github.com/superseriousbusiness/gotosocial/internal/config"
logrus "github.com/sirupsen/logrus"
mock "github.com/stretchr/testify/mock"
)
// MockGTSAction is an autogenerated mock type for the GTSAction type
type MockGTSAction struct {
mock.Mock
}
// Execute provides a mock function with given fields: _a0, _a1, _a2
func (_m *MockGTSAction) Execute(_a0 context.Context, _a1 *config.Config, _a2 *logrus.Logger) error {
ret := _m.Called(_a0, _a1, _a2)
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, *config.Config, *logrus.Logger) error); ok {
r0 = rf(_a0, _a1, _a2)
} else {
r0 = ret.Error(0)
}
return r0
}

View file

@ -0,0 +1,100 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package account
import (
"fmt"
"net/http"
"strings"
"github.com/gin-gonic/gin"
"github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/apimodule"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/db/model"
"github.com/superseriousbusiness/gotosocial/internal/media"
"github.com/superseriousbusiness/gotosocial/internal/oauth"
"github.com/superseriousbusiness/gotosocial/internal/router"
)
const (
idKey = "id"
basePath = "/api/v1/accounts"
basePathWithID = basePath + "/:" + idKey
verifyPath = basePath + "/verify_credentials"
updateCredentialsPath = basePath + "/update_credentials"
)
type accountModule struct {
config *config.Config
db db.DB
oauthServer oauth.Server
mediaHandler media.MediaHandler
log *logrus.Logger
}
// New returns a new account module
func New(config *config.Config, db db.DB, oauthServer oauth.Server, mediaHandler media.MediaHandler, log *logrus.Logger) apimodule.ClientAPIModule {
return &accountModule{
config: config,
db: db,
oauthServer: oauthServer,
mediaHandler: mediaHandler,
log: log,
}
}
// Route attaches all routes from this module to the given router
func (m *accountModule) Route(r router.Router) error {
r.AttachHandler(http.MethodPost, basePath, m.accountCreatePOSTHandler)
r.AttachHandler(http.MethodGet, basePathWithID, m.muxHandler)
return nil
}
func (m *accountModule) CreateTables(db db.DB) error {
models := []interface{}{
&model.User{},
&model.Account{},
&model.Follow{},
&model.FollowRequest{},
&model.Status{},
&model.Application{},
&model.EmailDomainBlock{},
&model.MediaAttachment{},
}
for _, m := range models {
if err := db.CreateTable(m); err != nil {
return fmt.Errorf("error creating table: %s", err)
}
}
return nil
}
func (m *accountModule) muxHandler(c *gin.Context) {
ru := c.Request.RequestURI
if strings.HasPrefix(ru, verifyPath) {
m.accountVerifyGETHandler(c)
} else if strings.HasPrefix(ru, updateCredentialsPath) {
m.accountUpdateCredentialsPATCHHandler(c)
} else {
m.accountGETHandler(c)
}
}

View file

@ -0,0 +1,155 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package account
import (
"errors"
"fmt"
"net"
"net/http"
"github.com/gin-gonic/gin"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/db/model"
"github.com/superseriousbusiness/gotosocial/internal/oauth"
"github.com/superseriousbusiness/gotosocial/internal/util"
"github.com/superseriousbusiness/gotosocial/pkg/mastotypes"
"github.com/superseriousbusiness/oauth2/v4"
)
// accountCreatePOSTHandler handles create account requests, validates them,
// and puts them in the database if they're valid.
// It should be served as a POST at /api/v1/accounts
func (m *accountModule) accountCreatePOSTHandler(c *gin.Context) {
l := m.log.WithField("func", "accountCreatePOSTHandler")
authed, err := oauth.MustAuth(c, true, true, false, false)
if err != nil {
l.Debugf("couldn't auth: %s", err)
c.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
return
}
l.Trace("parsing request form")
form := &mastotypes.AccountCreateRequest{}
if err := c.ShouldBind(form); err != nil || form == nil {
l.Debugf("could not parse form from request: %s", err)
c.JSON(http.StatusBadRequest, gin.H{"error": "missing one or more required form values"})
return
}
l.Tracef("validating form %+v", form)
if err := validateCreateAccount(form, m.config.AccountsConfig, m.db); err != nil {
l.Debugf("error validating form: %s", err)
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
clientIP := c.ClientIP()
l.Tracef("attempting to parse client ip address %s", clientIP)
signUpIP := net.ParseIP(clientIP)
if signUpIP == nil {
l.Debugf("error validating sign up ip address %s", clientIP)
c.JSON(http.StatusBadRequest, gin.H{"error": "ip address could not be parsed from request"})
return
}
ti, err := m.accountCreate(form, signUpIP, authed.Token, authed.Application)
if err != nil {
l.Errorf("internal server error while creating new account: %s", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, ti)
}
// accountCreate does the dirty work of making an account and user in the database.
// It then returns a token to the caller, for use with the new account, as per the
// spec here: https://docs.joinmastodon.org/methods/accounts/
func (m *accountModule) accountCreate(form *mastotypes.AccountCreateRequest, signUpIP net.IP, token oauth2.TokenInfo, app *model.Application) (*mastotypes.Token, error) {
l := m.log.WithField("func", "accountCreate")
// don't store a reason if we don't require one
reason := form.Reason
if !m.config.AccountsConfig.ReasonRequired {
reason = ""
}
l.Trace("creating new username and account")
user, err := m.db.NewSignup(form.Username, reason, m.config.AccountsConfig.RequireApproval, form.Email, form.Password, signUpIP, form.Locale, app.ID)
if err != nil {
return nil, fmt.Errorf("error creating new signup in the database: %s", err)
}
l.Tracef("generating a token for user %s with account %s and application %s", user.ID, user.AccountID, app.ID)
accessToken, err := m.oauthServer.GenerateUserAccessToken(token, app.ClientSecret, user.ID)
if err != nil {
return nil, fmt.Errorf("error creating new access token for user %s: %s", user.ID, err)
}
return &mastotypes.Token{
AccessToken: accessToken.GetAccess(),
TokenType: "Bearer",
Scope: accessToken.GetScope(),
CreatedAt: accessToken.GetAccessCreateAt().Unix(),
}, nil
}
// validateCreateAccount checks through all the necessary prerequisites for creating a new account,
// according to the provided account create request. If the account isn't eligible, an error will be returned.
func validateCreateAccount(form *mastotypes.AccountCreateRequest, c *config.AccountsConfig, database db.DB) error {
if !c.OpenRegistration {
return errors.New("registration is not open for this server")
}
if err := util.ValidateUsername(form.Username); err != nil {
return err
}
if err := util.ValidateEmail(form.Email); err != nil {
return err
}
if err := util.ValidateNewPassword(form.Password); err != nil {
return err
}
if !form.Agreement {
return errors.New("agreement to terms and conditions not given")
}
if err := util.ValidateLanguage(form.Locale); err != nil {
return err
}
if err := util.ValidateSignUpReason(form.Reason, c.ReasonRequired); err != nil {
return err
}
if err := database.IsEmailAvailable(form.Email); err != nil {
return err
}
if err := database.IsUsernameAvailable(form.Username); err != nil {
return err
}
return nil
}

View file

@ -0,0 +1,545 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package account
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"io/ioutil"
"mime/multipart"
"net/http"
"net/http/httptest"
"net/url"
"os"
"testing"
"time"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/db/model"
"github.com/superseriousbusiness/gotosocial/internal/media"
"github.com/superseriousbusiness/gotosocial/internal/oauth"
"github.com/superseriousbusiness/gotosocial/internal/storage"
"github.com/superseriousbusiness/gotosocial/pkg/mastotypes"
"github.com/superseriousbusiness/oauth2/v4"
"github.com/superseriousbusiness/oauth2/v4/models"
oauthmodels "github.com/superseriousbusiness/oauth2/v4/models"
"golang.org/x/crypto/bcrypt"
)
type AccountCreateTestSuite struct {
suite.Suite
config *config.Config
log *logrus.Logger
testAccountLocal *model.Account
testApplication *model.Application
testToken oauth2.TokenInfo
mockOauthServer *oauth.MockServer
mockStorage *storage.MockStorage
mediaHandler media.MediaHandler
db db.DB
accountModule *accountModule
newUserFormHappyPath url.Values
}
/*
TEST INFRASTRUCTURE
*/
// SetupSuite sets some variables on the suite that we can use as consts (more or less) throughout
func (suite *AccountCreateTestSuite) SetupSuite() {
// some of our subsequent entities need a log so create this here
log := logrus.New()
log.SetLevel(logrus.TraceLevel)
suite.log = log
suite.testAccountLocal = &model.Account{
ID: uuid.NewString(),
Username: "test_user",
}
// can use this test application throughout
suite.testApplication = &model.Application{
ID: "weeweeeeeeeeeeeeee",
Name: "a test application",
Website: "https://some-application-website.com",
RedirectURI: "http://localhost:8080",
ClientID: "a-known-client-id",
ClientSecret: "some-secret",
Scopes: "read",
VapidKey: "aaaaaa-aaaaaaaa-aaaaaaaaaaa",
}
// can use this test token throughout
suite.testToken = &oauthmodels.Token{
ClientID: "a-known-client-id",
RedirectURI: "http://localhost:8080",
Scope: "read",
Code: "123456789",
CodeCreateAt: time.Now(),
CodeExpiresIn: time.Duration(10 * time.Minute),
}
// Direct config to local postgres instance
c := config.Empty()
c.Protocol = "http"
c.Host = "localhost"
c.DBConfig = &config.DBConfig{
Type: "postgres",
Address: "localhost",
Port: 5432,
User: "postgres",
Password: "postgres",
Database: "postgres",
ApplicationName: "gotosocial",
}
c.MediaConfig = &config.MediaConfig{
MaxImageSize: 2 << 20,
}
c.StorageConfig = &config.StorageConfig{
Backend: "local",
BasePath: "/tmp",
ServeProtocol: "http",
ServeHost: "localhost",
ServeBasePath: "/fileserver/media",
}
suite.config = c
// use an actual database for this, because it's just easier than mocking one out
database, err := db.New(context.Background(), c, log)
if err != nil {
suite.FailNow(err.Error())
}
suite.db = database
// we need to mock the oauth server because account creation needs it to create a new token
suite.mockOauthServer = &oauth.MockServer{}
suite.mockOauthServer.On("GenerateUserAccessToken", suite.testToken, suite.testApplication.ClientSecret, mock.AnythingOfType("string")).Run(func(args mock.Arguments) {
l := suite.log.WithField("func", "GenerateUserAccessToken")
token := args.Get(0).(oauth2.TokenInfo)
l.Infof("received token %+v", token)
clientSecret := args.Get(1).(string)
l.Infof("received clientSecret %+v", clientSecret)
userID := args.Get(2).(string)
l.Infof("received userID %+v", userID)
}).Return(&models.Token{
Code: "we're authorized now!",
}, nil)
suite.mockStorage = &storage.MockStorage{}
// We don't need storage to do anything for these tests, so just simulate a success and do nothing -- we won't need to return anything from storage
suite.mockStorage.On("StoreFileAt", mock.AnythingOfType("string"), mock.AnythingOfType("[]uint8")).Return(nil)
// set a media handler because some handlers (eg update credentials) need to upload media (new header/avatar)
suite.mediaHandler = media.New(suite.config, suite.db, suite.mockStorage, log)
// and finally here's the thing we're actually testing!
suite.accountModule = New(suite.config, suite.db, suite.mockOauthServer, suite.mediaHandler, suite.log).(*accountModule)
}
func (suite *AccountCreateTestSuite) TearDownSuite() {
if err := suite.db.Stop(context.Background()); err != nil {
logrus.Panicf("error closing db connection: %s", err)
}
}
// SetupTest creates a db connection and creates necessary tables before each test
func (suite *AccountCreateTestSuite) SetupTest() {
// create all the tables we might need in thie suite
models := []interface{}{
&model.User{},
&model.Account{},
&model.Follow{},
&model.FollowRequest{},
&model.Status{},
&model.Application{},
&model.EmailDomainBlock{},
&model.MediaAttachment{},
}
for _, m := range models {
if err := suite.db.CreateTable(m); err != nil {
logrus.Panicf("db connection error: %s", err)
}
}
// form to submit for happy path account create requests -- this will be changed inside tests so it's better to set it before each test
suite.newUserFormHappyPath = url.Values{
"reason": []string{"a very good reason that's at least 40 characters i swear"},
"username": []string{"test_user"},
"email": []string{"user@example.org"},
"password": []string{"very-strong-password"},
"agreement": []string{"true"},
"locale": []string{"en"},
}
// same with accounts config
suite.config.AccountsConfig = &config.AccountsConfig{
OpenRegistration: true,
RequireApproval: true,
ReasonRequired: true,
}
}
// TearDownTest drops tables to make sure there's no data in the db
func (suite *AccountCreateTestSuite) TearDownTest() {
// remove all the tables we might have used so it's clear for the next test
models := []interface{}{
&model.User{},
&model.Account{},
&model.Follow{},
&model.FollowRequest{},
&model.Status{},
&model.Application{},
&model.EmailDomainBlock{},
&model.MediaAttachment{},
}
for _, m := range models {
if err := suite.db.DropTable(m); err != nil {
logrus.Panicf("error dropping table: %s", err)
}
}
}
/*
ACTUAL TESTS
*/
/*
TESTING: AccountCreatePOSTHandler
*/
// TestAccountCreatePOSTHandlerSuccessful checks the happy path for an account creation request: all the fields provided are valid,
// and at the end of it a new user and account should be added into the database.
//
// This is the handler served at /api/v1/accounts as POST
func (suite *AccountCreateTestSuite) TestAccountCreatePOSTHandlerSuccessful() {
// setup
recorder := httptest.NewRecorder()
ctx, _ := gin.CreateTestContext(recorder)
ctx.Set(oauth.SessionAuthorizedApplication, suite.testApplication)
ctx.Set(oauth.SessionAuthorizedToken, suite.testToken)
ctx.Request = httptest.NewRequest(http.MethodPost, fmt.Sprintf("http://localhost:8080/%s", basePath), nil) // the endpoint we're hitting
ctx.Request.Form = suite.newUserFormHappyPath
suite.accountModule.accountCreatePOSTHandler(ctx)
// check response
// 1. we should have OK from our call to the function
suite.EqualValues(http.StatusOK, recorder.Code)
// 2. we should have a token in the result body
result := recorder.Result()
defer result.Body.Close()
b, err := ioutil.ReadAll(result.Body)
assert.NoError(suite.T(), err)
t := &mastotypes.Token{}
err = json.Unmarshal(b, t)
assert.NoError(suite.T(), err)
assert.Equal(suite.T(), "we're authorized now!", t.AccessToken)
// check new account
// 1. we should be able to get the new account from the db
acct := &model.Account{}
err = suite.db.GetWhere("username", "test_user", acct)
assert.NoError(suite.T(), err)
assert.NotNil(suite.T(), acct)
// 2. reason should be set
assert.Equal(suite.T(), suite.newUserFormHappyPath.Get("reason"), acct.Reason)
// 3. display name should be equal to username by default
assert.Equal(suite.T(), suite.newUserFormHappyPath.Get("username"), acct.DisplayName)
// 4. domain should be nil because this is a local account
assert.Nil(suite.T(), nil, acct.Domain)
// 5. id should be set and parseable as a uuid
assert.NotNil(suite.T(), acct.ID)
_, err = uuid.Parse(acct.ID)
assert.Nil(suite.T(), err)
// 6. private and public key should be set
assert.NotNil(suite.T(), acct.PrivateKey)
assert.NotNil(suite.T(), acct.PublicKey)
// check new user
// 1. we should be able to get the new user from the db
usr := &model.User{}
err = suite.db.GetWhere("unconfirmed_email", suite.newUserFormHappyPath.Get("email"), usr)
assert.Nil(suite.T(), err)
assert.NotNil(suite.T(), usr)
// 2. user should have account id set to account we got above
assert.Equal(suite.T(), acct.ID, usr.AccountID)
// 3. id should be set and parseable as a uuid
assert.NotNil(suite.T(), usr.ID)
_, err = uuid.Parse(usr.ID)
assert.Nil(suite.T(), err)
// 4. locale should be equal to what we requested
assert.Equal(suite.T(), suite.newUserFormHappyPath.Get("locale"), usr.Locale)
// 5. created by application id should be equal to the app id
assert.Equal(suite.T(), suite.testApplication.ID, usr.CreatedByApplicationID)
// 6. password should be matcheable to what we set above
err = bcrypt.CompareHashAndPassword([]byte(usr.EncryptedPassword), []byte(suite.newUserFormHappyPath.Get("password")))
assert.Nil(suite.T(), err)
}
// TestAccountCreatePOSTHandlerNoAuth makes sure that the handler fails when no authorization is provided:
// only registered applications can create accounts, and we don't provide one here.
func (suite *AccountCreateTestSuite) TestAccountCreatePOSTHandlerNoAuth() {
// setup
recorder := httptest.NewRecorder()
ctx, _ := gin.CreateTestContext(recorder)
ctx.Request = httptest.NewRequest(http.MethodPost, fmt.Sprintf("http://localhost:8080/%s", basePath), nil) // the endpoint we're hitting
ctx.Request.Form = suite.newUserFormHappyPath
suite.accountModule.accountCreatePOSTHandler(ctx)
// check response
// 1. we should have forbidden from our call to the function because we didn't auth
suite.EqualValues(http.StatusForbidden, recorder.Code)
// 2. we should have an error message in the result body
result := recorder.Result()
defer result.Body.Close()
b, err := ioutil.ReadAll(result.Body)
assert.NoError(suite.T(), err)
assert.Equal(suite.T(), `{"error":"not authorized"}`, string(b))
}
// TestAccountCreatePOSTHandlerNoAuth makes sure that the handler fails when no form is provided at all.
func (suite *AccountCreateTestSuite) TestAccountCreatePOSTHandlerNoForm() {
// setup
recorder := httptest.NewRecorder()
ctx, _ := gin.CreateTestContext(recorder)
ctx.Set(oauth.SessionAuthorizedApplication, suite.testApplication)
ctx.Set(oauth.SessionAuthorizedToken, suite.testToken)
ctx.Request = httptest.NewRequest(http.MethodPost, fmt.Sprintf("http://localhost:8080/%s", basePath), nil) // the endpoint we're hitting
suite.accountModule.accountCreatePOSTHandler(ctx)
// check response
suite.EqualValues(http.StatusBadRequest, recorder.Code)
// 2. we should have an error message in the result body
result := recorder.Result()
defer result.Body.Close()
b, err := ioutil.ReadAll(result.Body)
assert.NoError(suite.T(), err)
assert.Equal(suite.T(), `{"error":"missing one or more required form values"}`, string(b))
}
// TestAccountCreatePOSTHandlerWeakPassword makes sure that the handler fails when a weak password is provided
func (suite *AccountCreateTestSuite) TestAccountCreatePOSTHandlerWeakPassword() {
// setup
recorder := httptest.NewRecorder()
ctx, _ := gin.CreateTestContext(recorder)
ctx.Set(oauth.SessionAuthorizedApplication, suite.testApplication)
ctx.Set(oauth.SessionAuthorizedToken, suite.testToken)
ctx.Request = httptest.NewRequest(http.MethodPost, fmt.Sprintf("http://localhost:8080/%s", basePath), nil) // the endpoint we're hitting
ctx.Request.Form = suite.newUserFormHappyPath
// set a weak password
ctx.Request.Form.Set("password", "weak")
suite.accountModule.accountCreatePOSTHandler(ctx)
// check response
suite.EqualValues(http.StatusBadRequest, recorder.Code)
// 2. we should have an error message in the result body
result := recorder.Result()
defer result.Body.Close()
b, err := ioutil.ReadAll(result.Body)
assert.NoError(suite.T(), err)
assert.Equal(suite.T(), `{"error":"insecure password, try including more special characters, using uppercase letters, using numbers or using a longer password"}`, string(b))
}
// TestAccountCreatePOSTHandlerWeirdLocale makes sure that the handler fails when a weird locale is provided
func (suite *AccountCreateTestSuite) TestAccountCreatePOSTHandlerWeirdLocale() {
// setup
recorder := httptest.NewRecorder()
ctx, _ := gin.CreateTestContext(recorder)
ctx.Set(oauth.SessionAuthorizedApplication, suite.testApplication)
ctx.Set(oauth.SessionAuthorizedToken, suite.testToken)
ctx.Request = httptest.NewRequest(http.MethodPost, fmt.Sprintf("http://localhost:8080/%s", basePath), nil) // the endpoint we're hitting
ctx.Request.Form = suite.newUserFormHappyPath
// set an invalid locale
ctx.Request.Form.Set("locale", "neverneverland")
suite.accountModule.accountCreatePOSTHandler(ctx)
// check response
suite.EqualValues(http.StatusBadRequest, recorder.Code)
// 2. we should have an error message in the result body
result := recorder.Result()
defer result.Body.Close()
b, err := ioutil.ReadAll(result.Body)
assert.NoError(suite.T(), err)
assert.Equal(suite.T(), `{"error":"language: tag is not well-formed"}`, string(b))
}
// TestAccountCreatePOSTHandlerRegistrationsClosed makes sure that the handler fails when registrations are closed
func (suite *AccountCreateTestSuite) TestAccountCreatePOSTHandlerRegistrationsClosed() {
// setup
recorder := httptest.NewRecorder()
ctx, _ := gin.CreateTestContext(recorder)
ctx.Set(oauth.SessionAuthorizedApplication, suite.testApplication)
ctx.Set(oauth.SessionAuthorizedToken, suite.testToken)
ctx.Request = httptest.NewRequest(http.MethodPost, fmt.Sprintf("http://localhost:8080/%s", basePath), nil) // the endpoint we're hitting
ctx.Request.Form = suite.newUserFormHappyPath
// close registrations
suite.config.AccountsConfig.OpenRegistration = false
suite.accountModule.accountCreatePOSTHandler(ctx)
// check response
suite.EqualValues(http.StatusBadRequest, recorder.Code)
// 2. we should have an error message in the result body
result := recorder.Result()
defer result.Body.Close()
b, err := ioutil.ReadAll(result.Body)
assert.NoError(suite.T(), err)
assert.Equal(suite.T(), `{"error":"registration is not open for this server"}`, string(b))
}
// TestAccountCreatePOSTHandlerReasonNotProvided makes sure that the handler fails when no reason is provided but one is required
func (suite *AccountCreateTestSuite) TestAccountCreatePOSTHandlerReasonNotProvided() {
// setup
recorder := httptest.NewRecorder()
ctx, _ := gin.CreateTestContext(recorder)
ctx.Set(oauth.SessionAuthorizedApplication, suite.testApplication)
ctx.Set(oauth.SessionAuthorizedToken, suite.testToken)
ctx.Request = httptest.NewRequest(http.MethodPost, fmt.Sprintf("http://localhost:8080/%s", basePath), nil) // the endpoint we're hitting
ctx.Request.Form = suite.newUserFormHappyPath
// remove reason
ctx.Request.Form.Set("reason", "")
suite.accountModule.accountCreatePOSTHandler(ctx)
// check response
suite.EqualValues(http.StatusBadRequest, recorder.Code)
// 2. we should have an error message in the result body
result := recorder.Result()
defer result.Body.Close()
b, err := ioutil.ReadAll(result.Body)
assert.NoError(suite.T(), err)
assert.Equal(suite.T(), `{"error":"no reason provided"}`, string(b))
}
// TestAccountCreatePOSTHandlerReasonNotProvided makes sure that the handler fails when a crappy reason is presented but a good one is required
func (suite *AccountCreateTestSuite) TestAccountCreatePOSTHandlerInsufficientReason() {
// setup
recorder := httptest.NewRecorder()
ctx, _ := gin.CreateTestContext(recorder)
ctx.Set(oauth.SessionAuthorizedApplication, suite.testApplication)
ctx.Set(oauth.SessionAuthorizedToken, suite.testToken)
ctx.Request = httptest.NewRequest(http.MethodPost, fmt.Sprintf("http://localhost:8080/%s", basePath), nil) // the endpoint we're hitting
ctx.Request.Form = suite.newUserFormHappyPath
// remove reason
ctx.Request.Form.Set("reason", "just cuz")
suite.accountModule.accountCreatePOSTHandler(ctx)
// check response
suite.EqualValues(http.StatusBadRequest, recorder.Code)
// 2. we should have an error message in the result body
result := recorder.Result()
defer result.Body.Close()
b, err := ioutil.ReadAll(result.Body)
assert.NoError(suite.T(), err)
assert.Equal(suite.T(), `{"error":"reason should be at least 40 chars but 'just cuz' was 8"}`, string(b))
}
/*
TESTING: AccountUpdateCredentialsPATCHHandler
*/
func (suite *AccountCreateTestSuite) TestAccountUpdateCredentialsPATCHHandler() {
// put test local account in db
err := suite.db.Put(suite.testAccountLocal)
assert.NoError(suite.T(), err)
// attach avatar to request
aviFile, err := os.Open("../../media/test/test-jpeg.jpg")
assert.NoError(suite.T(), err)
body := &bytes.Buffer{}
writer := multipart.NewWriter(body)
part, err := writer.CreateFormFile("avatar", "test-jpeg.jpg")
assert.NoError(suite.T(), err)
_, err = io.Copy(part, aviFile)
assert.NoError(suite.T(), err)
err = aviFile.Close()
assert.NoError(suite.T(), err)
err = writer.Close()
assert.NoError(suite.T(), err)
// setup
recorder := httptest.NewRecorder()
ctx, _ := gin.CreateTestContext(recorder)
ctx.Set(oauth.SessionAuthorizedAccount, suite.testAccountLocal)
ctx.Set(oauth.SessionAuthorizedToken, suite.testToken)
ctx.Request = httptest.NewRequest(http.MethodPatch, fmt.Sprintf("http://localhost:8080/%s", updateCredentialsPath), body) // the endpoint we're hitting
ctx.Request.Header.Set("Content-Type", writer.FormDataContentType())
suite.accountModule.accountUpdateCredentialsPATCHHandler(ctx)
// check response
// 1. we should have OK because our request was valid
suite.EqualValues(http.StatusOK, recorder.Code)
// 2. we should have an error message in the result body
result := recorder.Result()
defer result.Body.Close()
// TODO: implement proper checks here
//
// b, err := ioutil.ReadAll(result.Body)
// assert.NoError(suite.T(), err)
// assert.Equal(suite.T(), `{"error":"not authorized"}`, string(b))
}
func TestAccountCreateTestSuite(t *testing.T) {
suite.Run(t, new(AccountCreateTestSuite))
}

View file

@ -0,0 +1,57 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package account
import (
"net/http"
"github.com/gin-gonic/gin"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/db/model"
)
// accountGetHandler serves the account information held by the server in response to a GET
// request. It should be served as a GET at /api/v1/accounts/:id.
//
// See: https://docs.joinmastodon.org/methods/accounts/
func (m *accountModule) accountGETHandler(c *gin.Context) {
targetAcctID := c.Param(idKey)
if targetAcctID == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "no account id specified"})
return
}
targetAccount := &model.Account{}
if err := m.db.GetByID(targetAcctID, targetAccount); err != nil {
if _, ok := err.(db.ErrNoEntries); ok {
c.JSON(http.StatusNotFound, gin.H{"error": "Record not found"})
return
}
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
acctInfo, err := m.db.AccountToMastoPublic(targetAccount)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, acctInfo)
}

View file

@ -0,0 +1,259 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package account
import (
"bytes"
"errors"
"fmt"
"io"
"mime/multipart"
"net/http"
"github.com/gin-gonic/gin"
"github.com/superseriousbusiness/gotosocial/internal/db/model"
"github.com/superseriousbusiness/gotosocial/internal/oauth"
"github.com/superseriousbusiness/gotosocial/internal/util"
"github.com/superseriousbusiness/gotosocial/pkg/mastotypes"
)
// accountUpdateCredentialsPATCHHandler allows a user to modify their account/profile settings.
// It should be served as a PATCH at /api/v1/accounts/update_credentials
//
// TODO: this can be optimized massively by building up a picture of what we want the new account
// details to be, and then inserting it all in the database at once. As it is, we do queries one-by-one
// which is not gonna make the database very happy when lots of requests are going through.
// This way it would also be safer because the update won't happen until *all* the fields are validated.
// Otherwise we risk doing a partial update and that's gonna cause probllleeemmmsss.
func (m *accountModule) accountUpdateCredentialsPATCHHandler(c *gin.Context) {
l := m.log.WithField("func", "accountUpdateCredentialsPATCHHandler")
authed, err := oauth.MustAuth(c, true, false, false, true)
if err != nil {
l.Debugf("couldn't auth: %s", err)
c.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
return
}
l.Tracef("retrieved account %+v", authed.Account.ID)
l.Trace("parsing request form")
form := &mastotypes.UpdateCredentialsRequest{}
if err := c.ShouldBind(form); err != nil || form == nil {
l.Debugf("could not parse form from request: %s", err)
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
// if everything on the form is nil, then nothing has been set and we shouldn't continue
if form.Discoverable == nil && form.Bot == nil && form.DisplayName == nil && form.Note == nil && form.Avatar == nil && form.Header == nil && form.Locked == nil && form.Source == nil && form.FieldsAttributes == nil {
l.Debugf("could not parse form from request")
c.JSON(http.StatusBadRequest, gin.H{"error": "empty form submitted"})
return
}
if form.Discoverable != nil {
if err := m.db.UpdateOneByID(authed.Account.ID, "discoverable", *form.Discoverable, &model.Account{}); err != nil {
l.Debugf("error updating discoverable: %s", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
}
if form.Bot != nil {
if err := m.db.UpdateOneByID(authed.Account.ID, "bot", *form.Bot, &model.Account{}); err != nil {
l.Debugf("error updating bot: %s", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
}
if form.DisplayName != nil {
if err := util.ValidateDisplayName(*form.DisplayName); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if err := m.db.UpdateOneByID(authed.Account.ID, "display_name", *form.DisplayName, &model.Account{}); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
}
if form.Note != nil {
if err := util.ValidateNote(*form.Note); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if err := m.db.UpdateOneByID(authed.Account.ID, "note", *form.Note, &model.Account{}); err != nil {
l.Debugf("error updating note: %s", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
}
if form.Avatar != nil && form.Avatar.Size != 0 {
avatarInfo, err := m.UpdateAccountAvatar(form.Avatar, authed.Account.ID)
if err != nil {
l.Debugf("could not update avatar for account %s: %s", authed.Account.ID, err)
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
l.Tracef("new avatar info for account %s is %+v", authed.Account.ID, avatarInfo)
}
if form.Header != nil && form.Header.Size != 0 {
headerInfo, err := m.UpdateAccountHeader(form.Header, authed.Account.ID)
if err != nil {
l.Debugf("could not update header for account %s: %s", authed.Account.ID, err)
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
l.Tracef("new header info for account %s is %+v", authed.Account.ID, headerInfo)
}
if form.Locked != nil {
if err := m.db.UpdateOneByID(authed.Account.ID, "locked", *form.Locked, &model.Account{}); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
}
if form.Source != nil {
if form.Source.Language != nil {
if err := util.ValidateLanguage(*form.Source.Language); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if err := m.db.UpdateOneByID(authed.Account.ID, "language", *form.Source.Language, &model.Account{}); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
}
if form.Source.Sensitive != nil {
if err := m.db.UpdateOneByID(authed.Account.ID, "locked", *form.Locked, &model.Account{}); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
}
if form.Source.Privacy != nil {
if err := util.ValidatePrivacy(*form.Source.Privacy); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if err := m.db.UpdateOneByID(authed.Account.ID, "privacy", *form.Source.Privacy, &model.Account{}); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
}
}
// if form.FieldsAttributes != nil {
// // TODO: parse fields attributes nicely and update
// }
// fetch the account with all updated values set
updatedAccount := &model.Account{}
if err := m.db.GetByID(authed.Account.ID, updatedAccount); err != nil {
l.Debugf("could not fetch updated account %s: %s", authed.Account.ID, err)
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
acctSensitive, err := m.db.AccountToMastoSensitive(updatedAccount)
if err != nil {
l.Tracef("could not convert account into mastosensitive account: %s", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
l.Tracef("conversion successful, returning OK and mastosensitive account %+v", acctSensitive)
c.JSON(http.StatusOK, acctSensitive)
}
/*
HELPER FUNCTIONS
*/
// TODO: try to combine the below two functions because this is a lot of code repetition.
// UpdateAccountAvatar does the dirty work of checking the avatar part of an account update form,
// parsing and checking the image, and doing the necessary updates in the database for this to become
// the account's new avatar image.
func (m *accountModule) UpdateAccountAvatar(avatar *multipart.FileHeader, accountID string) (*model.MediaAttachment, error) {
var err error
if int(avatar.Size) > m.config.MediaConfig.MaxImageSize {
err = fmt.Errorf("avatar with size %d exceeded max image size of %d bytes", avatar.Size, m.config.MediaConfig.MaxImageSize)
return nil, err
}
f, err := avatar.Open()
if err != nil {
return nil, fmt.Errorf("could not read provided avatar: %s", err)
}
// extract the bytes
buf := new(bytes.Buffer)
size, err := io.Copy(buf, f)
if err != nil {
return nil, fmt.Errorf("could not read provided avatar: %s", err)
}
if size == 0 {
return nil, errors.New("could not read provided avatar: size 0 bytes")
}
// do the setting
avatarInfo, err := m.mediaHandler.SetHeaderOrAvatarForAccountID(buf.Bytes(), accountID, "avatar")
if err != nil {
return nil, fmt.Errorf("error processing avatar: %s", err)
}
return avatarInfo, f.Close()
}
// UpdateAccountHeader does the dirty work of checking the header part of an account update form,
// parsing and checking the image, and doing the necessary updates in the database for this to become
// the account's new header image.
func (m *accountModule) UpdateAccountHeader(header *multipart.FileHeader, accountID string) (*model.MediaAttachment, error) {
var err error
if int(header.Size) > m.config.MediaConfig.MaxImageSize {
err = fmt.Errorf("header with size %d exceeded max image size of %d bytes", header.Size, m.config.MediaConfig.MaxImageSize)
return nil, err
}
f, err := header.Open()
if err != nil {
return nil, fmt.Errorf("could not read provided header: %s", err)
}
// extract the bytes
buf := new(bytes.Buffer)
size, err := io.Copy(buf, f)
if err != nil {
return nil, fmt.Errorf("could not read provided header: %s", err)
}
if size == 0 {
return nil, errors.New("could not read provided header: size 0 bytes")
}
// do the setting
headerInfo, err := m.mediaHandler.SetHeaderOrAvatarForAccountID(buf.Bytes(), accountID, "header")
if err != nil {
return nil, fmt.Errorf("error processing header: %s", err)
}
return headerInfo, f.Close()
}

View file

@ -0,0 +1,298 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package account
import (
"bytes"
"context"
"fmt"
"io"
"mime/multipart"
"net/http"
"net/http/httptest"
"net/url"
"os"
"testing"
"time"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/db/model"
"github.com/superseriousbusiness/gotosocial/internal/media"
"github.com/superseriousbusiness/gotosocial/internal/oauth"
"github.com/superseriousbusiness/gotosocial/internal/storage"
"github.com/superseriousbusiness/oauth2/v4"
"github.com/superseriousbusiness/oauth2/v4/models"
oauthmodels "github.com/superseriousbusiness/oauth2/v4/models"
)
type AccountUpdateTestSuite struct {
suite.Suite
config *config.Config
log *logrus.Logger
testAccountLocal *model.Account
testApplication *model.Application
testToken oauth2.TokenInfo
mockOauthServer *oauth.MockServer
mockStorage *storage.MockStorage
mediaHandler media.MediaHandler
db db.DB
accountModule *accountModule
newUserFormHappyPath url.Values
}
/*
TEST INFRASTRUCTURE
*/
// SetupSuite sets some variables on the suite that we can use as consts (more or less) throughout
func (suite *AccountUpdateTestSuite) SetupSuite() {
// some of our subsequent entities need a log so create this here
log := logrus.New()
log.SetLevel(logrus.TraceLevel)
suite.log = log
suite.testAccountLocal = &model.Account{
ID: uuid.NewString(),
Username: "test_user",
}
// can use this test application throughout
suite.testApplication = &model.Application{
ID: "weeweeeeeeeeeeeeee",
Name: "a test application",
Website: "https://some-application-website.com",
RedirectURI: "http://localhost:8080",
ClientID: "a-known-client-id",
ClientSecret: "some-secret",
Scopes: "read",
VapidKey: "aaaaaa-aaaaaaaa-aaaaaaaaaaa",
}
// can use this test token throughout
suite.testToken = &oauthmodels.Token{
ClientID: "a-known-client-id",
RedirectURI: "http://localhost:8080",
Scope: "read",
Code: "123456789",
CodeCreateAt: time.Now(),
CodeExpiresIn: time.Duration(10 * time.Minute),
}
// Direct config to local postgres instance
c := config.Empty()
c.Protocol = "http"
c.Host = "localhost"
c.DBConfig = &config.DBConfig{
Type: "postgres",
Address: "localhost",
Port: 5432,
User: "postgres",
Password: "postgres",
Database: "postgres",
ApplicationName: "gotosocial",
}
c.MediaConfig = &config.MediaConfig{
MaxImageSize: 2 << 20,
}
c.StorageConfig = &config.StorageConfig{
Backend: "local",
BasePath: "/tmp",
ServeProtocol: "http",
ServeHost: "localhost",
ServeBasePath: "/fileserver/media",
}
suite.config = c
// use an actual database for this, because it's just easier than mocking one out
database, err := db.New(context.Background(), c, log)
if err != nil {
suite.FailNow(err.Error())
}
suite.db = database
// we need to mock the oauth server because account creation needs it to create a new token
suite.mockOauthServer = &oauth.MockServer{}
suite.mockOauthServer.On("GenerateUserAccessToken", suite.testToken, suite.testApplication.ClientSecret, mock.AnythingOfType("string")).Run(func(args mock.Arguments) {
l := suite.log.WithField("func", "GenerateUserAccessToken")
token := args.Get(0).(oauth2.TokenInfo)
l.Infof("received token %+v", token)
clientSecret := args.Get(1).(string)
l.Infof("received clientSecret %+v", clientSecret)
userID := args.Get(2).(string)
l.Infof("received userID %+v", userID)
}).Return(&models.Token{
Code: "we're authorized now!",
}, nil)
suite.mockStorage = &storage.MockStorage{}
// We don't need storage to do anything for these tests, so just simulate a success and do nothing -- we won't need to return anything from storage
suite.mockStorage.On("StoreFileAt", mock.AnythingOfType("string"), mock.AnythingOfType("[]uint8")).Return(nil)
// set a media handler because some handlers (eg update credentials) need to upload media (new header/avatar)
suite.mediaHandler = media.New(suite.config, suite.db, suite.mockStorage, log)
// and finally here's the thing we're actually testing!
suite.accountModule = New(suite.config, suite.db, suite.mockOauthServer, suite.mediaHandler, suite.log).(*accountModule)
}
func (suite *AccountUpdateTestSuite) TearDownSuite() {
if err := suite.db.Stop(context.Background()); err != nil {
logrus.Panicf("error closing db connection: %s", err)
}
}
// SetupTest creates a db connection and creates necessary tables before each test
func (suite *AccountUpdateTestSuite) SetupTest() {
// create all the tables we might need in thie suite
models := []interface{}{
&model.User{},
&model.Account{},
&model.Follow{},
&model.FollowRequest{},
&model.Status{},
&model.Application{},
&model.EmailDomainBlock{},
&model.MediaAttachment{},
}
for _, m := range models {
if err := suite.db.CreateTable(m); err != nil {
logrus.Panicf("db connection error: %s", err)
}
}
// form to submit for happy path account create requests -- this will be changed inside tests so it's better to set it before each test
suite.newUserFormHappyPath = url.Values{
"reason": []string{"a very good reason that's at least 40 characters i swear"},
"username": []string{"test_user"},
"email": []string{"user@example.org"},
"password": []string{"very-strong-password"},
"agreement": []string{"true"},
"locale": []string{"en"},
}
// same with accounts config
suite.config.AccountsConfig = &config.AccountsConfig{
OpenRegistration: true,
RequireApproval: true,
ReasonRequired: true,
}
}
// TearDownTest drops tables to make sure there's no data in the db
func (suite *AccountUpdateTestSuite) TearDownTest() {
// remove all the tables we might have used so it's clear for the next test
models := []interface{}{
&model.User{},
&model.Account{},
&model.Follow{},
&model.FollowRequest{},
&model.Status{},
&model.Application{},
&model.EmailDomainBlock{},
&model.MediaAttachment{},
}
for _, m := range models {
if err := suite.db.DropTable(m); err != nil {
logrus.Panicf("error dropping table: %s", err)
}
}
}
/*
ACTUAL TESTS
*/
/*
TESTING: AccountUpdateCredentialsPATCHHandler
*/
func (suite *AccountUpdateTestSuite) TestAccountUpdateCredentialsPATCHHandler() {
// put test local account in db
err := suite.db.Put(suite.testAccountLocal)
assert.NoError(suite.T(), err)
// attach avatar to request form
avatarFile, err := os.Open("../../media/test/test-jpeg.jpg")
assert.NoError(suite.T(), err)
body := &bytes.Buffer{}
writer := multipart.NewWriter(body)
avatarPart, err := writer.CreateFormFile("avatar", "test-jpeg.jpg")
assert.NoError(suite.T(), err)
_, err = io.Copy(avatarPart, avatarFile)
assert.NoError(suite.T(), err)
err = avatarFile.Close()
assert.NoError(suite.T(), err)
// set display name to a new value
displayNamePart, err := writer.CreateFormField("display_name")
assert.NoError(suite.T(), err)
_, err = io.Copy(displayNamePart, bytes.NewBufferString("test_user_wohoah"))
assert.NoError(suite.T(), err)
// set locked to true
lockedPart, err := writer.CreateFormField("locked")
assert.NoError(suite.T(), err)
_, err = io.Copy(lockedPart, bytes.NewBufferString("true"))
assert.NoError(suite.T(), err)
// close the request writer, the form is now prepared
err = writer.Close()
assert.NoError(suite.T(), err)
// setup
recorder := httptest.NewRecorder()
ctx, _ := gin.CreateTestContext(recorder)
ctx.Set(oauth.SessionAuthorizedAccount, suite.testAccountLocal)
ctx.Set(oauth.SessionAuthorizedToken, suite.testToken)
ctx.Request = httptest.NewRequest(http.MethodPatch, fmt.Sprintf("http://localhost:8080/%s", updateCredentialsPath), body) // the endpoint we're hitting
ctx.Request.Header.Set("Content-Type", writer.FormDataContentType())
suite.accountModule.accountUpdateCredentialsPATCHHandler(ctx)
// check response
// 1. we should have OK because our request was valid
suite.EqualValues(http.StatusOK, recorder.Code)
// 2. we should have an error message in the result body
result := recorder.Result()
defer result.Body.Close()
// TODO: implement proper checks here
//
// b, err := ioutil.ReadAll(result.Body)
// assert.NoError(suite.T(), err)
// assert.Equal(suite.T(), `{"error":"not authorized"}`, string(b))
}
func TestAccountUpdateTestSuite(t *testing.T) {
suite.Run(t, new(AccountUpdateTestSuite))
}

View file

@ -0,0 +1,50 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package account
import (
"net/http"
"github.com/gin-gonic/gin"
"github.com/superseriousbusiness/gotosocial/internal/oauth"
)
// accountVerifyGETHandler serves a user's account details to them IF they reached this
// handler while in possession of a valid token, according to the oauth middleware.
// It should be served as a GET at /api/v1/accounts/verify_credentials
func (m *accountModule) accountVerifyGETHandler(c *gin.Context) {
l := m.log.WithField("func", "accountVerifyGETHandler")
authed, err := oauth.MustAuth(c, true, false, false, true)
if err != nil {
l.Debugf("couldn't auth: %s", err)
c.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
return
}
l.Tracef("retrieved account %+v, converting to mastosensitive...", authed.Account.ID)
acctSensitive, err := m.db.AccountToMastoSensitive(authed.Account)
if err != nil {
l.Tracef("could not convert account into mastosensitive account: %s", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
l.Tracef("conversion successful, returning OK and mastosensitive account %+v", acctSensitive)
c.JSON(http.StatusOK, acctSensitive)
}

View file

@ -17,21 +17,3 @@
*/
package account
import (
"github.com/gotosocial/gotosocial/internal/module"
"github.com/gotosocial/gotosocial/internal/router"
)
type accountModule struct {
}
// New returns a new account module
func New() module.ClientAPIModule {
return &accountModule{}
}
// Route attaches all routes from this module to the given router
func (m *accountModule) Route(r router.Router) error {
return nil
}

View file

@ -16,14 +16,18 @@
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
// Package module is basically a wrapper for a lot of modules (in subdirectories) that satisfy the ClientAPIModule interface.
package module
// Package apimodule is basically a wrapper for a lot of modules (in subdirectories) that satisfy the ClientAPIModule interface.
package apimodule
import "github.com/gotosocial/gotosocial/internal/router"
import (
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/router"
)
// ClientAPIModule represents a chunk of code (usually contained in a single package) that adds a set
// of functionalities and side effects to a router, by mapping routes and handlers onto it--in other words, a REST API ;)
// A ClientAPIMpdule corresponds roughly to one main path of the gotosocial REST api, for example /api/v1/accounts/ or /oauth/
type ClientAPIModule interface {
Route(s router.Router) error
CreateTables(db db.DB) error
}

View file

@ -0,0 +1,71 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package app
import (
"fmt"
"net/http"
"github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/apimodule"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/db/model"
"github.com/superseriousbusiness/gotosocial/internal/oauth"
"github.com/superseriousbusiness/gotosocial/internal/router"
)
const appsPath = "/api/v1/apps"
type appModule struct {
server oauth.Server
db db.DB
log *logrus.Logger
}
// New returns a new auth module
func New(srv oauth.Server, db db.DB, log *logrus.Logger) apimodule.ClientAPIModule {
return &appModule{
server: srv,
db: db,
log: log,
}
}
// Route satisfies the RESTAPIModule interface
func (m *appModule) Route(s router.Router) error {
s.AttachHandler(http.MethodPost, appsPath, m.appsPOSTHandler)
return nil
}
func (m *appModule) CreateTables(db db.DB) error {
models := []interface{}{
&oauth.Client{},
&oauth.Token{},
&model.User{},
&model.Account{},
&model.Application{},
}
for _, m := range models {
if err := db.CreateTable(m); err != nil {
return fmt.Errorf("error creating table: %s", err)
}
}
return nil
}

View file

@ -0,0 +1,21 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package app
// TODO: write tests

View file

@ -0,0 +1,113 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package app
import (
"fmt"
"net/http"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
"github.com/superseriousbusiness/gotosocial/internal/db/model"
"github.com/superseriousbusiness/gotosocial/internal/oauth"
"github.com/superseriousbusiness/gotosocial/pkg/mastotypes"
)
// appsPOSTHandler should be served at https://example.org/api/v1/apps
// It is equivalent to: https://docs.joinmastodon.org/methods/apps/
func (m *appModule) appsPOSTHandler(c *gin.Context) {
l := m.log.WithField("func", "AppsPOSTHandler")
l.Trace("entering AppsPOSTHandler")
form := &mastotypes.ApplicationPOSTRequest{}
if err := c.ShouldBind(form); err != nil {
c.JSON(http.StatusUnprocessableEntity, gin.H{"error": err.Error()})
return
}
// permitted length for most fields
permittedLength := 64
// redirect can be a bit bigger because we probably need to encode data in the redirect uri
permittedRedirect := 256
// check lengths of fields before proceeding so the user can't spam huge entries into the database
if len(form.ClientName) > permittedLength {
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("client_name must be less than %d bytes", permittedLength)})
return
}
if len(form.Website) > permittedLength {
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("website must be less than %d bytes", permittedLength)})
return
}
if len(form.RedirectURIs) > permittedRedirect {
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("redirect_uris must be less than %d bytes", permittedRedirect)})
return
}
if len(form.Scopes) > permittedLength {
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("scopes must be less than %d bytes", permittedLength)})
return
}
// set default 'read' for scopes if it's not set, this follows the default of the mastodon api https://docs.joinmastodon.org/methods/apps/
var scopes string
if form.Scopes == "" {
scopes = "read"
} else {
scopes = form.Scopes
}
// generate new IDs for this application and its associated client
clientID := uuid.NewString()
clientSecret := uuid.NewString()
vapidKey := uuid.NewString()
// generate the application to put in the database
app := &model.Application{
Name: form.ClientName,
Website: form.Website,
RedirectURI: form.RedirectURIs,
ClientID: clientID,
ClientSecret: clientSecret,
Scopes: scopes,
VapidKey: vapidKey,
}
// chuck it in the db
if err := m.db.Put(app); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
// now we need to model an oauth client from the application that the oauth library can use
oc := &oauth.Client{
ID: clientID,
Secret: clientSecret,
Domain: form.RedirectURIs,
UserID: "", // This client isn't yet associated with a specific user, it's just an app client right now
}
// chuck it in the db
if err := m.db.Put(oc); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
// done, return the new app information per the spec here: https://docs.joinmastodon.org/methods/apps/
c.JSON(http.StatusOK, app.ToMasto())
}

View file

@ -1,4 +1,4 @@
# oauth
# auth
This package provides uses the [GoToSocial oauth2](https://github.com/gotosocial/oauth2) module (forked from [go-oauth2](https://github.com/go-oauth2/oauth2)) to provide [oauth2](https://www.oauth.com/) functionality to the GoToSocial client API.

View file

@ -0,0 +1,89 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
// Package auth is a module that provides oauth functionality to a router.
// It adds the following paths:
// /auth/sign_in
// /oauth/token
// /oauth/authorize
// It also includes the oauthTokenMiddleware, which can be attached to a router to authenticate every request by Bearer token.
package auth
import (
"fmt"
"net/http"
"github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/apimodule"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/db/model"
"github.com/superseriousbusiness/gotosocial/internal/oauth"
"github.com/superseriousbusiness/gotosocial/internal/router"
)
const (
authSignInPath = "/auth/sign_in"
oauthTokenPath = "/oauth/token"
oauthAuthorizePath = "/oauth/authorize"
)
type authModule struct {
server oauth.Server
db db.DB
log *logrus.Logger
}
// New returns a new auth module
func New(srv oauth.Server, db db.DB, log *logrus.Logger) apimodule.ClientAPIModule {
return &authModule{
server: srv,
db: db,
log: log,
}
}
// Route satisfies the RESTAPIModule interface
func (m *authModule) Route(s router.Router) error {
s.AttachHandler(http.MethodGet, authSignInPath, m.signInGETHandler)
s.AttachHandler(http.MethodPost, authSignInPath, m.signInPOSTHandler)
s.AttachHandler(http.MethodPost, oauthTokenPath, m.tokenPOSTHandler)
s.AttachHandler(http.MethodGet, oauthAuthorizePath, m.authorizeGETHandler)
s.AttachHandler(http.MethodPost, oauthAuthorizePath, m.authorizePOSTHandler)
s.AttachMiddleware(m.oauthTokenMiddleware)
return nil
}
func (m *authModule) CreateTables(db db.DB) error {
models := []interface{}{
&oauth.Client{},
&oauth.Token{},
&model.User{},
&model.Account{},
&model.Application{},
}
for _, m := range models {
if err := db.CreateTable(m); err != nil {
return fmt.Errorf("error creating table: %s", err)
}
}
return nil
}

View file

@ -16,7 +16,7 @@
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package oauth
package auth
import (
"context"
@ -25,30 +25,29 @@ import (
"time"
"github.com/google/uuid"
"github.com/gotosocial/gotosocial/internal/config"
"github.com/gotosocial/gotosocial/internal/db"
"github.com/gotosocial/gotosocial/internal/gtsmodel"
"github.com/gotosocial/gotosocial/internal/router"
"github.com/gotosocial/oauth2/v4"
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/db/model"
"github.com/superseriousbusiness/gotosocial/internal/oauth"
"github.com/superseriousbusiness/gotosocial/internal/router"
"golang.org/x/crypto/bcrypt"
)
type OauthTestSuite struct {
type AuthTestSuite struct {
suite.Suite
tokenStore oauth2.TokenStore
clientStore oauth2.ClientStore
oauthServer oauth.Server
db db.DB
testAccount *gtsmodel.Account
testApplication *gtsmodel.Application
testUser *gtsmodel.User
testClient *oauthClient
testAccount *model.Account
testApplication *model.Application
testUser *model.User
testClient *oauth.Client
config *config.Config
}
// SetupSuite sets some variables on the suite that we can use as consts (more or less) throughout
func (suite *OauthTestSuite) SetupSuite() {
func (suite *AuthTestSuite) SetupSuite() {
c := config.Empty()
// we're running on localhost without https so set the protocol to http
c.Protocol = "http"
@ -76,21 +75,21 @@ func (suite *OauthTestSuite) SetupSuite() {
acctID := uuid.NewString()
suite.testAccount = &gtsmodel.Account{
suite.testAccount = &model.Account{
ID: acctID,
Username: "test_user",
}
suite.testUser = &gtsmodel.User{
suite.testUser = &model.User{
EncryptedPassword: string(encryptedPassword),
Email: "user@example.org",
AccountID: acctID,
}
suite.testClient = &oauthClient{
suite.testClient = &oauth.Client{
ID: "a-known-client-id",
Secret: "some-secret",
Domain: fmt.Sprintf("%s://%s", c.Protocol, c.Host),
}
suite.testApplication = &gtsmodel.Application{
suite.testApplication = &model.Application{
Name: "a test application",
Website: "https://some-application-website.com",
RedirectURI: "http://localhost:8080",
@ -102,7 +101,7 @@ func (suite *OauthTestSuite) SetupSuite() {
}
// SetupTest creates a postgres connection and creates the oauth_clients table before each test
func (suite *OauthTestSuite) SetupTest() {
func (suite *AuthTestSuite) SetupTest() {
log := logrus.New()
log.SetLevel(logrus.TraceLevel)
@ -114,11 +113,11 @@ func (suite *OauthTestSuite) SetupTest() {
suite.db = db
models := []interface{}{
&oauthClient{},
&oauthToken{},
&gtsmodel.User{},
&gtsmodel.Account{},
&gtsmodel.Application{},
&oauth.Client{},
&oauth.Token{},
&model.User{},
&model.Account{},
&model.Application{},
}
for _, m := range models {
@ -127,8 +126,7 @@ func (suite *OauthTestSuite) SetupTest() {
}
}
suite.tokenStore = newTokenStore(context.Background(), suite.db, logrus.New())
suite.clientStore = newClientStore(suite.db)
suite.oauthServer = oauth.New(suite.db, log)
if err := suite.db.Put(suite.testAccount); err != nil {
logrus.Panicf("could not insert test account into db: %s", err)
@ -146,13 +144,13 @@ func (suite *OauthTestSuite) SetupTest() {
}
// TearDownTest drops the oauth_clients table and closes the pg connection after each test
func (suite *OauthTestSuite) TearDownTest() {
func (suite *AuthTestSuite) TearDownTest() {
models := []interface{}{
&oauthClient{},
&oauthToken{},
&gtsmodel.User{},
&gtsmodel.Account{},
&gtsmodel.Application{},
&oauth.Client{},
&oauth.Token{},
&model.User{},
&model.Account{},
&model.Application{},
}
for _, m := range models {
if err := suite.db.DropTable(m); err != nil {
@ -165,7 +163,7 @@ func (suite *OauthTestSuite) TearDownTest() {
suite.db = nil
}
func (suite *OauthTestSuite) TestAPIInitialize() {
func (suite *AuthTestSuite) TestAPIInitialize() {
log := logrus.New()
log.SetLevel(logrus.TraceLevel)
@ -174,18 +172,18 @@ func (suite *OauthTestSuite) TestAPIInitialize() {
suite.FailNow(fmt.Sprintf("error mapping routes onto router: %s", err))
}
api := New(suite.tokenStore, suite.clientStore, suite.db, log)
api := New(suite.oauthServer, suite.db, log)
if err := api.Route(r); err != nil {
suite.FailNow(fmt.Sprintf("error mapping routes onto router: %s", err))
}
go r.Start()
r.Start()
time.Sleep(60 * time.Second)
// http://localhost:8080/oauth/authorize?client_id=a-known-client-id&response_type=code&redirect_uri=http://localhost:8080&scope=read
// curl -v -F client_id=a-known-client-id -F client_secret=some-secret -F redirect_uri=http://localhost:8080 -F code=[ INSERT CODE HERE ] -F grant_type=authorization_code localhost:8080/oauth/token
// curl -v -H "Authorization: Bearer [INSERT TOKEN HERE]" http://localhost:8080
if err := r.Stop(context.Background()); err != nil {
suite.FailNow(fmt.Sprintf("error stopping router: %s", err))
}
}
func TestOauthTestSuite(t *testing.T) {
suite.Run(t, new(OauthTestSuite))
func TestAuthTestSuite(t *testing.T) {
suite.Run(t, new(AuthTestSuite))
}

View file

@ -0,0 +1,204 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package auth
import (
"errors"
"fmt"
"net/http"
"net/url"
"github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin"
"github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/db/model"
"github.com/superseriousbusiness/gotosocial/pkg/mastotypes"
)
// authorizeGETHandler should be served as GET at https://example.org/oauth/authorize
// The idea here is to present an oauth authorize page to the user, with a button
// that they have to click to accept. See here: https://docs.joinmastodon.org/methods/apps/oauth/#authorize-a-user
func (m *authModule) authorizeGETHandler(c *gin.Context) {
l := m.log.WithField("func", "AuthorizeGETHandler")
s := sessions.Default(c)
// UserID will be set in the session by AuthorizePOSTHandler if the caller has already gone through the authentication flow
// If it's not set, then we don't know yet who the user is, so we need to redirect them to the sign in page.
userID, ok := s.Get("userid").(string)
if !ok || userID == "" {
l.Trace("userid was empty, parsing form then redirecting to sign in page")
if err := parseAuthForm(c, l); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
} else {
c.Redirect(http.StatusFound, authSignInPath)
}
return
}
// We can use the client_id on the session to retrieve info about the app associated with the client_id
clientID, ok := s.Get("client_id").(string)
if !ok || clientID == "" {
c.JSON(http.StatusInternalServerError, gin.H{"error": "no client_id found in session"})
return
}
app := &model.Application{
ClientID: clientID,
}
if err := m.db.GetWhere("client_id", app.ClientID, app); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("no application found for client id %s", clientID)})
return
}
// we can also use the userid of the user to fetch their username from the db to greet them nicely <3
user := &model.User{
ID: userID,
}
if err := m.db.GetByID(user.ID, user); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
acct := &model.Account{
ID: user.AccountID,
}
if err := m.db.GetByID(acct.ID, acct); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
// Finally we should also get the redirect and scope of this particular request, as stored in the session.
redirect, ok := s.Get("redirect_uri").(string)
if !ok || redirect == "" {
c.JSON(http.StatusInternalServerError, gin.H{"error": "no redirect_uri found in session"})
return
}
scope, ok := s.Get("scope").(string)
if !ok || scope == "" {
c.JSON(http.StatusInternalServerError, gin.H{"error": "no scope found in session"})
return
}
// the authorize template will display a form to the user where they can get some information
// about the app that's trying to authorize, and the scope of the request.
// They can then approve it if it looks OK to them, which will POST to the AuthorizePOSTHandler
l.Trace("serving authorize html")
c.HTML(http.StatusOK, "authorize.tmpl", gin.H{
"appname": app.Name,
"appwebsite": app.Website,
"redirect": redirect,
"scope": scope,
"user": acct.Username,
})
}
// authorizePOSTHandler should be served as POST at https://example.org/oauth/authorize
// At this point we assume that the user has A) logged in and B) accepted that the app should act for them,
// so we should proceed with the authentication flow and generate an oauth token for them if we can.
// See here: https://docs.joinmastodon.org/methods/apps/oauth/#authorize-a-user
func (m *authModule) authorizePOSTHandler(c *gin.Context) {
l := m.log.WithField("func", "AuthorizePOSTHandler")
s := sessions.Default(c)
// At this point we know the user has said 'yes' to allowing the application and oauth client
// work for them, so we can set the
// We need to retrieve the original form submitted to the authorizeGEThandler, and
// recreate it on the request so that it can be used further by the oauth2 library.
// So first fetch all the values from the session.
forceLogin, ok := s.Get("force_login").(string)
if !ok {
c.JSON(http.StatusBadRequest, gin.H{"error": "session missing force_login"})
return
}
responseType, ok := s.Get("response_type").(string)
if !ok || responseType == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "session missing response_type"})
return
}
clientID, ok := s.Get("client_id").(string)
if !ok || clientID == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "session missing client_id"})
return
}
redirectURI, ok := s.Get("redirect_uri").(string)
if !ok || redirectURI == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "session missing redirect_uri"})
return
}
scope, ok := s.Get("scope").(string)
if !ok {
c.JSON(http.StatusBadRequest, gin.H{"error": "session missing scope"})
return
}
userID, ok := s.Get("userid").(string)
if !ok {
c.JSON(http.StatusBadRequest, gin.H{"error": "session missing userid"})
return
}
// we're done with the session so we can clear it now
s.Clear()
// now set the values on the request
values := url.Values{}
values.Set("force_login", forceLogin)
values.Set("response_type", responseType)
values.Set("client_id", clientID)
values.Set("redirect_uri", redirectURI)
values.Set("scope", scope)
values.Set("userid", userID)
c.Request.Form = values
l.Tracef("values on request set to %+v", c.Request.Form)
// and proceed with authorization using the oauth2 library
if err := m.server.HandleAuthorizeRequest(c.Writer, c.Request); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
}
}
// parseAuthForm parses the OAuthAuthorize form in the gin context, and stores
// the values in the form into the session.
func parseAuthForm(c *gin.Context, l *logrus.Entry) error {
s := sessions.Default(c)
// first make sure they've filled out the authorize form with the required values
form := &mastotypes.OAuthAuthorize{}
if err := c.ShouldBind(form); err != nil {
return err
}
l.Tracef("parsed form: %+v", form)
// these fields are *required* so check 'em
if form.ResponseType == "" || form.ClientID == "" || form.RedirectURI == "" {
return errors.New("missing one of: response_type, client_id or redirect_uri")
}
// set default scope to read
if form.Scope == "" {
form.Scope = "read"
}
// save these values from the form so we can use them elsewhere in the session
s.Set("force_login", form.ForceLogin)
s.Set("response_type", form.ResponseType)
s.Set("client_id", form.ClientID)
s.Set("redirect_uri", form.RedirectURI)
s.Set("scope", form.Scope)
return s.Save()
}

View file

@ -0,0 +1,76 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package auth
import (
"github.com/gin-gonic/gin"
"github.com/superseriousbusiness/gotosocial/internal/db/model"
"github.com/superseriousbusiness/gotosocial/internal/oauth"
)
// oauthTokenMiddleware checks if the client has presented a valid oauth Bearer token.
// If so, it will check the User that the token belongs to, and set that in the context of
// the request. Then, it will look up the account for that user, and set that in the request too.
// If user or account can't be found, then the handler won't *fail*, in case the server wants to allow
// public requests that don't have a Bearer token set (eg., for public instance information and so on).
func (m *authModule) oauthTokenMiddleware(c *gin.Context) {
l := m.log.WithField("func", "ValidatePassword")
l.Trace("entering OauthTokenMiddleware")
ti, err := m.server.ValidationBearerToken(c.Request)
if err != nil {
l.Trace("no valid token presented: continuing with unauthenticated request")
return
}
c.Set(oauth.SessionAuthorizedToken, ti)
l.Tracef("set gin context %s to %+v", oauth.SessionAuthorizedToken, ti)
// check for user-level token
if uid := ti.GetUserID(); uid != "" {
l.Tracef("authenticated user %s with bearer token, scope is %s", uid, ti.GetScope())
// fetch user's and account for this user id
user := &model.User{}
if err := m.db.GetByID(uid, user); err != nil || user == nil {
l.Warnf("no user found for validated uid %s", uid)
return
}
c.Set(oauth.SessionAuthorizedUser, user)
l.Tracef("set gin context %s to %+v", oauth.SessionAuthorizedUser, user)
acct := &model.Account{}
if err := m.db.GetByID(user.AccountID, acct); err != nil || acct == nil {
l.Warnf("no account found for validated user %s", uid)
return
}
c.Set(oauth.SessionAuthorizedAccount, acct)
l.Tracef("set gin context %s to %+v", oauth.SessionAuthorizedAccount, acct)
}
// check for application token
if cid := ti.GetClientID(); cid != "" {
l.Tracef("authenticated client %s with bearer token, scope is %s", cid, ti.GetScope())
app := &model.Application{}
if err := m.db.GetWhere("client_id", cid, app); err != nil {
l.Tracef("no app found for client %s", cid)
}
c.Set(oauth.SessionAuthorizedApplication, app)
l.Tracef("set gin context %s to %+v", oauth.SessionAuthorizedApplication, app)
}
}

View file

@ -0,0 +1,115 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package auth
import (
"errors"
"net/http"
"github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin"
"github.com/superseriousbusiness/gotosocial/internal/db/model"
"golang.org/x/crypto/bcrypt"
)
type login struct {
Email string `form:"username"`
Password string `form:"password"`
}
// signInGETHandler should be served at https://example.org/auth/sign_in.
// The idea is to present a sign in page to the user, where they can enter their username and password.
// The form will then POST to the sign in page, which will be handled by SignInPOSTHandler
func (m *authModule) signInGETHandler(c *gin.Context) {
m.log.WithField("func", "SignInGETHandler").Trace("serving sign in html")
c.HTML(http.StatusOK, "sign-in.tmpl", gin.H{})
}
// signInPOSTHandler should be served at https://example.org/auth/sign_in.
// The idea is to present a sign in page to the user, where they can enter their username and password.
// The handler will then redirect to the auth handler served at /auth
func (m *authModule) signInPOSTHandler(c *gin.Context) {
l := m.log.WithField("func", "SignInPOSTHandler")
s := sessions.Default(c)
form := &login{}
if err := c.ShouldBind(form); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
l.Tracef("parsed form: %+v", form)
userid, err := m.validatePassword(form.Email, form.Password)
if err != nil {
c.String(http.StatusForbidden, err.Error())
return
}
s.Set("userid", userid)
if err := s.Save(); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
l.Trace("redirecting to auth page")
c.Redirect(http.StatusFound, oauthAuthorizePath)
}
// validatePassword takes an email address and a password.
// The goal is to authenticate the password against the one for that email
// address stored in the database. If OK, we return the userid (a uuid) for that user,
// so that it can be used in further Oauth flows to generate a token/retreieve an oauth client from the db.
func (m *authModule) validatePassword(email string, password string) (userid string, err error) {
l := m.log.WithField("func", "ValidatePassword")
// make sure an email/password was provided and bail if not
if email == "" || password == "" {
l.Debug("email or password was not provided")
return incorrectPassword()
}
// first we select the user from the database based on email address, bail if no user found for that email
gtsUser := &model.User{}
if err := m.db.GetWhere("email", email, gtsUser); err != nil {
l.Debugf("user %s was not retrievable from db during oauth authorization attempt: %s", email, err)
return incorrectPassword()
}
// make sure a password is actually set and bail if not
if gtsUser.EncryptedPassword == "" {
l.Warnf("encrypted password for user %s was empty for some reason", gtsUser.Email)
return incorrectPassword()
}
// compare the provided password with the encrypted one from the db, bail if they don't match
if err := bcrypt.CompareHashAndPassword([]byte(gtsUser.EncryptedPassword), []byte(password)); err != nil {
l.Debugf("password hash didn't match for user %s during login attempt: %s", gtsUser.Email, err)
return incorrectPassword()
}
// If we've made it this far the email/password is correct, so we can just return the id of the user.
userid = gtsUser.ID
l.Tracef("returning (%s, %s)", userid, err)
return
}
// incorrectPassword is just a little helper function to use in the ValidatePassword function
func incorrectPassword() (string, error) {
return "", errors.New("password/email combination was incorrect")
}

View file

@ -0,0 +1,36 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package auth
import (
"net/http"
"github.com/gin-gonic/gin"
)
// tokenPOSTHandler should be served as a POST at https://example.org/oauth/token
// The idea here is to serve an oauth access token to a user, which can be used for authorizing against non-public APIs.
// See https://docs.joinmastodon.org/methods/apps/oauth/#obtain-a-token
func (m *authModule) tokenPOSTHandler(c *gin.Context) {
l := m.log.WithField("func", "TokenPOSTHandler")
l.Trace("entered TokenPOSTHandler")
if err := m.server.HandleTokenRequest(c.Writer, c.Request); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
}
}

View file

@ -0,0 +1,63 @@
package fileserver
import (
"fmt"
"github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/apimodule"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/db/model"
"github.com/superseriousbusiness/gotosocial/internal/router"
"github.com/superseriousbusiness/gotosocial/internal/storage"
)
// fileServer implements the RESTAPIModule interface.
// The goal here is to serve requested media files if the gotosocial server is configured to use local storage.
type fileServer struct {
config *config.Config
db db.DB
storage storage.Storage
log *logrus.Logger
storageBase string
}
// New returns a new fileServer module
func New(config *config.Config, db db.DB, storage storage.Storage, log *logrus.Logger) apimodule.ClientAPIModule {
storageBase := config.StorageConfig.BasePath // TODO: do this properly
return &fileServer{
config: config,
db: db,
storage: storage,
log: log,
storageBase: storageBase,
}
}
// Route satisfies the RESTAPIModule interface
func (m *fileServer) Route(s router.Router) error {
// s.AttachHandler(http.MethodPost, appsPath, m.appsPOSTHandler)
return nil
}
func (m *fileServer) CreateTables(db db.DB) error {
models := []interface{}{
&model.User{},
&model.Account{},
&model.Follow{},
&model.FollowRequest{},
&model.Status{},
&model.Application{},
&model.EmailDomainBlock{},
&model.MediaAttachment{},
}
for _, m := range models {
if err := db.CreateTable(m); err != nil {
return fmt.Errorf("error creating table: %s", err)
}
}
return nil
}

View file

@ -0,0 +1,27 @@
// Code generated by mockery v2.7.4. DO NOT EDIT.
package apimodule
import (
mock "github.com/stretchr/testify/mock"
router "github.com/superseriousbusiness/gotosocial/internal/router"
)
// MockClientAPIModule is an autogenerated mock type for the ClientAPIModule type
type MockClientAPIModule struct {
mock.Mock
}
// Route provides a mock function with given fields: s
func (_m *MockClientAPIModule) Route(s router.Router) error {
ret := _m.Called(s)
var r0 error
if rf, ok := ret.Get(0).(func(router.Router) error); ok {
r0 = rf(s)
} else {
r0 = ret.Error(0)
}
return r0
}

47
internal/cache/mock_Cache.go vendored Normal file
View file

@ -0,0 +1,47 @@
// Code generated by mockery v2.7.4. DO NOT EDIT.
package cache
import mock "github.com/stretchr/testify/mock"
// MockCache is an autogenerated mock type for the Cache type
type MockCache struct {
mock.Mock
}
// Fetch provides a mock function with given fields: k
func (_m *MockCache) Fetch(k string) (interface{}, error) {
ret := _m.Called(k)
var r0 interface{}
if rf, ok := ret.Get(0).(func(string) interface{}); ok {
r0 = rf(k)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(interface{})
}
}
var r1 error
if rf, ok := ret.Get(1).(func(string) error); ok {
r1 = rf(k)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// Store provides a mock function with given fields: k, v
func (_m *MockCache) Store(k string, v interface{}) error {
ret := _m.Called(k, v)
var r0 error
if rf, ok := ret.Get(0).(func(string, interface{}) error); ok {
r0 = rf(k, v)
} else {
r0 = ret.Error(0)
}
return r0
}

View file

@ -0,0 +1,29 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package config
// AccountsConfig contains configuration to do with creating accounts, new registrations, and defaults.
type AccountsConfig struct {
// Do we want people to be able to just submit sign up requests, or do we want invite only?
OpenRegistration bool `yaml:"openRegistration"`
// Do sign up requests require approval from an admin/moderator?
RequireApproval bool `yaml:"requireApproval"`
// Do we require a reason for a sign up or is an empty string OK?
ReasonRequired bool `yaml:"reasonRequired"`
}

View file

@ -33,26 +33,21 @@ type Config struct {
Protocol string `yaml:"protocol"`
DBConfig *DBConfig `yaml:"db"`
TemplateConfig *TemplateConfig `yaml:"template"`
AccountsConfig *AccountsConfig `yaml:"accounts"`
MediaConfig *MediaConfig `yaml:"media"`
StorageConfig *StorageConfig `yaml:"storage"`
}
// FromFile returns a new config from a file, or an error if something goes amiss.
func FromFile(path string) (*Config, error) {
if path != "" {
c, err := loadFromFile(path)
if err != nil {
return nil, fmt.Errorf("error creating config: %s", err)
}
return c, nil
}
// Default returns a new config with default values.
// Not yet implemented.
func Default() *Config {
// TODO: find a way of doing this without code repetition, because having to
// repeat all values here and elsewhere is annoying and gonna be prone to mistakes.
return &Config{
DBConfig: &DBConfig{},
TemplateConfig: &TemplateConfig{},
}
return Empty(), nil
}
// Empty just returns an empty config
@ -60,6 +55,9 @@ func Empty() *Config {
return &Config{
DBConfig: &DBConfig{},
TemplateConfig: &TemplateConfig{},
AccountsConfig: &AccountsConfig{},
MediaConfig: &MediaConfig{},
StorageConfig: &StorageConfig{},
}
}
@ -136,11 +134,51 @@ func (c *Config) ParseCLIFlags(f KeyedFlags) {
if c.TemplateConfig.BaseDir == "" || f.IsSet(fn.TemplateBaseDir) {
c.TemplateConfig.BaseDir = f.String(fn.TemplateBaseDir)
}
// accounts flags
if f.IsSet(fn.AccountsOpenRegistration) {
c.AccountsConfig.OpenRegistration = f.Bool(fn.AccountsOpenRegistration)
}
if f.IsSet(fn.AccountsRequireApproval) {
c.AccountsConfig.RequireApproval = f.Bool(fn.AccountsRequireApproval)
}
// media flags
if c.MediaConfig.MaxImageSize == 0 || f.IsSet(fn.MediaMaxImageSize) {
c.MediaConfig.MaxImageSize = f.Int(fn.MediaMaxImageSize)
}
if c.MediaConfig.MaxVideoSize == 0 || f.IsSet(fn.MediaMaxVideoSize) {
c.MediaConfig.MaxVideoSize = f.Int(fn.MediaMaxVideoSize)
}
// storage flags
if c.StorageConfig.Backend == "" || f.IsSet(fn.StorageBackend) {
c.StorageConfig.Backend = f.String(fn.StorageBackend)
}
if c.StorageConfig.BasePath == "" || f.IsSet(fn.StorageBasePath) {
c.StorageConfig.BasePath = f.String(fn.StorageBasePath)
}
if c.StorageConfig.ServeProtocol == "" || f.IsSet(fn.StorageServeProtocol) {
c.StorageConfig.ServeProtocol = f.String(fn.StorageServeProtocol)
}
if c.StorageConfig.ServeHost == "" || f.IsSet(fn.StorageServeHost) {
c.StorageConfig.ServeHost = f.String(fn.StorageServeHost)
}
if c.StorageConfig.ServeBasePath == "" || f.IsSet(fn.StorageServeBasePath) {
c.StorageConfig.ServeBasePath = f.String(fn.StorageServeBasePath)
}
}
// KeyedFlags is a wrapper for any type that can store keyed flags and give them back.
// HINT: This works with a urfave cli context struct ;)
type KeyedFlags interface {
Bool(k string) bool
String(k string) string
Int(k string) int
IsSet(k string) bool
@ -154,13 +192,27 @@ type Flags struct {
ConfigPath string
Host string
Protocol string
DbType string
DbAddress string
DbPort string
DbUser string
DbPassword string
DbDatabase string
TemplateBaseDir string
AccountsOpenRegistration string
AccountsRequireApproval string
MediaMaxImageSize string
MediaMaxVideoSize string
StorageBackend string
StorageBasePath string
StorageServeProtocol string
StorageServeHost string
StorageServeBasePath string
}
// GetFlagNames returns a struct containing the names of the various flags used for
@ -172,13 +224,27 @@ func GetFlagNames() Flags {
ConfigPath: "config-path",
Host: "host",
Protocol: "protocol",
DbType: "db-type",
DbAddress: "db-address",
DbPort: "db-port",
DbUser: "db-user",
DbPassword: "db-password",
DbDatabase: "db-database",
TemplateBaseDir: "template-basedir",
AccountsOpenRegistration: "accounts-open-registration",
AccountsRequireApproval: "accounts-require-approval",
MediaMaxImageSize: "media-max-image-size",
MediaMaxVideoSize: "media-max-video-size",
StorageBackend: "storage-backend",
StorageBasePath: "storage-base-path",
StorageServeProtocol: "storage-serve-protocol",
StorageServeHost: "storage-serve-host",
StorageServeBasePath: "storage-serve-base-path",
}
}
@ -191,12 +257,26 @@ func GetEnvNames() Flags {
ConfigPath: "GTS_CONFIG_PATH",
Host: "GTS_HOST",
Protocol: "GTS_PROTOCOL",
DbType: "GTS_DB_TYPE",
DbAddress: "GTS_DB_ADDRESS",
DbPort: "GTS_DB_PORT",
DbUser: "GTS_DB_USER",
DbPassword: "GTS_DB_PASSWORD",
DbDatabase: "GTS_DB_DATABASE",
TemplateBaseDir: "GTS_TEMPLATE_BASEDIR",
AccountsOpenRegistration: "GTS_ACCOUNTS_OPEN_REGISTRATION",
AccountsRequireApproval: "GTS_ACCOUNTS_REQUIRE_APPROVAL",
MediaMaxImageSize: "GTS_MEDIA_MAX_IMAGE_SIZE",
MediaMaxVideoSize: "GTS_MEDIA_MAX_VIDEO_SIZE",
StorageBackend: "GTS_STORAGE_BACKEND",
StorageBasePath: "GTS_STORAGE_BASE_PATH",
StorageServeProtocol: "GTS_STORAGE_SERVE_PROTOCOL",
StorageServeHost: "GTS_STORAGE_SERVE_HOST",
StorageServeBasePath: "GTS_STORAGE_SERVE_BASE_PATH",
}
}

27
internal/config/media.go Normal file
View file

@ -0,0 +1,27 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package config
// MediaConfig contains configuration for receiving and parsing media files and attachments
type MediaConfig struct {
// Max size of uploaded images in bytes
MaxImageSize int `yaml:"maxImageSize"`
// Max size of uploaded video in bytes
MaxVideoSize int `yaml:"maxVideoSize"`
}

View file

@ -0,0 +1,66 @@
// Code generated by mockery v2.7.4. DO NOT EDIT.
package config
import mock "github.com/stretchr/testify/mock"
// MockKeyedFlags is an autogenerated mock type for the KeyedFlags type
type MockKeyedFlags struct {
mock.Mock
}
// Bool provides a mock function with given fields: k
func (_m *MockKeyedFlags) Bool(k string) bool {
ret := _m.Called(k)
var r0 bool
if rf, ok := ret.Get(0).(func(string) bool); ok {
r0 = rf(k)
} else {
r0 = ret.Get(0).(bool)
}
return r0
}
// Int provides a mock function with given fields: k
func (_m *MockKeyedFlags) Int(k string) int {
ret := _m.Called(k)
var r0 int
if rf, ok := ret.Get(0).(func(string) int); ok {
r0 = rf(k)
} else {
r0 = ret.Get(0).(int)
}
return r0
}
// IsSet provides a mock function with given fields: k
func (_m *MockKeyedFlags) IsSet(k string) bool {
ret := _m.Called(k)
var r0 bool
if rf, ok := ret.Get(0).(func(string) bool); ok {
r0 = rf(k)
} else {
r0 = ret.Get(0).(bool)
}
return r0
}
// String provides a mock function with given fields: k
func (_m *MockKeyedFlags) String(k string) string {
ret := _m.Called(k)
var r0 string
if rf, ok := ret.Get(0).(func(string) string); ok {
r0 = rf(k)
} else {
r0 = ret.Get(0).(string)
}
return r0
}

View file

@ -0,0 +1,36 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package config
// StorageConfig contains configuration for storage and serving of media files and attachments
type StorageConfig struct {
// Type of storage backend to use: currently only 'local' is supported.
// TODO: add S3 support here.
Backend string `yaml:"backend"`
// The base path for storing things. Should be an already-existing directory.
BasePath string `yaml:"basePath"`
// Protocol to use when *serving* media files from storage
ServeProtocol string `yaml:"serveProtocol"`
// Host to use when *serving* media files from storage
ServeHost string `yaml:"serveHost"`
// Base path to use when *serving* media files from storage
ServeBasePath string `yaml:"serveBasePath"`
}

View file

@ -21,9 +21,9 @@ package db
import (
"context"
"github.com/gotosocial/gotosocial/internal/action"
"github.com/gotosocial/gotosocial/internal/config"
"github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/action"
"github.com/superseriousbusiness/gotosocial/internal/config"
)
// Initialize will initialize the database given in the config for use with GoToSocial

View file

@ -21,53 +21,167 @@ package db
import (
"context"
"fmt"
"net"
"strings"
"github.com/go-fed/activity/pub"
"github.com/gotosocial/gotosocial/internal/config"
"github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db/model"
"github.com/superseriousbusiness/gotosocial/pkg/mastotypes"
)
const dbTypePostgres string = "POSTGRES"
// ErrNoEntries is to be returned from the DB interface when no entries are found for a given query.
type ErrNoEntries struct{}
func (e ErrNoEntries) Error() string {
return "no entries"
}
// DB provides methods for interacting with an underlying database or other storage mechanism (for now, just postgres).
// Note that in all of the functions below, the passed interface should be a pointer or a slice, which will then be populated
// by whatever is returned from the database.
type DB interface {
// Federation returns an interface that's compatible with go-fed, for performing federation storage/retrieval functions.
// See: https://pkg.go.dev/github.com/go-fed/activity@v1.0.0/pub?utm_source=gopls#Database
Federation() pub.Database
// CreateTable creates a table for the given interface
/*
BASIC DB FUNCTIONALITY
*/
// CreateTable creates a table for the given interface.
// For implementations that don't use tables, this can just return nil.
CreateTable(i interface{}) error
// DropTable drops the table for the given interface
// DropTable drops the table for the given interface.
// For implementations that don't use tables, this can just return nil.
DropTable(i interface{}) error
// Stop should stop and close the database connection cleanly, returning an error if this is not possible
// Stop should stop and close the database connection cleanly, returning an error if this is not possible.
// If the database implementation doesn't need to be stopped, this can just return nil.
Stop(ctx context.Context) error
// IsHealthy should return nil if the database connection is healthy, or an error if not
// IsHealthy should return nil if the database connection is healthy, or an error if not.
IsHealthy(ctx context.Context) error
// GetByID gets one entry by its id.
// GetByID gets one entry by its id. In a database like postgres, this might be the 'id' field of the entry,
// for other implementations (for example, in-memory) it might just be the key of a map.
// The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice.
// In case of no entries, a 'no entries' error will be returned
GetByID(id string, i interface{}) error
// GetWhere gets one entry where key = value
// GetWhere gets one entry where key = value. This is similar to GetByID but allows the caller to specify the
// name of the key to select from.
// The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice.
// In case of no entries, a 'no entries' error will be returned
GetWhere(key string, value interface{}, i interface{}) error
// GetAll gets all entries of interface type i
// GetAll will try to get all entries of type i.
// The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice.
// In case of no entries, a 'no entries' error will be returned
GetAll(i interface{}) error
// Put stores i
// Put simply stores i. It is up to the implementation to figure out how to store it, and using what key.
// The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice.
Put(i interface{}) error
// Update by id updates i with id id
// UpdateByID updates i with id id.
// The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice.
UpdateByID(id string, i interface{}) error
// Delete by id removes i with id id
// UpdateOneByID updates interface i with database the given database id. It will update one field of key key and value value.
UpdateOneByID(id string, key string, value interface{}, i interface{}) error
// DeleteByID removes i with id id.
// If i didn't exist anyway, then no error should be returned.
DeleteByID(id string, i interface{}) error
// Delete where deletes i where key = value
// DeleteWhere deletes i where key = value
// If i didn't exist anyway, then no error should be returned.
DeleteWhere(key string, value interface{}, i interface{}) error
/*
HANDY SHORTCUTS
*/
// GetAccountByUserID is a shortcut for the common action of fetching an account corresponding to a user ID.
// The given account pointer will be set to the result of the query, whatever it is.
// In case of no entries, a 'no entries' error will be returned
GetAccountByUserID(userID string, account *model.Account) error
// GetFollowRequestsForAccountID is a shortcut for the common action of fetching a list of follow requests targeting the given account ID.
// The given slice 'followRequests' will be set to the result of the query, whatever it is.
// In case of no entries, a 'no entries' error will be returned
GetFollowRequestsForAccountID(accountID string, followRequests *[]model.FollowRequest) error
// GetFollowingByAccountID is a shortcut for the common action of fetching a list of accounts that accountID is following.
// The given slice 'following' will be set to the result of the query, whatever it is.
// In case of no entries, a 'no entries' error will be returned
GetFollowingByAccountID(accountID string, following *[]model.Follow) error
// GetFollowersByAccountID is a shortcut for the common action of fetching a list of accounts that accountID is followed by.
// The given slice 'followers' will be set to the result of the query, whatever it is.
// In case of no entries, a 'no entries' error will be returned
GetFollowersByAccountID(accountID string, followers *[]model.Follow) error
// GetStatusesByAccountID is a shortcut for the common action of fetching a list of statuses produced by accountID.
// The given slice 'statuses' will be set to the result of the query, whatever it is.
// In case of no entries, a 'no entries' error will be returned
GetStatusesByAccountID(accountID string, statuses *[]model.Status) error
// GetStatusesByTimeDescending is a shortcut for getting the most recent statuses. accountID is optional, if not provided
// then all statuses will be returned. If limit is set to 0, the size of the returned slice will not be limited. This can
// be very memory intensive so you probably shouldn't do this!
// In case of no entries, a 'no entries' error will be returned
GetStatusesByTimeDescending(accountID string, statuses *[]model.Status, limit int) error
// GetLastStatusForAccountID simply gets the most recent status by the given account.
// The given slice 'status' pointer will be set to the result of the query, whatever it is.
// In case of no entries, a 'no entries' error will be returned
GetLastStatusForAccountID(accountID string, status *model.Status) error
// IsUsernameAvailable checks whether a given username is available on our domain.
// Returns an error if the username is already taken, or something went wrong in the db.
IsUsernameAvailable(username string) error
// IsEmailAvailable checks whether a given email address for a new account is available to be used on our domain.
// Return an error if:
// A) the email is already associated with an account
// B) we block signups from this email domain
// C) something went wrong in the db
IsEmailAvailable(email string) error
// NewSignup creates a new user in the database with the given parameters, with an *unconfirmed* email address.
// By the time this function is called, it should be assumed that all the parameters have passed validation!
NewSignup(username string, reason string, requireApproval bool, email string, password string, signUpIP net.IP, locale string, appID string) (*model.User, error)
// SetHeaderOrAvatarForAccountID sets the header or avatar for the given accountID to the given media attachment.
SetHeaderOrAvatarForAccountID(mediaAttachment *model.MediaAttachment, accountID string) error
// GetHeaderAvatarForAccountID gets the current avatar for the given account ID.
// The passed mediaAttachment pointer will be populated with the value of the avatar, if it exists.
GetAvatarForAccountID(avatar *model.MediaAttachment, accountID string) error
// GetHeaderForAccountID gets the current header for the given account ID.
// The passed mediaAttachment pointer will be populated with the value of the header, if it exists.
GetHeaderForAccountID(header *model.MediaAttachment, accountID string) error
/*
USEFUL CONVERSION FUNCTIONS
*/
// AccountToMastoSensitive takes a db model account as a param, and returns a populated mastotype account, or an error
// if something goes wrong. The returned account should be ready to serialize on an API level, and may have sensitive fields,
// so serve it only to an authorized user who should have permission to see it.
AccountToMastoSensitive(account *model.Account) (*mastotypes.Account, error)
// AccountToMastoPublic takes a db model account as a param, and returns a populated mastotype account, or an error
// if something goes wrong. The returned account should be ready to serialize on an API level, and may NOT have sensitive fields.
// In other words, this is the public record that the server has of an account.
AccountToMastoPublic(account *model.Account) (*mastotypes.Account, error)
}
// New returns a new database service that satisfies the DB interface and, by extension,

View file

@ -0,0 +1,159 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package db
import (
"context"
"errors"
"net/url"
"sync"
"github.com/go-fed/activity/pub"
"github.com/go-fed/activity/streams"
"github.com/go-fed/activity/streams/vocab"
"github.com/superseriousbusiness/gotosocial/internal/config"
)
// FederatingDB uses the underlying DB interface to implement the go-fed pub.Database interface.
// It doesn't care what the underlying implementation of the DB interface is, as long as it works.
type federatingDB struct {
locks *sync.Map
db DB
config *config.Config
}
func newFederatingDB(db DB, config *config.Config) pub.Database {
return &federatingDB{
locks: new(sync.Map),
db: db,
config: config,
}
}
/*
GO-FED DB INTERFACE-IMPLEMENTING FUNCTIONS
*/
func (f *federatingDB) Lock(ctx context.Context, id *url.URL) error {
// Before any other Database methods are called, the relevant `id`
// entries are locked to allow for fine-grained concurrency.
// Strategy: create a new lock, if stored, continue. Otherwise, lock the
// existing mutex.
mu := &sync.Mutex{}
mu.Lock() // Optimistically lock if we do store it.
i, loaded := f.locks.LoadOrStore(id.String(), mu)
if loaded {
mu = i.(*sync.Mutex)
mu.Lock()
}
return nil
}
func (f *federatingDB) Unlock(ctx context.Context, id *url.URL) error {
// Once Go-Fed is done calling Database methods, the relevant `id`
// entries are unlocked.
i, ok := f.locks.Load(id.String())
if !ok {
return errors.New("missing an id in unlock")
}
mu := i.(*sync.Mutex)
mu.Unlock()
return nil
}
func (f *federatingDB) InboxContains(ctx context.Context, inbox *url.URL, id *url.URL) (bool, error) {
return false, nil
}
func (f *federatingDB) GetInbox(ctx context.Context, inboxIRI *url.URL) (inbox vocab.ActivityStreamsOrderedCollectionPage, err error) {
return nil, nil
}
func (f *federatingDB) SetInbox(ctx context.Context, inbox vocab.ActivityStreamsOrderedCollectionPage) error {
return nil
}
func (f *federatingDB) Owns(ctx context.Context, id *url.URL) (owns bool, err error) {
return id.Host == f.config.Host, nil
}
func (f *federatingDB) ActorForOutbox(ctx context.Context, outboxIRI *url.URL) (actorIRI *url.URL, err error) {
return nil, nil
}
func (f *federatingDB) ActorForInbox(ctx context.Context, inboxIRI *url.URL) (actorIRI *url.URL, err error) {
return nil, nil
}
func (f *federatingDB) OutboxForInbox(ctx context.Context, inboxIRI *url.URL) (outboxIRI *url.URL, err error) {
return nil, nil
}
func (f *federatingDB) Exists(ctx context.Context, id *url.URL) (exists bool, err error) {
return false, nil
}
func (f *federatingDB) Get(ctx context.Context, id *url.URL) (value vocab.Type, err error) {
return nil, nil
}
func (f *federatingDB) Create(ctx context.Context, asType vocab.Type) error {
t, err := streams.NewTypeResolver()
if err != nil {
return err
}
if err := t.Resolve(ctx, asType); err != nil {
return err
}
asType.GetTypeName()
return nil
}
func (f *federatingDB) Update(ctx context.Context, asType vocab.Type) error {
return nil
}
func (f *federatingDB) Delete(ctx context.Context, id *url.URL) error {
return nil
}
func (f *federatingDB) GetOutbox(ctx context.Context, outboxIRI *url.URL) (inbox vocab.ActivityStreamsOrderedCollectionPage, err error) {
return nil, nil
}
func (f *federatingDB) SetOutbox(ctx context.Context, outbox vocab.ActivityStreamsOrderedCollectionPage) error {
return nil
}
func (f *federatingDB) NewID(ctx context.Context, t vocab.Type) (id *url.URL, err error) {
return nil, nil
}
func (f *federatingDB) Followers(ctx context.Context, actorIRI *url.URL) (followers vocab.ActivityStreamsCollection, err error) {
return nil, nil
}
func (f *federatingDB) Following(ctx context.Context, actorIRI *url.URL) (followers vocab.ActivityStreamsCollection, err error) {
return nil, nil
}
func (f *federatingDB) Liked(ctx context.Context, actorIRI *url.URL) (followers vocab.ActivityStreamsCollection, err error) {
return nil, nil
}

View file

@ -0,0 +1,21 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package db
// TODO: write tests for pgfed

363
internal/db/mock_DB.go Normal file
View file

@ -0,0 +1,363 @@
// Code generated by mockery v2.7.4. DO NOT EDIT.
package db
import (
context "context"
mock "github.com/stretchr/testify/mock"
mastotypes "github.com/superseriousbusiness/gotosocial/pkg/mastotypes"
model "github.com/superseriousbusiness/gotosocial/internal/db/model"
net "net"
pub "github.com/go-fed/activity/pub"
)
// MockDB is an autogenerated mock type for the DB type
type MockDB struct {
mock.Mock
}
// AccountToMastoSensitive provides a mock function with given fields: account
func (_m *MockDB) AccountToMastoSensitive(account *model.Account) (*mastotypes.Account, error) {
ret := _m.Called(account)
var r0 *mastotypes.Account
if rf, ok := ret.Get(0).(func(*model.Account) *mastotypes.Account); ok {
r0 = rf(account)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*mastotypes.Account)
}
}
var r1 error
if rf, ok := ret.Get(1).(func(*model.Account) error); ok {
r1 = rf(account)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// CreateTable provides a mock function with given fields: i
func (_m *MockDB) CreateTable(i interface{}) error {
ret := _m.Called(i)
var r0 error
if rf, ok := ret.Get(0).(func(interface{}) error); ok {
r0 = rf(i)
} else {
r0 = ret.Error(0)
}
return r0
}
// DeleteByID provides a mock function with given fields: id, i
func (_m *MockDB) DeleteByID(id string, i interface{}) error {
ret := _m.Called(id, i)
var r0 error
if rf, ok := ret.Get(0).(func(string, interface{}) error); ok {
r0 = rf(id, i)
} else {
r0 = ret.Error(0)
}
return r0
}
// DeleteWhere provides a mock function with given fields: key, value, i
func (_m *MockDB) DeleteWhere(key string, value interface{}, i interface{}) error {
ret := _m.Called(key, value, i)
var r0 error
if rf, ok := ret.Get(0).(func(string, interface{}, interface{}) error); ok {
r0 = rf(key, value, i)
} else {
r0 = ret.Error(0)
}
return r0
}
// DropTable provides a mock function with given fields: i
func (_m *MockDB) DropTable(i interface{}) error {
ret := _m.Called(i)
var r0 error
if rf, ok := ret.Get(0).(func(interface{}) error); ok {
r0 = rf(i)
} else {
r0 = ret.Error(0)
}
return r0
}
// Federation provides a mock function with given fields:
func (_m *MockDB) Federation() pub.Database {
ret := _m.Called()
var r0 pub.Database
if rf, ok := ret.Get(0).(func() pub.Database); ok {
r0 = rf()
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(pub.Database)
}
}
return r0
}
// GetAccountByUserID provides a mock function with given fields: userID, account
func (_m *MockDB) GetAccountByUserID(userID string, account *model.Account) error {
ret := _m.Called(userID, account)
var r0 error
if rf, ok := ret.Get(0).(func(string, *model.Account) error); ok {
r0 = rf(userID, account)
} else {
r0 = ret.Error(0)
}
return r0
}
// GetAll provides a mock function with given fields: i
func (_m *MockDB) GetAll(i interface{}) error {
ret := _m.Called(i)
var r0 error
if rf, ok := ret.Get(0).(func(interface{}) error); ok {
r0 = rf(i)
} else {
r0 = ret.Error(0)
}
return r0
}
// GetByID provides a mock function with given fields: id, i
func (_m *MockDB) GetByID(id string, i interface{}) error {
ret := _m.Called(id, i)
var r0 error
if rf, ok := ret.Get(0).(func(string, interface{}) error); ok {
r0 = rf(id, i)
} else {
r0 = ret.Error(0)
}
return r0
}
// GetFollowRequestsForAccountID provides a mock function with given fields: accountID, followRequests
func (_m *MockDB) GetFollowRequestsForAccountID(accountID string, followRequests *[]model.FollowRequest) error {
ret := _m.Called(accountID, followRequests)
var r0 error
if rf, ok := ret.Get(0).(func(string, *[]model.FollowRequest) error); ok {
r0 = rf(accountID, followRequests)
} else {
r0 = ret.Error(0)
}
return r0
}
// GetFollowersByAccountID provides a mock function with given fields: accountID, followers
func (_m *MockDB) GetFollowersByAccountID(accountID string, followers *[]model.Follow) error {
ret := _m.Called(accountID, followers)
var r0 error
if rf, ok := ret.Get(0).(func(string, *[]model.Follow) error); ok {
r0 = rf(accountID, followers)
} else {
r0 = ret.Error(0)
}
return r0
}
// GetFollowingByAccountID provides a mock function with given fields: accountID, following
func (_m *MockDB) GetFollowingByAccountID(accountID string, following *[]model.Follow) error {
ret := _m.Called(accountID, following)
var r0 error
if rf, ok := ret.Get(0).(func(string, *[]model.Follow) error); ok {
r0 = rf(accountID, following)
} else {
r0 = ret.Error(0)
}
return r0
}
// GetLastStatusForAccountID provides a mock function with given fields: accountID, status
func (_m *MockDB) GetLastStatusForAccountID(accountID string, status *model.Status) error {
ret := _m.Called(accountID, status)
var r0 error
if rf, ok := ret.Get(0).(func(string, *model.Status) error); ok {
r0 = rf(accountID, status)
} else {
r0 = ret.Error(0)
}
return r0
}
// GetStatusesByAccountID provides a mock function with given fields: accountID, statuses
func (_m *MockDB) GetStatusesByAccountID(accountID string, statuses *[]model.Status) error {
ret := _m.Called(accountID, statuses)
var r0 error
if rf, ok := ret.Get(0).(func(string, *[]model.Status) error); ok {
r0 = rf(accountID, statuses)
} else {
r0 = ret.Error(0)
}
return r0
}
// GetStatusesByTimeDescending provides a mock function with given fields: accountID, statuses, limit
func (_m *MockDB) GetStatusesByTimeDescending(accountID string, statuses *[]model.Status, limit int) error {
ret := _m.Called(accountID, statuses, limit)
var r0 error
if rf, ok := ret.Get(0).(func(string, *[]model.Status, int) error); ok {
r0 = rf(accountID, statuses, limit)
} else {
r0 = ret.Error(0)
}
return r0
}
// GetWhere provides a mock function with given fields: key, value, i
func (_m *MockDB) GetWhere(key string, value interface{}, i interface{}) error {
ret := _m.Called(key, value, i)
var r0 error
if rf, ok := ret.Get(0).(func(string, interface{}, interface{}) error); ok {
r0 = rf(key, value, i)
} else {
r0 = ret.Error(0)
}
return r0
}
// IsEmailAvailable provides a mock function with given fields: email
func (_m *MockDB) IsEmailAvailable(email string) error {
ret := _m.Called(email)
var r0 error
if rf, ok := ret.Get(0).(func(string) error); ok {
r0 = rf(email)
} else {
r0 = ret.Error(0)
}
return r0
}
// IsHealthy provides a mock function with given fields: ctx
func (_m *MockDB) IsHealthy(ctx context.Context) error {
ret := _m.Called(ctx)
var r0 error
if rf, ok := ret.Get(0).(func(context.Context) error); ok {
r0 = rf(ctx)
} else {
r0 = ret.Error(0)
}
return r0
}
// IsUsernameAvailable provides a mock function with given fields: username
func (_m *MockDB) IsUsernameAvailable(username string) error {
ret := _m.Called(username)
var r0 error
if rf, ok := ret.Get(0).(func(string) error); ok {
r0 = rf(username)
} else {
r0 = ret.Error(0)
}
return r0
}
// NewSignup provides a mock function with given fields: username, reason, requireApproval, email, password, signUpIP, locale, appID
func (_m *MockDB) NewSignup(username string, reason string, requireApproval bool, email string, password string, signUpIP net.IP, locale string, appID string) (*model.User, error) {
ret := _m.Called(username, reason, requireApproval, email, password, signUpIP, locale, appID)
var r0 *model.User
if rf, ok := ret.Get(0).(func(string, string, bool, string, string, net.IP, string, string) *model.User); ok {
r0 = rf(username, reason, requireApproval, email, password, signUpIP, locale, appID)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*model.User)
}
}
var r1 error
if rf, ok := ret.Get(1).(func(string, string, bool, string, string, net.IP, string, string) error); ok {
r1 = rf(username, reason, requireApproval, email, password, signUpIP, locale, appID)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// Put provides a mock function with given fields: i
func (_m *MockDB) Put(i interface{}) error {
ret := _m.Called(i)
var r0 error
if rf, ok := ret.Get(0).(func(interface{}) error); ok {
r0 = rf(i)
} else {
r0 = ret.Error(0)
}
return r0
}
// Stop provides a mock function with given fields: ctx
func (_m *MockDB) Stop(ctx context.Context) error {
ret := _m.Called(ctx)
var r0 error
if rf, ok := ret.Get(0).(func(context.Context) error); ok {
r0 = rf(ctx)
} else {
r0 = ret.Error(0)
}
return r0
}
// UpdateByID provides a mock function with given fields: id, i
func (_m *MockDB) UpdateByID(id string, i interface{}) error {
ret := _m.Called(id, i)
var r0 error
if rf, ok := ret.Get(0).(func(string, interface{}) error); ok {
r0 = rf(id, i)
} else {
r0 = ret.Error(0)
}
return r0
}

View file

@ -16,12 +16,14 @@
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
// Package gtsmodel contains types used *internally* by GoToSocial and added/removed/selected from the database.
// Package model contains types used *internally* by GoToSocial and added/removed/selected from the database.
// These types should never be serialized and/or sent out via public APIs, as they contain sensitive information.
// The annotation used on these structs is for handling them via the go-pg ORM. See here: https://pg.uptrace.dev/models/
package gtsmodel
// The annotation used on these structs is for handling them via the go-pg ORM (hence why they're in this db subdir).
// See here for more info on go-pg model annotations: https://pg.uptrace.dev/models/
package model
import (
"crypto/rsa"
"net/url"
"time"
)
@ -37,20 +39,36 @@ type Account struct {
// Username of the account, should just be a string of [a-z0-9_]. Can be added to domain to create the full username in the form ``[username]@[domain]`` eg., ``user_96@example.org``
Username string `pg:",notnull,unique:userdomain"` // username and domain should be unique *with* each other
// Domain of the account, will be empty if this is a local account, otherwise something like ``example.org`` or ``mastodon.social``. Should be unique with username.
Domain string `pg:",unique:userdomain"` // username and domain
Domain string `pg:",unique:userdomain"` // username and domain should be unique *with* each other
/*
ACCOUNT METADATA
*/
// Avatar image for this account
Avatar
// Header image for this account
Header
// File name of the avatar on local storage
AvatarFileName string
// Gif? png? jpeg?
AvatarContentType string
// Size of the avatar in bytes
AvatarFileSize int
// When was the avatar last updated?
AvatarUpdatedAt time.Time `pg:"type:timestamp"`
// Where can the avatar be retrieved?
AvatarRemoteURL *url.URL `pg:"type:text"`
// File name of the header on local storage
HeaderFileName string
// Gif? png? jpeg?
HeaderContentType string
// Size of the header in bytes
HeaderFileSize int
// When was the header last updated?
HeaderUpdatedAt time.Time `pg:"type:timestamp"`
// Where can the header be retrieved?
HeaderRemoteURL *url.URL `pg:"type:text"`
// DisplayName for this account. Can be empty, then just the Username will be used for display purposes.
DisplayName string
// a key/value map of fields that this account has added to their profile
Fields map[string]string
Fields []Field
// A note that this account has on their profile (ie., the account's bio/description of themselves)
Note string
// Is this a memorial account, ie., has the user passed away?
@ -63,15 +81,25 @@ type Account struct {
UpdatedAt time.Time `pg:"type:timestamp,notnull,default:now()"`
// When should this account function until
SubscriptionExpiresAt time.Time `pg:"type:timestamp"`
// Does this account identify itself as a bot?
Bot bool
// What reason was given for signing up when this account was created?
Reason string
/*
PRIVACY SETTINGS
USER AND PRIVACY PREFERENCES
*/
// Does this account need an approval for new followers?
Locked bool
// Should this account be shown in the instance's profile directory?
Discoverable bool
// Default post privacy for this account
Privacy string
// Set posts from this account to sensitive by default?
Sensitive bool
// What language does this account post in?
Language string
/*
ACTIVITYPUB THINGS
@ -81,8 +109,6 @@ type Account struct {
URI string `pg:",unique"`
// At which URL can we see the user account in a web browser?
URL string `pg:",unique"`
// RemoteURL where this account is located. Will be empty if this is a local account.
RemoteURL string `pg:",unique"`
// Last time this account was located using the webfinger API.
LastWebfingeredAt time.Time `pg:"type:timestamp"`
// Address of this account's activitypub inbox, for sending activity to
@ -106,9 +132,9 @@ type Account struct {
Secret string
// Privatekey for validating activitypub requests, will obviously only be defined for local accounts
PrivateKey string
PrivateKey *rsa.PrivateKey
// Publickey for encoding activitypub requests, will be defined for both local and remote accounts
PublicKey string
PublicKey *rsa.PublicKey
/*
ADMIN FIELDS
@ -128,28 +154,11 @@ type Account struct {
SuspensionOrigin int
}
// Avatar represents the avatar for the account for display purposes
type Avatar struct {
// File name of the avatar on local storage
AvatarFileName string
// Gif? png? jpeg?
AvatarContentType string
AvatarFileSize int
AvatarUpdatedAt *time.Time `pg:"type:timestamp"`
// Where can we retrieve the avatar?
AvatarRemoteURL *url.URL `pg:"type:text"`
AvatarStorageSchemaVersion int
}
// Header represents the header of the account for display purposes
type Header struct {
// File name of the header on local storage
HeaderFileName string
// Gif? png? jpeg?
HeaderContentType string
HeaderFileSize int
HeaderUpdatedAt *time.Time `pg:"type:timestamp"`
// Where can we retrieve the header?
HeaderRemoteURL *url.URL `pg:"type:text"`
HeaderStorageSchemaVersion int
// Field represents a key value field on an account, for things like pronouns, website, etc.
// VerifiedAt is optional, to be used only if Value is a URL to a webpage that contains the
// username of the user.
type Field struct {
Name string
Value string
VerifiedAt time.Time `pg:"type:timestamp"`
}

View file

@ -16,9 +16,9 @@
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package gtsmodel
package model
import "github.com/gotosocial/gotosocial/pkg/mastotypes"
import "github.com/superseriousbusiness/gotosocial/pkg/mastotypes"
// Application represents an application that can perform actions on behalf of a user.
// It is used to authorize tokens etc, and is associated with an oauth client id in the database.
@ -41,8 +41,8 @@ type Application struct {
VapidKey string
}
// ToMastotype returns this application as a mastodon api type, ready for serialization
func (a *Application) ToMastotype() *mastotypes.Application {
// ToMasto returns this application as a mastodon api type, ready for serialization
func (a *Application) ToMasto() *mastotypes.Application {
return &mastotypes.Application{
ID: a.ID,
Name: a.Name,

View file

@ -0,0 +1,47 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package model
import "time"
// DomainBlock represents a federation block against a particular domain, of varying severity.
type DomainBlock struct {
// ID of this block in the database
ID string `pg:"type:uuid,default:gen_random_uuid(),pk,notnull,unique"`
// Domain to block. If ANY PART of the candidate domain contains this string, it will be blocked.
// For example: 'example.org' also blocks 'gts.example.org'. '.com' blocks *any* '.com' domains.
// TODO: implement wildcards here
Domain string `pg:",notnull"`
// When was this block created
CreatedAt time.Time `pg:"type:timestamp,notnull,default:now()"`
// When was this block updated
UpdatedAt time.Time `pg:"type:timestamp,notnull,default:now()"`
// Account ID of the creator of this block
CreatedByAccountID string `pg:",notnull"`
// TODO: define this
Severity int
// Reject media from this domain?
RejectMedia bool
// Reject reports from this domain?
RejectReports bool
// Private comment on this block, viewable to admins
PrivateComment string
// Public comment on this block, viewable (optionally) by everyone
PublicComment string
}

View file

@ -0,0 +1,35 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package model
import "time"
// EmailDomainBlock represents a domain that the server should automatically reject sign-up requests from.
type EmailDomainBlock struct {
// ID of this block in the database
ID string `pg:"type:uuid,default:gen_random_uuid(),pk,notnull,unique"`
// Email domain to block. Eg. 'gmail.com' or 'hotmail.com'
Domain string `pg:",notnull"`
// When was this block created
CreatedAt time.Time `pg:"type:timestamp,notnull,default:now()"`
// When was this block updated
UpdatedAt time.Time `pg:"type:timestamp,notnull,default:now()"`
// Account ID of the creator of this block
CreatedByAccountID string `pg:",notnull"`
}

View file

@ -0,0 +1,41 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package model
import "time"
// Follow represents one account following another, and the metadata around that follow.
type Follow struct {
// id of this follow in the database
ID string `pg:"type:uuid,default:gen_random_uuid(),pk,notnull,unique"`
// When was this follow created?
CreatedAt time.Time `pg:"type:timestamp,notnull,default:now()"`
// When was this follow last updated?
UpdatedAt time.Time `pg:"type:timestamp,notnull,default:now()"`
// Who does this follow belong to?
AccountID string `pg:",unique:srctarget,notnull"`
// Who does AccountID follow?
TargetAccountID string `pg:",unique:srctarget,notnull"`
// Does this follow also want to see reblogs and not just posts?
ShowReblogs bool `pg:"default:true"`
// What is the activitypub URI of this follow?
URI string `pg:",unique"`
// does the following account want to be notified when the followed account posts?
Notify bool
}

View file

@ -0,0 +1,41 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package model
import "time"
// FollowRequest represents one account requesting to follow another, and the metadata around that request.
type FollowRequest struct {
// id of this follow request in the database
ID string `pg:"type:uuid,default:gen_random_uuid(),pk,notnull,unique"`
// When was this follow request created?
CreatedAt time.Time `pg:"type:timestamp,notnull,default:now()"`
// When was this follow request last updated?
UpdatedAt time.Time `pg:"type:timestamp,notnull,default:now()"`
// Who does this follow request originate from?
AccountID string `pg:",unique:srctarget,notnull"`
// Who is the target of this follow request?
TargetAccountID string `pg:",unique:srctarget,notnull"`
// Does this follow also want to see reblogs and not just posts?
ShowReblogs bool `pg:"default:true"`
// What is the activitypub URI of this follow request?
URI string `pg:",unique"`
// does the following account want to be notified when the followed account posts?
Notify bool
}

View file

@ -0,0 +1,136 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package model
import (
"time"
)
// MediaAttachment represents a user-uploaded media attachment: an image/video/audio/gif that is
// somewhere in storage and that can be retrieved and served by the router.
type MediaAttachment struct {
// ID of the attachment in the database
ID string `pg:"type:uuid,default:gen_random_uuid(),pk,notnull,unique"`
// ID of the status to which this is attached
StatusID string
// Where can the attachment be retrieved on a remote server
RemoteURL string
// When was the attachment created
CreatedAt time.Time `pg:"type:timestamp,notnull,default:now()"`
// When was the attachment last updated
UpdatedAt time.Time `pg:"type:timestamp,notnull,default:now()"`
// Type of file (image/gif/audio/video)
Type FileType `pg:",notnull"`
// Metadata about the file
FileMeta FileMeta
// To which account does this attachment belong
AccountID string `pg:",notnull"`
// Description of the attachment (for screenreaders)
Description string
// To which scheduled status does this attachment belong
ScheduledStatusID string
// What is the generated blurhash of this attachment
Blurhash string
// What is the processing status of this attachment
Processing ProcessingStatus
// metadata for the whole file
File File
// small image thumbnail derived from a larger image, video, or audio file.
Thumbnail Thumbnail
// Is this attachment being used as an avatar?
Avatar bool
// Is this attachment being used as a header?
Header bool
}
// File refers to the metadata for the whole file
type File struct {
// What is the path of the file in storage.
Path string
// What is the MIME content type of the file.
ContentType string
// What is the size of the file in bytes.
FileSize int
// When was the file last updated.
UpdatedAt time.Time `pg:"type:timestamp,notnull,default:now()"`
}
// Thumbnail refers to a small image thumbnail derived from a larger image, video, or audio file.
type Thumbnail struct {
// What is the path of the file in storage
Path string
// What is the MIME content type of the file.
ContentType string
// What is the size of the file in bytes
FileSize int
// When was the file last updated
UpdatedAt time.Time `pg:"type:timestamp,notnull,default:now()"`
// What is the remote URL of the thumbnail
RemoteURL string
}
// ProcessingStatus refers to how far along in the processing stage the attachment is.
type ProcessingStatus int
const (
// ProcessingStatusReceived: the attachment has been received and is awaiting processing. No thumbnail available yet.
ProcessingStatusReceived ProcessingStatus = 0
// ProcessingStatusProcessing: the attachment is currently being processed. Thumbnail is available but full media is not.
ProcessingStatusProcessing ProcessingStatus = 1
// ProcessingStatusProcessed: the attachment has been fully processed and is ready to be served.
ProcessingStatusProcessed ProcessingStatus = 2
// ProcessingStatusError: something went wrong processing the attachment and it won't be tried again--these can be deleted.
ProcessingStatusError ProcessingStatus = 666
)
// FileType refers to the file type of the media attaachment.
type FileType string
const (
// FileTypeImage is for jpegs and pngs
FileTypeImage FileType = "image"
// FileTypeGif is for native gifs and soundless videos that have been converted to gifs
FileTypeGif FileType = "gif"
// FileTypeAudio is for audio-only files (no video)
FileTypeAudio FileType = "audio"
// FileTypeVideo is for files with audio + visual
FileTypeVideo FileType = "video"
)
// FileMeta describes metadata about the actual contents of the file.
type FileMeta struct {
Original Original
Small Small
}
// Small implements SmallMeta and can be used for a thumbnail of any media type
type Small struct {
Width int
Height int
Size int
Aspect float64
}
// ImageOriginal implements OriginalMeta for still images
type Original struct {
Width int
Height int
Size int
Aspect float64
}

View file

@ -16,7 +16,7 @@
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package gtsmodel
package model
import "time"

View file

@ -16,7 +16,7 @@
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package gtsmodel
package model
import (
"net"
@ -33,7 +33,7 @@ type User struct {
// id of this user in the local database; the end-user will never need to know this, it's strictly internal
ID string `pg:"type:uuid,default:gen_random_uuid(),pk,notnull,unique"`
// confirmed email address for this user, this should be unique -- only one email address registered per instance, multiple users per email are not supported
Email string `pg:",notnull,unique"`
Email string `pg:"default:null,unique"`
// The id of the local gtsmodel.Account entry for this user, if it exists (unconfirmed users don't have an account yet)
AccountID string `pg:"default:'',notnull,unique"`
// The encrypted password of this user, generated using https://pkg.go.dev/golang.org/x/crypto/bcrypt#GenerateFromPassword. A salt is included so we're safe against 🌈 tables

View file

@ -1,137 +0,0 @@
package db
import (
"context"
"errors"
"net/url"
"sync"
"github.com/go-fed/activity/pub"
"github.com/go-fed/activity/streams"
"github.com/go-fed/activity/streams/vocab"
"github.com/go-pg/pg/v10"
)
type postgresFederation struct {
locks *sync.Map
conn *pg.DB
}
func newPostgresFederation(conn *pg.DB) pub.Database {
return &postgresFederation{
locks: new(sync.Map),
conn: conn,
}
}
/*
GO-FED DB INTERFACE-IMPLEMENTING FUNCTIONS
*/
func (pf *postgresFederation) Lock(ctx context.Context, id *url.URL) error {
// Before any other Database methods are called, the relevant `id`
// entries are locked to allow for fine-grained concurrency.
// Strategy: create a new lock, if stored, continue. Otherwise, lock the
// existing mutex.
mu := &sync.Mutex{}
mu.Lock() // Optimistically lock if we do store it.
i, loaded := pf.locks.LoadOrStore(id.String(), mu)
if loaded {
mu = i.(*sync.Mutex)
mu.Lock()
}
return nil
}
func (pf *postgresFederation) Unlock(ctx context.Context, id *url.URL) error {
// Once Go-Fed is done calling Database methods, the relevant `id`
// entries are unlocked.
i, ok := pf.locks.Load(id.String())
if !ok {
return errors.New("missing an id in unlock")
}
mu := i.(*sync.Mutex)
mu.Unlock()
return nil
}
func (pf *postgresFederation) InboxContains(ctx context.Context, inbox *url.URL, id *url.URL) (bool, error) {
return false, nil
}
func (pf *postgresFederation) GetInbox(ctx context.Context, inboxIRI *url.URL) (inbox vocab.ActivityStreamsOrderedCollectionPage, err error) {
return nil, nil
}
func (pf *postgresFederation) SetInbox(ctx context.Context, inbox vocab.ActivityStreamsOrderedCollectionPage) error {
return nil
}
func (pf *postgresFederation) Owns(ctx context.Context, id *url.URL) (owns bool, err error) {
return false, nil
}
func (pf *postgresFederation) ActorForOutbox(ctx context.Context, outboxIRI *url.URL) (actorIRI *url.URL, err error) {
return nil, nil
}
func (pf *postgresFederation) ActorForInbox(ctx context.Context, inboxIRI *url.URL) (actorIRI *url.URL, err error) {
return nil, nil
}
func (pf *postgresFederation) OutboxForInbox(ctx context.Context, inboxIRI *url.URL) (outboxIRI *url.URL, err error) {
return nil, nil
}
func (pf *postgresFederation) Exists(ctx context.Context, id *url.URL) (exists bool, err error) {
return false, nil
}
func (pf *postgresFederation) Get(ctx context.Context, id *url.URL) (value vocab.Type, err error) {
return nil, nil
}
func (pf *postgresFederation) Create(ctx context.Context, asType vocab.Type) error {
t, err := streams.NewTypeResolver()
if err != nil {
return err
}
if err := t.Resolve(ctx, asType); err != nil {
return err
}
asType.GetTypeName()
return nil
}
func (pf *postgresFederation) Update(ctx context.Context, asType vocab.Type) error {
return nil
}
func (pf *postgresFederation) Delete(ctx context.Context, id *url.URL) error {
return nil
}
func (pf *postgresFederation) GetOutbox(ctx context.Context, outboxIRI *url.URL) (inbox vocab.ActivityStreamsOrderedCollectionPage, err error) {
return nil, nil
}
func (pf *postgresFederation) SetOutbox(ctx context.Context, outbox vocab.ActivityStreamsOrderedCollectionPage) error {
return nil
}
func (pf *postgresFederation) NewID(ctx context.Context, t vocab.Type) (id *url.URL, err error) {
return nil, nil
}
func (pf *postgresFederation) Followers(ctx context.Context, actorIRI *url.URL) (followers vocab.ActivityStreamsCollection, err error) {
return nil, nil
}
func (pf *postgresFederation) Following(ctx context.Context, actorIRI *url.URL) (followers vocab.ActivityStreamsCollection, err error) {
return nil, nil
}
func (pf *postgresFederation) Liked(ctx context.Context, actorIRI *url.URL) (followers vocab.ActivityStreamsCollection, err error) {
return nil, nil
}

View file

@ -20,8 +20,12 @@ package db
import (
"context"
"crypto/rand"
"crypto/rsa"
"errors"
"fmt"
"net"
"net/mail"
"regexp"
"strings"
"time"
@ -30,14 +34,17 @@ import (
"github.com/go-pg/pg/extra/pgdebug"
"github.com/go-pg/pg/v10"
"github.com/go-pg/pg/v10/orm"
"github.com/gotosocial/gotosocial/internal/config"
"github.com/gotosocial/gotosocial/internal/gtsmodel"
"github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db/model"
"github.com/superseriousbusiness/gotosocial/internal/util"
"github.com/superseriousbusiness/gotosocial/pkg/mastotypes"
"golang.org/x/crypto/bcrypt"
)
// postgresService satisfies the DB interface
type postgresService struct {
config *config.DBConfig
config *config.Config
conn *pg.DB
log *logrus.Entry
cancel context.CancelFunc
@ -46,7 +53,7 @@ type postgresService struct {
// newPostgresService returns a postgresService derived from the provided config, which implements the go-fed DB interface.
// Under the hood, it uses https://github.com/go-pg/pg to create and maintain a database connection.
func newPostgresService(ctx context.Context, c *config.Config, log *logrus.Entry) (*postgresService, error) {
func newPostgresService(ctx context.Context, c *config.Config, log *logrus.Entry) (DB, error) {
opts, err := derivePGOptions(c)
if err != nil {
return nil, fmt.Errorf("could not create postgres service: %s", err)
@ -98,18 +105,18 @@ func newPostgresService(ctx context.Context, c *config.Config, log *logrus.Entry
return nil, errors.New("db connection timeout")
}
// we can confidently return this useable postgres service now
return &postgresService{
config: c.DBConfig,
ps := &postgresService{
config: c,
conn: conn,
log: log,
cancel: cancel,
federationDB: newPostgresFederation(conn),
}, nil
}
func (ps *postgresService) Federation() pub.Database {
return ps.federationDB
federatingDB := newFederatingDB(ps, c)
ps.federationDB = federatingDB
// we can confidently return this useable postgres service now
return ps, nil
}
/*
@ -168,9 +175,29 @@ func derivePGOptions(c *config.Config) (*pg.Options, error) {
}
/*
EXTRA FUNCTIONS
FEDERATION FUNCTIONALITY
*/
func (ps *postgresService) Federation() pub.Database {
return ps.federationDB
}
/*
BASIC DB FUNCTIONALITY
*/
func (ps *postgresService) CreateTable(i interface{}) error {
return ps.conn.Model(i).CreateTable(&orm.CreateTableOptions{
IfNotExists: true,
})
}
func (ps *postgresService) DropTable(i interface{}) error {
return ps.conn.Model(i).DropTable(&orm.DropTableOptions{
IfExists: true,
})
}
func (ps *postgresService) Stop(ctx context.Context) error {
ps.log.Info("closing db connection")
if err := ps.conn.Close(); err != nil {
@ -181,11 +208,15 @@ func (ps *postgresService) Stop(ctx context.Context) error {
return nil
}
func (ps *postgresService) IsHealthy(ctx context.Context) error {
return ps.conn.Ping(ctx)
}
func (ps *postgresService) CreateSchema(ctx context.Context) error {
models := []interface{}{
(*gtsmodel.Account)(nil),
(*gtsmodel.Status)(nil),
(*gtsmodel.User)(nil),
(*model.Account)(nil),
(*model.Status)(nil),
(*model.User)(nil),
}
ps.log.Info("creating db schema")
@ -202,32 +233,35 @@ func (ps *postgresService) CreateSchema(ctx context.Context) error {
return nil
}
func (ps *postgresService) IsHealthy(ctx context.Context) error {
return ps.conn.Ping(ctx)
}
func (ps *postgresService) CreateTable(i interface{}) error {
return ps.conn.Model(i).CreateTable(&orm.CreateTableOptions{
IfNotExists: true,
})
}
func (ps *postgresService) DropTable(i interface{}) error {
return ps.conn.Model(i).DropTable(&orm.DropTableOptions{
IfExists: true,
})
}
func (ps *postgresService) GetByID(id string, i interface{}) error {
return ps.conn.Model(i).Where("id = ?", id).Select()
if err := ps.conn.Model(i).Where("id = ?", id).Select(); err != nil {
if err == pg.ErrNoRows {
return ErrNoEntries{}
}
return err
}
return nil
}
func (ps *postgresService) GetWhere(key string, value interface{}, i interface{}) error {
return ps.conn.Model(i).Where(fmt.Sprintf("%s = ?", key), value).Select()
if err := ps.conn.Model(i).Where("? = ?", pg.Safe(key), value).Select(); err != nil {
if err == pg.ErrNoRows {
return ErrNoEntries{}
}
return err
}
return nil
}
func (ps *postgresService) GetAll(i interface{}) error {
return ps.conn.Model(i).Select()
if err := ps.conn.Model(i).Select(); err != nil {
if err == pg.ErrNoRows {
return ErrNoEntries{}
}
return err
}
return nil
}
func (ps *postgresService) Put(i interface{}) error {
@ -236,16 +270,393 @@ func (ps *postgresService) Put(i interface{}) error {
}
func (ps *postgresService) UpdateByID(id string, i interface{}) error {
_, err := ps.conn.Model(i).OnConflict("(id) DO UPDATE").Insert()
if _, err := ps.conn.Model(i).OnConflict("(id) DO UPDATE").Insert(); err != nil {
if err == pg.ErrNoRows {
return ErrNoEntries{}
}
return err
}
return nil
}
func (ps *postgresService) UpdateOneByID(id string, key string, value interface{}, i interface{}) error {
_, err := ps.conn.Model(i).Set("? = ?", pg.Safe(key), value).Where("id = ?", id).Update()
return err
}
func (ps *postgresService) DeleteByID(id string, i interface{}) error {
_, err := ps.conn.Model(i).Where("id = ?", id).Delete()
if _, err := ps.conn.Model(i).Where("id = ?", id).Delete(); err != nil {
if err == pg.ErrNoRows {
return ErrNoEntries{}
}
return err
}
return nil
}
func (ps *postgresService) DeleteWhere(key string, value interface{}, i interface{}) error {
_, err := ps.conn.Model(i).Where(fmt.Sprintf("%s = ?", key), value).Delete()
if _, err := ps.conn.Model(i).Where("? = ?", pg.Safe(key), value).Delete(); err != nil {
if err == pg.ErrNoRows {
return ErrNoEntries{}
}
return err
}
return nil
}
/*
HANDY SHORTCUTS
*/
func (ps *postgresService) GetAccountByUserID(userID string, account *model.Account) error {
user := &model.User{
ID: userID,
}
if err := ps.conn.Model(user).Where("id = ?", userID).Select(); err != nil {
if err == pg.ErrNoRows {
return ErrNoEntries{}
}
return err
}
if err := ps.conn.Model(account).Where("id = ?", user.AccountID).Select(); err != nil {
if err == pg.ErrNoRows {
return ErrNoEntries{}
}
return err
}
return nil
}
func (ps *postgresService) GetFollowRequestsForAccountID(accountID string, followRequests *[]model.FollowRequest) error {
if err := ps.conn.Model(followRequests).Where("target_account_id = ?", accountID).Select(); err != nil {
if err == pg.ErrNoRows {
return ErrNoEntries{}
}
return err
}
return nil
}
func (ps *postgresService) GetFollowingByAccountID(accountID string, following *[]model.Follow) error {
if err := ps.conn.Model(following).Where("account_id = ?", accountID).Select(); err != nil {
if err == pg.ErrNoRows {
return ErrNoEntries{}
}
return err
}
return nil
}
func (ps *postgresService) GetFollowersByAccountID(accountID string, followers *[]model.Follow) error {
if err := ps.conn.Model(followers).Where("target_account_id = ?", accountID).Select(); err != nil {
if err == pg.ErrNoRows {
return ErrNoEntries{}
}
return err
}
return nil
}
func (ps *postgresService) GetStatusesByAccountID(accountID string, statuses *[]model.Status) error {
if err := ps.conn.Model(statuses).Where("account_id = ?", accountID).Select(); err != nil {
if err == pg.ErrNoRows {
return ErrNoEntries{}
}
return err
}
return nil
}
func (ps *postgresService) GetStatusesByTimeDescending(accountID string, statuses *[]model.Status, limit int) error {
q := ps.conn.Model(statuses).Order("created_at DESC")
if limit != 0 {
q = q.Limit(limit)
}
if accountID != "" {
q = q.Where("account_id = ?", accountID)
}
if err := q.Select(); err != nil {
if err == pg.ErrNoRows {
return ErrNoEntries{}
}
return err
}
return nil
}
func (ps *postgresService) GetLastStatusForAccountID(accountID string, status *model.Status) error {
if err := ps.conn.Model(status).Order("created_at DESC").Limit(1).Where("account_id = ?", accountID).Select(); err != nil {
if err == pg.ErrNoRows {
return ErrNoEntries{}
}
return err
}
return nil
}
func (ps *postgresService) IsUsernameAvailable(username string) error {
// if no error we fail because it means we found something
// if error but it's not pg.ErrNoRows then we fail
// if err is pg.ErrNoRows we're good, we found nothing so continue
if err := ps.conn.Model(&model.Account{}).Where("username = ?", username).Where("domain = ?", nil).Select(); err == nil {
return fmt.Errorf("username %s already in use", username)
} else if err != pg.ErrNoRows {
return fmt.Errorf("db error: %s", err)
}
return nil
}
func (ps *postgresService) IsEmailAvailable(email string) error {
// parse the domain from the email
m, err := mail.ParseAddress(email)
if err != nil {
return fmt.Errorf("error parsing email address %s: %s", email, err)
}
domain := strings.Split(m.Address, "@")[1] // domain will always be the second part after @
// check if the email domain is blocked
if err := ps.conn.Model(&model.EmailDomainBlock{}).Where("domain = ?", domain).Select(); err == nil {
// fail because we found something
return fmt.Errorf("email domain %s is blocked", domain)
} else if err != pg.ErrNoRows {
// fail because we got an unexpected error
return fmt.Errorf("db error: %s", err)
}
// check if this email is associated with a user already
if err := ps.conn.Model(&model.User{}).Where("email = ?", email).WhereOr("unconfirmed_email = ?", email).Select(); err == nil {
// fail because we found something
return fmt.Errorf("email %s already in use", email)
} else if err != pg.ErrNoRows {
// fail because we got an unexpected error
return fmt.Errorf("db error: %s", err)
}
return nil
}
func (ps *postgresService) NewSignup(username string, reason string, requireApproval bool, email string, password string, signUpIP net.IP, locale string, appID string) (*model.User, error) {
key, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
ps.log.Errorf("error creating new rsa key: %s", err)
return nil, err
}
uris := util.GenerateURIs(username, ps.config.Protocol, ps.config.Host)
a := &model.Account{
Username: username,
DisplayName: username,
Reason: reason,
URL: uris.UserURL,
PrivateKey: key,
PublicKey: &key.PublicKey,
ActorType: "Person",
URI: uris.UserURI,
InboxURL: uris.InboxURL,
OutboxURL: uris.OutboxURL,
FollowersURL: uris.FollowersURL,
FeaturedCollectionURL: uris.CollectionURL,
}
if _, err = ps.conn.Model(a).Insert(); err != nil {
return nil, err
}
pw, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err != nil {
return nil, fmt.Errorf("error hashing password: %s", err)
}
u := &model.User{
AccountID: a.ID,
EncryptedPassword: string(pw),
SignUpIP: signUpIP,
Locale: locale,
UnconfirmedEmail: email,
CreatedByApplicationID: appID,
Approved: !requireApproval, // if we don't require moderator approval, just pre-approve the user
}
if _, err = ps.conn.Model(u).Insert(); err != nil {
return nil, err
}
return u, nil
}
func (ps *postgresService) SetHeaderOrAvatarForAccountID(mediaAttachment *model.MediaAttachment, accountID string) error {
_, err := ps.conn.Model(mediaAttachment).Insert()
return err
}
func (ps *postgresService) GetHeaderForAccountID(header *model.MediaAttachment, accountID string) error {
if err := ps.conn.Model(header).Where("account_id = ?", accountID).Where("header = ?", true).Select(); err != nil {
if err == pg.ErrNoRows {
return ErrNoEntries{}
}
return err
}
return nil
}
func (ps *postgresService) GetAvatarForAccountID(avatar *model.MediaAttachment, accountID string) error {
if err := ps.conn.Model(avatar).Where("account_id = ?", accountID).Where("avatar = ?", true).Select(); err != nil {
if err == pg.ErrNoRows {
return ErrNoEntries{}
}
return err
}
return nil
}
/*
CONVERSION FUNCTIONS
*/
// AccountToMastoSensitive takes an internal account model and transforms it into an account ready to be served through the API.
// The resulting account fits the specifications for the path /api/v1/accounts/verify_credentials, as described here:
// https://docs.joinmastodon.org/methods/accounts/. Note that it's *sensitive* because it's only meant to be exposed to the user
// that the account actually belongs to.
func (ps *postgresService) AccountToMastoSensitive(a *model.Account) (*mastotypes.Account, error) {
// we can build this sensitive account easily by first getting the public account....
mastoAccount, err := ps.AccountToMastoPublic(a)
if err != nil {
return nil, err
}
// then adding the Source object to it...
// check pending follow requests aimed at this account
fr := []model.FollowRequest{}
if err := ps.GetFollowRequestsForAccountID(a.ID, &fr); err != nil {
if _, ok := err.(ErrNoEntries); !ok {
return nil, fmt.Errorf("error getting follow requests: %s", err)
}
}
var frc int
if fr != nil {
frc = len(fr)
}
mastoAccount.Source = &mastotypes.Source{
Privacy: a.Privacy,
Sensitive: a.Sensitive,
Language: a.Language,
Note: a.Note,
Fields: mastoAccount.Fields,
FollowRequestsCount: frc,
}
return mastoAccount, nil
}
func (ps *postgresService) AccountToMastoPublic(a *model.Account) (*mastotypes.Account, error) {
// count followers
followers := []model.Follow{}
if err := ps.GetFollowersByAccountID(a.ID, &followers); err != nil {
if _, ok := err.(ErrNoEntries); !ok {
return nil, fmt.Errorf("error getting followers: %s", err)
}
}
var followersCount int
if followers != nil {
followersCount = len(followers)
}
// count following
following := []model.Follow{}
if err := ps.GetFollowingByAccountID(a.ID, &following); err != nil {
if _, ok := err.(ErrNoEntries); !ok {
return nil, fmt.Errorf("error getting following: %s", err)
}
}
var followingCount int
if following != nil {
followingCount = len(following)
}
// count statuses
statuses := []model.Status{}
if err := ps.GetStatusesByAccountID(a.ID, &statuses); err != nil {
if _, ok := err.(ErrNoEntries); !ok {
return nil, fmt.Errorf("error getting last statuses: %s", err)
}
}
var statusesCount int
if statuses != nil {
statusesCount = len(statuses)
}
// check when the last status was
lastStatus := &model.Status{}
if err := ps.GetLastStatusForAccountID(a.ID, lastStatus); err != nil {
if _, ok := err.(ErrNoEntries); !ok {
return nil, fmt.Errorf("error getting last status: %s", err)
}
}
var lastStatusAt string
if lastStatus != nil {
lastStatusAt = lastStatus.CreatedAt.Format(time.RFC3339)
}
// build the avatar and header URLs
avi := &model.MediaAttachment{}
if err := ps.GetAvatarForAccountID(avi, a.ID); err != nil {
if _, ok := err.(ErrNoEntries); !ok {
return nil, fmt.Errorf("error getting avatar: %s", err)
}
}
aviURL := avi.File.Path
aviURLStatic := avi.Thumbnail.Path
header := &model.MediaAttachment{}
if err := ps.GetHeaderForAccountID(avi, a.ID); err != nil {
if _, ok := err.(ErrNoEntries); !ok {
return nil, fmt.Errorf("error getting header: %s", err)
}
}
headerURL := header.File.Path
headerURLStatic := header.Thumbnail.Path
// get the fields set on this account
fields := []mastotypes.Field{}
for _, f := range a.Fields {
mField := mastotypes.Field{
Name: f.Name,
Value: f.Value,
}
if !f.VerifiedAt.IsZero() {
mField.VerifiedAt = f.VerifiedAt.Format(time.RFC3339)
}
fields = append(fields, mField)
}
var acct string
if a.Domain != "" {
// this is a remote user
acct = fmt.Sprintf("%s@%s", a.Username, a.Domain)
} else {
// this is a local user
acct = a.Username
}
return &mastotypes.Account{
ID: a.ID,
Username: a.Username,
Acct: acct,
DisplayName: a.DisplayName,
Locked: a.Locked,
Bot: a.Bot,
CreatedAt: a.CreatedAt.Format(time.RFC3339),
Note: a.Note,
URL: a.URL,
Avatar: aviURL,
AvatarStatic: aviURLStatic,
Header: headerURL,
HeaderStatic: headerURLStatic,
FollowersCount: followersCount,
FollowingCount: followingCount,
StatusesCount: statusesCount,
LastStatusAt: lastStatusAt,
Emojis: nil, // TODO: implement this
Fields: fields,
}, nil
}

21
internal/db/pg_test.go Normal file
View file

@ -0,0 +1,21 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package db
// TODO: write tests for postgres

View file

@ -0,0 +1,96 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package distributor
import (
"github.com/go-fed/activity/pub"
"github.com/sirupsen/logrus"
)
// Distributor should be passed to api modules (see internal/apimodule/...). It is used for
// passing messages back and forth from the client API and the federating interface, via channels.
// It also contains logic for filtering which messages should end up where.
// It is designed to be used asynchronously: the client API and the federating API should just be able to
// fire messages into the distributor and not wait for a reply before proceeding with other work. This allows
// for clean distribution of messages without slowing down the client API and harming the user experience.
type Distributor interface {
// ClientAPIIn returns a channel for accepting messages that come from the gts client API.
ClientAPIIn() chan interface{}
// ClientAPIOut returns a channel for putting in messages that need to go to the gts client API.
ClientAPIOut() chan interface{}
// Start starts the Distributor, reading from its channels and passing messages back and forth.
Start() error
// Stop stops the distributor cleanly, finishing handling any remaining messages before closing down.
Stop() error
}
// distributor just implements the Distributor interface
type distributor struct {
federator pub.FederatingActor
clientAPIIn chan interface{}
clientAPIOut chan interface{}
stop chan interface{}
log *logrus.Logger
}
// New returns a new Distributor that uses the given federator and logger
func New(federator pub.FederatingActor, log *logrus.Logger) Distributor {
return &distributor{
federator: federator,
clientAPIIn: make(chan interface{}, 100),
clientAPIOut: make(chan interface{}, 100),
stop: make(chan interface{}),
log: log,
}
}
// ClientAPIIn returns a channel for accepting messages that come from the gts client API.
func (d *distributor) ClientAPIIn() chan interface{} {
return d.clientAPIIn
}
// ClientAPIOut returns a channel for putting in messages that need to go to the gts client API.
func (d *distributor) ClientAPIOut() chan interface{} {
return d.clientAPIOut
}
// Start starts the Distributor, reading from its channels and passing messages back and forth.
func (d *distributor) Start() error {
go func() {
DistLoop:
for {
select {
case clientMsgIn := <-d.clientAPIIn:
d.log.Infof("received clientMsgIn: %+v", clientMsgIn)
case clientMsgOut := <-d.clientAPIOut:
d.log.Infof("received clientMsgOut: %+v", clientMsgOut)
case <-d.stop:
break DistLoop
}
}
}()
return nil
}
// Stop stops the distributor cleanly, finishing handling any remaining messages before closing down.
// TODO: empty message buffer properly before stopping otherwise we'll lose federating messages.
func (d *distributor) Stop() error {
close(d.stop)
return nil
}

View file

@ -0,0 +1,70 @@
// Code generated by mockery v2.7.4. DO NOT EDIT.
package distributor
import mock "github.com/stretchr/testify/mock"
// MockDistributor is an autogenerated mock type for the Distributor type
type MockDistributor struct {
mock.Mock
}
// ClientAPIIn provides a mock function with given fields:
func (_m *MockDistributor) ClientAPIIn() chan interface{} {
ret := _m.Called()
var r0 chan interface{}
if rf, ok := ret.Get(0).(func() chan interface{}); ok {
r0 = rf()
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(chan interface{})
}
}
return r0
}
// ClientAPIOut provides a mock function with given fields:
func (_m *MockDistributor) ClientAPIOut() chan interface{} {
ret := _m.Called()
var r0 chan interface{}
if rf, ok := ret.Get(0).(func() chan interface{}); ok {
r0 = rf()
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(chan interface{})
}
}
return r0
}
// Start provides a mock function with given fields:
func (_m *MockDistributor) Start() error {
ret := _m.Called()
var r0 error
if rf, ok := ret.Get(0).(func() error); ok {
r0 = rf()
} else {
r0 = ret.Error(0)
}
return r0
}
// Stop provides a mock function with given fields:
func (_m *MockDistributor) Stop() error {
ret := _m.Called()
var r0 error
if rf, ok := ret.Get(0).(func() error); ok {
r0 = rf()
} else {
r0 = ret.Error(0)
}
return r0
}

View file

@ -27,88 +27,93 @@ import (
"github.com/go-fed/activity/pub"
"github.com/go-fed/activity/streams/vocab"
"github.com/gotosocial/gotosocial/internal/db"
"github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/db"
)
// New returns a go-fed compatible federating actor
func New(db db.DB) pub.FederatingActor {
fa := &API{}
return pub.NewFederatingActor(fa, fa, db.Federation(), fa)
func New(db db.DB, log *logrus.Logger) pub.FederatingActor {
f := &Federator{
db: db,
}
return pub.NewFederatingActor(f, f, db.Federation(), f)
}
// API implements several go-fed interfaces in one convenient location
type API struct {
// Federator implements several go-fed interfaces in one convenient location
type Federator struct {
db db.DB
}
// AuthenticateGetInbox determines whether the request is for a GET call to the Actor's Inbox.
func (fa *API) AuthenticateGetInbox(ctx context.Context, w http.ResponseWriter, r *http.Request) (context.Context, bool, error) {
func (f *Federator) AuthenticateGetInbox(ctx context.Context, w http.ResponseWriter, r *http.Request) (context.Context, bool, error) {
// TODO
// use context.WithValue() and context.Value() to set and get values through here
return nil, false, nil
}
// AuthenticateGetOutbox determines whether the request is for a GET call to the Actor's Outbox.
func (fa *API) AuthenticateGetOutbox(ctx context.Context, w http.ResponseWriter, r *http.Request) (context.Context, bool, error) {
func (f *Federator) AuthenticateGetOutbox(ctx context.Context, w http.ResponseWriter, r *http.Request) (context.Context, bool, error) {
// TODO
return nil, false, nil
}
// GetOutbox returns a proper paginated view of the Outbox for serving in a response.
func (fa *API) GetOutbox(ctx context.Context, r *http.Request) (vocab.ActivityStreamsOrderedCollectionPage, error) {
func (f *Federator) GetOutbox(ctx context.Context, r *http.Request) (vocab.ActivityStreamsOrderedCollectionPage, error) {
// TODO
return nil, nil
}
// NewTransport returns a new pub.Transport for federating with peer software.
func (fa *API) NewTransport(ctx context.Context, actorBoxIRI *url.URL, gofedAgent string) (pub.Transport, error) {
func (f *Federator) NewTransport(ctx context.Context, actorBoxIRI *url.URL, gofedAgent string) (pub.Transport, error) {
// TODO
return nil, nil
}
func (fa *API) PostInboxRequestBodyHook(ctx context.Context, r *http.Request, activity pub.Activity) (context.Context, error) {
func (f *Federator) PostInboxRequestBodyHook(ctx context.Context, r *http.Request, activity pub.Activity) (context.Context, error) {
// TODO
return nil, nil
}
func (fa *API) AuthenticatePostInbox(ctx context.Context, w http.ResponseWriter, r *http.Request) (context.Context, bool, error) {
func (f *Federator) AuthenticatePostInbox(ctx context.Context, w http.ResponseWriter, r *http.Request) (context.Context, bool, error) {
// TODO
return nil, false, nil
}
func (fa *API) Blocked(ctx context.Context, actorIRIs []*url.URL) (bool, error) {
func (f *Federator) Blocked(ctx context.Context, actorIRIs []*url.URL) (bool, error) {
// TODO
return false, nil
}
func (fa *API) FederatingCallbacks(ctx context.Context) (pub.FederatingWrappedCallbacks, []interface{}, error) {
func (f *Federator) FederatingCallbacks(ctx context.Context) (pub.FederatingWrappedCallbacks, []interface{}, error) {
// TODO
return pub.FederatingWrappedCallbacks{}, nil, nil
}
func (fa *API) DefaultCallback(ctx context.Context, activity pub.Activity) error {
func (f *Federator) DefaultCallback(ctx context.Context, activity pub.Activity) error {
// TODO
return nil
}
func (fa *API) MaxInboxForwardingRecursionDepth(ctx context.Context) int {
func (f *Federator) MaxInboxForwardingRecursionDepth(ctx context.Context) int {
// TODO
return 0
}
func (fa *API) MaxDeliveryRecursionDepth(ctx context.Context) int {
func (f *Federator) MaxDeliveryRecursionDepth(ctx context.Context) int {
// TODO
return 0
}
func (fa *API) FilterForwarding(ctx context.Context, potentialRecipients []*url.URL, a pub.Activity) ([]*url.URL, error) {
func (f *Federator) FilterForwarding(ctx context.Context, potentialRecipients []*url.URL, a pub.Activity) ([]*url.URL, error) {
// TODO
return nil, nil
}
func (fa *API) GetInbox(ctx context.Context, r *http.Request) (vocab.ActivityStreamsOrderedCollectionPage, error) {
func (f *Federator) GetInbox(ctx context.Context, r *http.Request) (vocab.ActivityStreamsOrderedCollectionPage, error) {
// TODO
return nil, nil
}
func (fa *API) Now() time.Time {
func (f *Federator) Now() time.Time {
return time.Now()
}

View file

@ -25,10 +25,20 @@ import (
"os/signal"
"syscall"
"github.com/gotosocial/gotosocial/internal/action"
"github.com/gotosocial/gotosocial/internal/config"
"github.com/gotosocial/gotosocial/internal/db"
"github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/action"
"github.com/superseriousbusiness/gotosocial/internal/apimodule"
"github.com/superseriousbusiness/gotosocial/internal/apimodule/account"
"github.com/superseriousbusiness/gotosocial/internal/apimodule/app"
"github.com/superseriousbusiness/gotosocial/internal/apimodule/auth"
"github.com/superseriousbusiness/gotosocial/internal/cache"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/federation"
"github.com/superseriousbusiness/gotosocial/internal/media"
"github.com/superseriousbusiness/gotosocial/internal/oauth"
"github.com/superseriousbusiness/gotosocial/internal/router"
"github.com/superseriousbusiness/gotosocial/internal/storage"
)
// Run creates and starts a gotosocial server
@ -38,9 +48,48 @@ var Run action.GTSAction = func(ctx context.Context, c *config.Config, log *logr
return fmt.Errorf("error creating dbservice: %s", err)
}
// if err := dbService.CreateSchema(ctx); err != nil {
// return fmt.Errorf("error creating dbschema: %s", err)
// }
router, err := router.New(c, log)
if err != nil {
return fmt.Errorf("error creating router: %s", err)
}
storageBackend, err := storage.NewInMem(c, log)
if err != nil {
return fmt.Errorf("error creating storage backend: %s", err)
}
// build backend handlers
mediaHandler := media.New(c, dbService, storageBackend, log)
oauthServer := oauth.New(dbService, log)
// build client api modules
authModule := auth.New(oauthServer, dbService, log)
accountModule := account.New(c, dbService, oauthServer, mediaHandler, log)
appsModule := app.New(oauthServer, dbService, log)
apiModules := []apimodule.ClientAPIModule{
authModule, // this one has to go first so the other modules use its middleware
accountModule,
appsModule,
}
for _, m := range apiModules {
if err := m.Route(router); err != nil {
return fmt.Errorf("routing error: %s", err)
}
if err := m.CreateTables(dbService); err != nil {
return fmt.Errorf("table creation error: %s", err)
}
}
gts, err := New(dbService, &cache.MockCache{}, router, federation.New(dbService, log), c)
if err != nil {
return fmt.Errorf("error creating gotosocial service: %s", err)
}
if err := gts.Start(ctx); err != nil {
return fmt.Errorf("error starting gotosocial service: %s", err)
}
// catch shutdown signals from the operating system
sigs := make(chan os.Signal, 1)
@ -49,8 +98,8 @@ var Run action.GTSAction = func(ctx context.Context, c *config.Config, log *logr
log.Infof("received signal %s, shutting down", sig)
// close down all running services in order
if err := dbService.Stop(ctx); err != nil {
return fmt.Errorf("error closing dbservice: %s", err)
if err := gts.Stop(ctx); err != nil {
return fmt.Errorf("error closing gotosocial service: %s", err)
}
log.Info("done! exiting...")

View file

@ -22,17 +22,22 @@ import (
"context"
"github.com/go-fed/activity/pub"
"github.com/gotosocial/gotosocial/internal/cache"
"github.com/gotosocial/gotosocial/internal/config"
"github.com/gotosocial/gotosocial/internal/db"
"github.com/gotosocial/gotosocial/internal/router"
"github.com/superseriousbusiness/gotosocial/internal/cache"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/router"
)
// Gotosocial is the 'main' function of the gotosocial server, and the place where everything hangs together.
// The logic of stopping and starting the entire server is contained here.
type Gotosocial interface {
Start(context.Context) error
Stop(context.Context) error
}
// New returns a new gotosocial server, initialized with the given configuration.
// An error will be returned the caller if something goes wrong during initialization
// eg., no db or storage connection, port for router already in use, etc.
func New(db db.DB, cache cache.Cache, apiRouter router.Router, federationAPI pub.FederatingActor, config *config.Config) (Gotosocial, error) {
return &gotosocial{
db: db,
@ -43,6 +48,7 @@ func New(db db.DB, cache cache.Cache, apiRouter router.Router, federationAPI pub
}, nil
}
// gotosocial fulfils the gotosocial interface.
type gotosocial struct {
db db.DB
cache cache.Cache
@ -51,10 +57,19 @@ type gotosocial struct {
config *config.Config
}
// Start starts up the gotosocial server. If something goes wrong
// while starting the server, then an error will be returned.
func (gts *gotosocial) Start(ctx context.Context) error {
gts.apiRouter.Start()
return nil
}
func (gts *gotosocial) Stop(ctx context.Context) error {
if err := gts.apiRouter.Stop(ctx); err != nil {
return err
}
if err := gts.db.Stop(ctx); err != nil {
return err
}
return nil
}

View file

@ -0,0 +1,28 @@
// Code generated by mockery v2.7.4. DO NOT EDIT.
package gotosocial
import (
context "context"
mock "github.com/stretchr/testify/mock"
)
// MockGotosocial is an autogenerated mock type for the Gotosocial type
type MockGotosocial struct {
mock.Mock
}
// Start provides a mock function with given fields: _a0
func (_m *MockGotosocial) Start(_a0 context.Context) error {
ret := _m.Called(_a0)
var r0 error
if rf, ok := ret.Get(0).(func(context.Context) error); ok {
r0 = rf(_a0)
} else {
r0 = ret.Error(0)
}
return r0
}

View file

@ -18,6 +18,195 @@
package media
// API provides an interface for parsing, storing, and retrieving media objects like photos and videos
type API interface {
import (
"errors"
"fmt"
"strings"
"time"
"github.com/google/uuid"
"github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/db/model"
"github.com/superseriousbusiness/gotosocial/internal/storage"
)
// MediaHandler provides an interface for parsing, storing, and retrieving media objects like photos, videos, and gifs.
type MediaHandler interface {
// SetHeaderOrAvatarForAccountID takes a new header image for an account, checks it out, removes exif data from it,
// puts it in whatever storage backend we're using, sets the relevant fields in the database for the new image,
// and then returns information to the caller about the new header.
SetHeaderOrAvatarForAccountID(img []byte, accountID string, headerOrAvi string) (*model.MediaAttachment, error)
}
type mediaHandler struct {
config *config.Config
db db.DB
storage storage.Storage
log *logrus.Logger
}
func New(config *config.Config, database db.DB, storage storage.Storage, log *logrus.Logger) MediaHandler {
return &mediaHandler{
config: config,
db: database,
storage: storage,
log: log,
}
}
// HeaderInfo wraps the urls at which a Header and a StaticHeader is available from the server.
type HeaderInfo struct {
// URL to the header
Header string
// Static version of the above (eg., a path to a still image if the header is a gif)
HeaderStatic string
}
/*
INTERFACE FUNCTIONS
*/
func (mh *mediaHandler) SetHeaderOrAvatarForAccountID(img []byte, accountID string, headerOrAvi string) (*model.MediaAttachment, error) {
l := mh.log.WithField("func", "SetHeaderForAccountID")
if headerOrAvi != "header" && headerOrAvi != "avatar" {
return nil, errors.New("header or avatar not selected")
}
// make sure we have an image we can handle
contentType, err := parseContentType(img)
if err != nil {
return nil, err
}
if !supportedImageType(contentType) {
return nil, fmt.Errorf("%s is not an accepted image type", contentType)
}
if len(img) == 0 {
return nil, fmt.Errorf("passed reader was of size 0")
}
l.Tracef("read %d bytes of file", len(img))
// process it
ma, err := mh.processHeaderOrAvi(img, contentType, headerOrAvi, accountID)
if err != nil {
return nil, fmt.Errorf("error processing %s: %s", headerOrAvi, err)
}
// set it in the database
if err := mh.db.SetHeaderOrAvatarForAccountID(ma, accountID); err != nil {
return nil, fmt.Errorf("error putting %s in database: %s", headerOrAvi, err)
}
return ma, nil
}
/*
HELPER FUNCTIONS
*/
func (mh *mediaHandler) processHeaderOrAvi(imageBytes []byte, contentType string, headerOrAvi string, accountID string) (*model.MediaAttachment, error) {
var isHeader bool
var isAvatar bool
switch headerOrAvi {
case "header":
isHeader = true
case "avatar":
isAvatar = true
default:
return nil, errors.New("header or avatar not selected")
}
var clean []byte
var err error
switch contentType {
case "image/jpeg":
if clean, err = purgeExif(imageBytes); err != nil {
return nil, fmt.Errorf("error cleaning exif data: %s", err)
}
case "image/png":
if clean, err = purgeExif(imageBytes); err != nil {
return nil, fmt.Errorf("error cleaning exif data: %s", err)
}
case "image/gif":
clean = imageBytes
default:
return nil, errors.New("media type unrecognized")
}
original, err := deriveImage(clean, contentType)
if err != nil {
return nil, fmt.Errorf("error parsing image: %s", err)
}
small, err := deriveThumbnail(clean, contentType)
if err != nil {
return nil, fmt.Errorf("error deriving thumbnail: %s", err)
}
// now put it in storage, take a new uuid for the name of the file so we don't store any unnecessary info about it
extension := strings.Split(contentType, "/")[1]
newMediaID := uuid.NewString()
base := fmt.Sprintf("%s://%s%s", mh.config.StorageConfig.ServeProtocol, mh.config.StorageConfig.ServeHost, mh.config.StorageConfig.ServeBasePath)
// we store the original...
originalPath := fmt.Sprintf("%s/%s/%s/original/%s.%s", base, accountID, headerOrAvi, newMediaID, extension)
if err := mh.storage.StoreFileAt(originalPath, original.image); err != nil {
return nil, fmt.Errorf("storage error: %s", err)
}
// and a thumbnail...
smallPath := fmt.Sprintf("%s/%s/%s/small/%s.%s", base, accountID, headerOrAvi, newMediaID, extension)
if err := mh.storage.StoreFileAt(smallPath, small.image); err != nil {
return nil, fmt.Errorf("storage error: %s", err)
}
ma := &model.MediaAttachment{
ID: newMediaID,
StatusID: "",
RemoteURL: "",
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
Type: model.FileTypeImage,
FileMeta: model.FileMeta{
Original: model.Original{
Width: original.width,
Height: original.height,
Size: original.size,
Aspect: original.aspect,
},
Small: model.Small{
Width: small.width,
Height: small.height,
Size: small.size,
Aspect: small.aspect,
},
},
AccountID: accountID,
Description: "",
ScheduledStatusID: "",
Blurhash: original.blurhash,
Processing: 2,
File: model.File{
Path: originalPath,
ContentType: contentType,
FileSize: len(original.image),
UpdatedAt: time.Now(),
},
Thumbnail: model.Thumbnail{
Path: smallPath,
ContentType: contentType,
FileSize: len(small.image),
UpdatedAt: time.Now(),
RemoteURL: "",
},
Avatar: isAvatar,
Header: isHeader,
}
return ma, nil
}

View file

@ -0,0 +1,159 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package media
import (
"context"
"io/ioutil"
"testing"
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/db/model"
"github.com/superseriousbusiness/gotosocial/internal/storage"
)
type MediaTestSuite struct {
suite.Suite
config *config.Config
log *logrus.Logger
db db.DB
mediaHandler *mediaHandler
mockStorage *storage.MockStorage
}
/*
TEST INFRASTRUCTURE
*/
// SetupSuite sets some variables on the suite that we can use as consts (more or less) throughout
func (suite *MediaTestSuite) SetupSuite() {
// some of our subsequent entities need a log so create this here
log := logrus.New()
log.SetLevel(logrus.TraceLevel)
suite.log = log
// Direct config to local postgres instance
c := config.Empty()
c.Protocol = "http"
c.Host = "localhost"
c.DBConfig = &config.DBConfig{
Type: "postgres",
Address: "localhost",
Port: 5432,
User: "postgres",
Password: "postgres",
Database: "postgres",
ApplicationName: "gotosocial",
}
c.MediaConfig = &config.MediaConfig{
MaxImageSize: 2 << 20,
}
c.StorageConfig = &config.StorageConfig{
Backend: "local",
BasePath: "/tmp",
ServeProtocol: "http",
ServeHost: "localhost",
ServeBasePath: "/fileserver/media",
}
suite.config = c
// use an actual database for this, because it's just easier than mocking one out
database, err := db.New(context.Background(), c, log)
if err != nil {
suite.FailNow(err.Error())
}
suite.db = database
suite.mockStorage = &storage.MockStorage{}
// We don't need storage to do anything for these tests, so just simulate a success and do nothing
suite.mockStorage.On("StoreFileAt", mock.AnythingOfType("string"), mock.AnythingOfType("[]uint8")).Return(nil)
// and finally here's the thing we're actually testing!
suite.mediaHandler = &mediaHandler{
config: suite.config,
db: suite.db,
storage: suite.mockStorage,
log: log,
}
}
func (suite *MediaTestSuite) TearDownSuite() {
if err := suite.db.Stop(context.Background()); err != nil {
logrus.Panicf("error closing db connection: %s", err)
}
}
// SetupTest creates a db connection and creates necessary tables before each test
func (suite *MediaTestSuite) SetupTest() {
// create all the tables we might need in thie suite
models := []interface{}{
&model.Account{},
&model.MediaAttachment{},
}
for _, m := range models {
if err := suite.db.CreateTable(m); err != nil {
logrus.Panicf("db connection error: %s", err)
}
}
}
// TearDownTest drops tables to make sure there's no data in the db
func (suite *MediaTestSuite) TearDownTest() {
// remove all the tables we might have used so it's clear for the next test
models := []interface{}{
&model.Account{},
&model.MediaAttachment{},
}
for _, m := range models {
if err := suite.db.DropTable(m); err != nil {
logrus.Panicf("error dropping table: %s", err)
}
}
}
/*
ACTUAL TESTS
*/
func (suite *MediaTestSuite) TestSetHeaderOrAvatarForAccountID() {
// load test image
f, err := ioutil.ReadFile("./test/test-jpeg.jpg")
assert.Nil(suite.T(), err)
ma, err := suite.mediaHandler.SetHeaderOrAvatarForAccountID(f, "weeeeeee", "header")
assert.Nil(suite.T(), err)
suite.log.Debugf("%+v", ma)
// attachment should have....
assert.Equal(suite.T(), "weeeeeee", ma.AccountID)
assert.Equal(suite.T(), "LjCZnlvyRkRn_NvzRjWF?urqV@f9", ma.Blurhash)
//TODO: add more checks here, cba right now!
}
// TODO: add tests for sad path, gif, png....
func TestMediaTestSuite(t *testing.T) {
suite.Run(t, new(MediaTestSuite))
}

View file

@ -0,0 +1,36 @@
// Code generated by mockery v2.7.4. DO NOT EDIT.
package media
import (
mock "github.com/stretchr/testify/mock"
model "github.com/superseriousbusiness/gotosocial/internal/db/model"
)
// MockMediaHandler is an autogenerated mock type for the MediaHandler type
type MockMediaHandler struct {
mock.Mock
}
// SetHeaderOrAvatarForAccountID provides a mock function with given fields: img, accountID, headerOrAvi
func (_m *MockMediaHandler) SetHeaderOrAvatarForAccountID(img []byte, accountID string, headerOrAvi string) (*model.MediaAttachment, error) {
ret := _m.Called(img, accountID, headerOrAvi)
var r0 *model.MediaAttachment
if rf, ok := ret.Get(0).(func([]byte, string, string) *model.MediaAttachment); ok {
r0 = rf(img, accountID, headerOrAvi)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*model.MediaAttachment)
}
}
var r1 error
if rf, ok := ret.Get(1).(func([]byte, string, string) error); ok {
r1 = rf(img, accountID, headerOrAvi)
} else {
r1 = ret.Error(1)
}
return r0, r1
}

File diff suppressed because one or more lines are too long

Binary file not shown.

After

Width:  |  Height:  |  Size: 8.6 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 293 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 6.6 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 263 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.2 MiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.2 MiB

192
internal/media/util.go Normal file
View file

@ -0,0 +1,192 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package media
import (
"bytes"
"errors"
"fmt"
"image"
"image/gif"
"image/jpeg"
"image/png"
"github.com/buckket/go-blurhash"
"github.com/h2non/filetype"
"github.com/nfnt/resize"
"github.com/superseriousbusiness/exifremove/pkg/exifremove"
)
// parseContentType parses the MIME content type from a file, returning it as a string in the form (eg., "image/jpeg").
// Returns an error if the content type is not something we can process.
func parseContentType(content []byte) (string, error) {
head := make([]byte, 261)
_, err := bytes.NewReader(content).Read(head)
if err != nil {
return "", fmt.Errorf("could not read first magic bytes of file: %s", err)
}
kind, err := filetype.Match(head)
if err != nil {
return "", err
}
if kind == filetype.Unknown {
return "", errors.New("filetype unknown")
}
return kind.MIME.Value, nil
}
// supportedImageType checks mime type of an image against a slice of accepted types,
// and returns True if the mime type is accepted.
func supportedImageType(mimeType string) bool {
acceptedImageTypes := []string{
"image/jpeg",
"image/gif",
"image/png",
}
for _, accepted := range acceptedImageTypes {
if mimeType == accepted {
return true
}
}
return false
}
// purgeExif is a little wrapper for the action of removing exif data from an image.
// Only pass pngs or jpegs to this function.
func purgeExif(b []byte) ([]byte, error) {
if len(b) == 0 {
return nil, errors.New("passed image was not valid")
}
clean, err := exifremove.Remove(b)
if err != nil {
return nil, fmt.Errorf("could not purge exif from image: %s", err)
}
if len(clean) == 0 {
return nil, errors.New("purged image was not valid")
}
return clean, nil
}
func deriveImage(b []byte, extension string) (*imageAndMeta, error) {
var i image.Image
var err error
switch extension {
case "image/jpeg":
i, err = jpeg.Decode(bytes.NewReader(b))
if err != nil {
return nil, err
}
case "image/png":
i, err = png.Decode(bytes.NewReader(b))
if err != nil {
return nil, err
}
case "image/gif":
i, err = gif.Decode(bytes.NewReader(b))
if err != nil {
return nil, err
}
default:
return nil, fmt.Errorf("extension %s not recognised", extension)
}
width := i.Bounds().Size().X
height := i.Bounds().Size().Y
size := width * height
aspect := float64(width) / float64(height)
bh, err := blurhash.Encode(4, 3, i)
if err != nil {
return nil, fmt.Errorf("error generating blurhash: %s", err)
}
out := &bytes.Buffer{}
if err := jpeg.Encode(out, i, nil); err != nil {
return nil, err
}
return &imageAndMeta{
image: out.Bytes(),
width: width,
height: height,
size: size,
aspect: aspect,
blurhash: bh,
}, nil
}
// deriveThumbnailFromImage returns a byte slice and metadata for a 256-pixel-width thumbnail
// of a given jpeg, png, or gif, or an error if something goes wrong.
//
// Note that the aspect ratio of the image will be retained,
// so it will not necessarily be a square.
func deriveThumbnail(b []byte, extension string) (*imageAndMeta, error) {
var i image.Image
var err error
switch extension {
case "image/jpeg":
i, err = jpeg.Decode(bytes.NewReader(b))
if err != nil {
return nil, err
}
case "image/png":
i, err = png.Decode(bytes.NewReader(b))
if err != nil {
return nil, err
}
case "image/gif":
i, err = gif.Decode(bytes.NewReader(b))
if err != nil {
return nil, err
}
default:
return nil, fmt.Errorf("extension %s not recognised", extension)
}
thumb := resize.Thumbnail(256, 256, i, resize.NearestNeighbor)
width := thumb.Bounds().Size().X
height := thumb.Bounds().Size().Y
size := width * height
aspect := float64(width) / float64(height)
out := &bytes.Buffer{}
if err := jpeg.Encode(out, thumb, nil); err != nil {
return nil, err
}
return &imageAndMeta{
image: out.Bytes(),
width: width,
height: height,
size: size,
aspect: aspect,
}, nil
}
type imageAndMeta struct {
image []byte
width int
height int
size int
aspect float64
blurhash string
}

147
internal/media/util_test.go Normal file
View file

@ -0,0 +1,147 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package media
import (
"io/ioutil"
"testing"
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/suite"
)
type MediaUtilTestSuite struct {
suite.Suite
log *logrus.Logger
}
/*
TEST INFRASTRUCTURE
*/
// SetupSuite sets some variables on the suite that we can use as consts (more or less) throughout
func (suite *MediaUtilTestSuite) SetupSuite() {
// some of our subsequent entities need a log so create this here
log := logrus.New()
log.SetLevel(logrus.TraceLevel)
suite.log = log
}
func (suite *MediaUtilTestSuite) TearDownSuite() {
}
// SetupTest creates a db connection and creates necessary tables before each test
func (suite *MediaUtilTestSuite) SetupTest() {
}
// TearDownTest drops tables to make sure there's no data in the db
func (suite *MediaUtilTestSuite) TearDownTest() {
}
/*
ACTUAL TESTS
*/
func (suite *MediaUtilTestSuite) TestParseContentTypeOK() {
f, err := ioutil.ReadFile("./test/test-jpeg.jpg")
assert.Nil(suite.T(), err)
ct, err := parseContentType(f)
assert.Nil(suite.T(), err)
assert.Equal(suite.T(), "image/jpeg", ct)
}
func (suite *MediaUtilTestSuite) TestParseContentTypeNotOK() {
f, err := ioutil.ReadFile("./test/test-corrupted.jpg")
assert.Nil(suite.T(), err)
ct, err := parseContentType(f)
assert.NotNil(suite.T(), err)
assert.Equal(suite.T(), "", ct)
assert.Equal(suite.T(), "filetype unknown", err.Error())
}
func (suite *MediaUtilTestSuite) TestRemoveEXIF() {
// load and validate image
b, err := ioutil.ReadFile("./test/test-with-exif.jpg")
assert.Nil(suite.T(), err)
// clean it up and validate the clean version
clean, err := purgeExif(b)
assert.Nil(suite.T(), err)
// compare it to our stored sample
sampleBytes, err := ioutil.ReadFile("./test/test-without-exif.jpg")
assert.Nil(suite.T(), err)
assert.EqualValues(suite.T(), sampleBytes, clean)
}
func (suite *MediaUtilTestSuite) TestDeriveImageFromJPEG() {
// load image
b, err := ioutil.ReadFile("./test/test-jpeg.jpg")
assert.Nil(suite.T(), err)
// clean it up and validate the clean version
imageAndMeta, err := deriveImage(b, "image/jpeg")
assert.Nil(suite.T(), err)
assert.Equal(suite.T(), 1920, imageAndMeta.width)
assert.Equal(suite.T(), 1080, imageAndMeta.height)
assert.Equal(suite.T(), 1.7777777777777777, imageAndMeta.aspect)
assert.Equal(suite.T(), 2073600, imageAndMeta.size)
assert.Equal(suite.T(), "LjCZnlvyRkRn_NvzRjWF?urqV@f9", imageAndMeta.blurhash)
// assert that the final image is what we would expect
sampleBytes, err := ioutil.ReadFile("./test/test-jpeg-processed.jpg")
assert.Nil(suite.T(), err)
assert.EqualValues(suite.T(), sampleBytes, imageAndMeta.image)
}
func (suite *MediaUtilTestSuite) TestDeriveThumbnailFromJPEG() {
// load image
b, err := ioutil.ReadFile("./test/test-jpeg.jpg")
assert.Nil(suite.T(), err)
// clean it up and validate the clean version
imageAndMeta, err := deriveThumbnail(b, "image/jpeg")
assert.Nil(suite.T(), err)
assert.Equal(suite.T(), 256, imageAndMeta.width)
assert.Equal(suite.T(), 144, imageAndMeta.height)
assert.Equal(suite.T(), 1.7777777777777777, imageAndMeta.aspect)
assert.Equal(suite.T(), 36864, imageAndMeta.size)
sampleBytes, err := ioutil.ReadFile("./test/test-jpeg-thumbnail.jpg")
assert.Nil(suite.T(), err)
assert.EqualValues(suite.T(), sampleBytes, imageAndMeta.image)
}
func (suite *MediaUtilTestSuite) TestSupportedImageTypes() {
ok := supportedImageType("image/jpeg")
assert.True(suite.T(), ok)
ok = supportedImageType("image/bmp")
assert.False(suite.T(), ok)
}
func TestMediaUtilTestSuite(t *testing.T) {
suite.Run(t, new(MediaUtilTestSuite))
}

View file

@ -1,510 +0,0 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
// Package oauth is a module that provides oauth functionality to a router.
// It adds the following paths:
// /api/v1/apps
// /auth/sign_in
// /oauth/token
// /oauth/authorize
// It also includes the oauthTokenMiddleware, which can be attached to a router to authenticate every request by Bearer token.
package oauth
import (
"fmt"
"net/http"
"net/url"
"github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
"github.com/gotosocial/gotosocial/internal/db"
"github.com/gotosocial/gotosocial/internal/gtsmodel"
"github.com/gotosocial/gotosocial/internal/module"
"github.com/gotosocial/gotosocial/internal/router"
"github.com/gotosocial/gotosocial/pkg/mastotypes"
"github.com/gotosocial/oauth2/v4"
"github.com/gotosocial/oauth2/v4/errors"
"github.com/gotosocial/oauth2/v4/manage"
"github.com/gotosocial/oauth2/v4/server"
"github.com/sirupsen/logrus"
"golang.org/x/crypto/bcrypt"
)
const (
appsPath = "/api/v1/apps"
authSignInPath = "/auth/sign_in"
oauthTokenPath = "/oauth/token"
oauthAuthorizePath = "/oauth/authorize"
)
// oauthModule is an oauth2 oauthModule that satisfies the ClientAPIModule interface
type oauthModule struct {
oauthManager *manage.Manager
oauthServer *server.Server
db db.DB
log *logrus.Logger
}
type login struct {
Email string `form:"username"`
Password string `form:"password"`
}
// New returns a new oauth module
func New(ts oauth2.TokenStore, cs oauth2.ClientStore, db db.DB, log *logrus.Logger) module.ClientAPIModule {
manager := manage.NewDefaultManager()
manager.MapTokenStorage(ts)
manager.MapClientStorage(cs)
manager.SetAuthorizeCodeTokenCfg(manage.DefaultAuthorizeCodeTokenCfg)
sc := &server.Config{
TokenType: "Bearer",
// Must follow the spec.
AllowGetAccessRequest: false,
// Support only the non-implicit flow.
AllowedResponseTypes: []oauth2.ResponseType{oauth2.Code},
// Allow:
// - Authorization Code (for first & third parties)
AllowedGrantTypes: []oauth2.GrantType{
oauth2.AuthorizationCode,
},
AllowedCodeChallengeMethods: []oauth2.CodeChallengeMethod{oauth2.CodeChallengePlain},
}
srv := server.NewServer(sc, manager)
srv.SetInternalErrorHandler(func(err error) *errors.Response {
log.Errorf("internal oauth error: %s", err)
return nil
})
srv.SetResponseErrorHandler(func(re *errors.Response) {
log.Errorf("internal response error: %s", re.Error)
})
m := &oauthModule{
oauthManager: manager,
oauthServer: srv,
db: db,
log: log,
}
m.oauthServer.SetUserAuthorizationHandler(m.userAuthorizationHandler)
m.oauthServer.SetClientInfoHandler(server.ClientFormHandler)
return m
}
// Route satisfies the RESTAPIModule interface
func (m *oauthModule) Route(s router.Router) error {
s.AttachHandler(http.MethodPost, appsPath, m.appsPOSTHandler)
s.AttachHandler(http.MethodGet, authSignInPath, m.signInGETHandler)
s.AttachHandler(http.MethodPost, authSignInPath, m.signInPOSTHandler)
s.AttachHandler(http.MethodPost, oauthTokenPath, m.tokenPOSTHandler)
s.AttachHandler(http.MethodGet, oauthAuthorizePath, m.authorizeGETHandler)
s.AttachHandler(http.MethodPost, oauthAuthorizePath, m.authorizePOSTHandler)
s.AttachMiddleware(m.oauthTokenMiddleware)
return nil
}
/*
MAIN HANDLERS -- serve these through a server/router
*/
// appsPOSTHandler should be served at https://example.org/api/v1/apps
// It is equivalent to: https://docs.joinmastodon.org/methods/apps/
func (m *oauthModule) appsPOSTHandler(c *gin.Context) {
l := m.log.WithField("func", "AppsPOSTHandler")
l.Trace("entering AppsPOSTHandler")
form := &mastotypes.ApplicationPOSTRequest{}
if err := c.ShouldBind(form); err != nil {
c.JSON(http.StatusUnprocessableEntity, gin.H{"error": err.Error()})
return
}
// permitted length for most fields
permittedLength := 64
// redirect can be a bit bigger because we probably need to encode data in the redirect uri
permittedRedirect := 256
// check lengths of fields before proceeding so the user can't spam huge entries into the database
if len(form.ClientName) > permittedLength {
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("client_name must be less than %d bytes", permittedLength)})
return
}
if len(form.Website) > permittedLength {
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("website must be less than %d bytes", permittedLength)})
return
}
if len(form.RedirectURIs) > permittedRedirect {
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("redirect_uris must be less than %d bytes", permittedRedirect)})
return
}
if len(form.Scopes) > permittedLength {
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("scopes must be less than %d bytes", permittedLength)})
return
}
// set default 'read' for scopes if it's not set, this follows the default of the mastodon api https://docs.joinmastodon.org/methods/apps/
var scopes string
if form.Scopes == "" {
scopes = "read"
} else {
scopes = form.Scopes
}
// generate new IDs for this application and its associated client
clientID := uuid.NewString()
clientSecret := uuid.NewString()
vapidKey := uuid.NewString()
// generate the application to put in the database
app := &gtsmodel.Application{
Name: form.ClientName,
Website: form.Website,
RedirectURI: form.RedirectURIs,
ClientID: clientID,
ClientSecret: clientSecret,
Scopes: scopes,
VapidKey: vapidKey,
}
// chuck it in the db
if err := m.db.Put(app); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
// now we need to model an oauth client from the application that the oauth library can use
oc := &oauthClient{
ID: clientID,
Secret: clientSecret,
Domain: form.RedirectURIs,
UserID: "", // This client isn't yet associated with a specific user, it's just an app client right now
}
// chuck it in the db
if err := m.db.Put(oc); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
// done, return the new app information per the spec here: https://docs.joinmastodon.org/methods/apps/
c.JSON(http.StatusOK, app.ToMastotype())
}
// signInGETHandler should be served at https://example.org/auth/sign_in.
// The idea is to present a sign in page to the user, where they can enter their username and password.
// The form will then POST to the sign in page, which will be handled by SignInPOSTHandler
func (m *oauthModule) signInGETHandler(c *gin.Context) {
m.log.WithField("func", "SignInGETHandler").Trace("serving sign in html")
c.HTML(http.StatusOK, "sign-in.tmpl", gin.H{})
}
// signInPOSTHandler should be served at https://example.org/auth/sign_in.
// The idea is to present a sign in page to the user, where they can enter their username and password.
// The handler will then redirect to the auth handler served at /auth
func (m *oauthModule) signInPOSTHandler(c *gin.Context) {
l := m.log.WithField("func", "SignInPOSTHandler")
s := sessions.Default(c)
form := &login{}
if err := c.ShouldBind(form); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
l.Tracef("parsed form: %+v", form)
userid, err := m.validatePassword(form.Email, form.Password)
if err != nil {
c.String(http.StatusForbidden, err.Error())
return
}
s.Set("userid", userid)
if err := s.Save(); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
l.Trace("redirecting to auth page")
c.Redirect(http.StatusFound, oauthAuthorizePath)
}
// tokenPOSTHandler should be served as a POST at https://example.org/oauth/token
// The idea here is to serve an oauth access token to a user, which can be used for authorizing against non-public APIs.
// See https://docs.joinmastodon.org/methods/apps/oauth/#obtain-a-token
func (m *oauthModule) tokenPOSTHandler(c *gin.Context) {
l := m.log.WithField("func", "TokenPOSTHandler")
l.Trace("entered TokenPOSTHandler")
if err := m.oauthServer.HandleTokenRequest(c.Writer, c.Request); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
}
}
// authorizeGETHandler should be served as GET at https://example.org/oauth/authorize
// The idea here is to present an oauth authorize page to the user, with a button
// that they have to click to accept. See here: https://docs.joinmastodon.org/methods/apps/oauth/#authorize-a-user
func (m *oauthModule) authorizeGETHandler(c *gin.Context) {
l := m.log.WithField("func", "AuthorizeGETHandler")
s := sessions.Default(c)
// UserID will be set in the session by AuthorizePOSTHandler if the caller has already gone through the authentication flow
// If it's not set, then we don't know yet who the user is, so we need to redirect them to the sign in page.
userID, ok := s.Get("userid").(string)
if !ok || userID == "" {
l.Trace("userid was empty, parsing form then redirecting to sign in page")
if err := parseAuthForm(c, l); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
} else {
c.Redirect(http.StatusFound, authSignInPath)
}
return
}
// We can use the client_id on the session to retrieve info about the app associated with the client_id
clientID, ok := s.Get("client_id").(string)
if !ok || clientID == "" {
c.JSON(http.StatusInternalServerError, gin.H{"error": "no client_id found in session"})
return
}
app := &gtsmodel.Application{
ClientID: clientID,
}
if err := m.db.GetWhere("client_id", app.ClientID, app); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("no application found for client id %s", clientID)})
return
}
// we can also use the userid of the user to fetch their username from the db to greet them nicely <3
user := &gtsmodel.User{
ID: userID,
}
if err := m.db.GetByID(user.ID, user); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
acct := &gtsmodel.Account{
ID: user.AccountID,
}
if err := m.db.GetByID(acct.ID, acct); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
// Finally we should also get the redirect and scope of this particular request, as stored in the session.
redirect, ok := s.Get("redirect_uri").(string)
if !ok || redirect == "" {
c.JSON(http.StatusInternalServerError, gin.H{"error": "no redirect_uri found in session"})
return
}
scope, ok := s.Get("scope").(string)
if !ok || scope == "" {
c.JSON(http.StatusInternalServerError, gin.H{"error": "no scope found in session"})
return
}
// the authorize template will display a form to the user where they can get some information
// about the app that's trying to authorize, and the scope of the request.
// They can then approve it if it looks OK to them, which will POST to the AuthorizePOSTHandler
l.Trace("serving authorize html")
c.HTML(http.StatusOK, "authorize.tmpl", gin.H{
"appname": app.Name,
"appwebsite": app.Website,
"redirect": redirect,
"scope": scope,
"user": acct.Username,
})
}
// authorizePOSTHandler should be served as POST at https://example.org/oauth/authorize
// At this point we assume that the user has A) logged in and B) accepted that the app should act for them,
// so we should proceed with the authentication flow and generate an oauth token for them if we can.
// See here: https://docs.joinmastodon.org/methods/apps/oauth/#authorize-a-user
func (m *oauthModule) authorizePOSTHandler(c *gin.Context) {
l := m.log.WithField("func", "AuthorizePOSTHandler")
s := sessions.Default(c)
// At this point we know the user has said 'yes' to allowing the application and oauth client
// work for them, so we can set the
// We need to retrieve the original form submitted to the authorizeGEThandler, and
// recreate it on the request so that it can be used further by the oauth2 library.
// So first fetch all the values from the session.
forceLogin, ok := s.Get("force_login").(string)
if !ok {
c.JSON(http.StatusBadRequest, gin.H{"error": "session missing force_login"})
return
}
responseType, ok := s.Get("response_type").(string)
if !ok || responseType == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "session missing response_type"})
return
}
clientID, ok := s.Get("client_id").(string)
if !ok || clientID == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "session missing client_id"})
return
}
redirectURI, ok := s.Get("redirect_uri").(string)
if !ok || redirectURI == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "session missing redirect_uri"})
return
}
scope, ok := s.Get("scope").(string)
if !ok {
c.JSON(http.StatusBadRequest, gin.H{"error": "session missing scope"})
return
}
userID, ok := s.Get("userid").(string)
if !ok {
c.JSON(http.StatusBadRequest, gin.H{"error": "session missing userid"})
return
}
// we're done with the session so we can clear it now
s.Clear()
// now set the values on the request
values := url.Values{}
values.Set("force_login", forceLogin)
values.Set("response_type", responseType)
values.Set("client_id", clientID)
values.Set("redirect_uri", redirectURI)
values.Set("scope", scope)
values.Set("userid", userID)
c.Request.Form = values
l.Tracef("values on request set to %+v", c.Request.Form)
// and proceed with authorization using the oauth2 library
if err := m.oauthServer.HandleAuthorizeRequest(c.Writer, c.Request); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
}
}
/*
MIDDLEWARE
*/
// oauthTokenMiddleware
func (m *oauthModule) oauthTokenMiddleware(c *gin.Context) {
l := m.log.WithField("func", "ValidatePassword")
l.Trace("entering OauthTokenMiddleware")
if ti, err := m.oauthServer.ValidationBearerToken(c.Request); err == nil {
l.Tracef("authenticated user %s with bearer token, scope is %s", ti.GetUserID(), ti.GetScope())
c.Set("authenticated_user", ti.GetUserID())
} else {
l.Trace("continuing with unauthenticated request")
}
}
/*
SUB-HANDLERS -- don't serve these directly, they should be attached to the oauth2 server or used inside handler funcs
*/
// validatePassword takes an email address and a password.
// The goal is to authenticate the password against the one for that email
// address stored in the database. If OK, we return the userid (a uuid) for that user,
// so that it can be used in further Oauth flows to generate a token/retreieve an oauth client from the db.
func (m *oauthModule) validatePassword(email string, password string) (userid string, err error) {
l := m.log.WithField("func", "ValidatePassword")
// make sure an email/password was provided and bail if not
if email == "" || password == "" {
l.Debug("email or password was not provided")
return incorrectPassword()
}
// first we select the user from the database based on email address, bail if no user found for that email
gtsUser := &gtsmodel.User{}
if err := m.db.GetWhere("email", email, gtsUser); err != nil {
l.Debugf("user %s was not retrievable from db during oauth authorization attempt: %s", email, err)
return incorrectPassword()
}
// make sure a password is actually set and bail if not
if gtsUser.EncryptedPassword == "" {
l.Warnf("encrypted password for user %s was empty for some reason", gtsUser.Email)
return incorrectPassword()
}
// compare the provided password with the encrypted one from the db, bail if they don't match
if err := bcrypt.CompareHashAndPassword([]byte(gtsUser.EncryptedPassword), []byte(password)); err != nil {
l.Debugf("password hash didn't match for user %s during login attempt: %s", gtsUser.Email, err)
return incorrectPassword()
}
// If we've made it this far the email/password is correct, so we can just return the id of the user.
userid = gtsUser.ID
l.Tracef("returning (%s, %s)", userid, err)
return
}
// incorrectPassword is just a little helper function to use in the ValidatePassword function
func incorrectPassword() (string, error) {
return "", errors.New("password/email combination was incorrect")
}
// userAuthorizationHandler gets the user's ID from the 'userid' field of the request form,
// or redirects to the /auth/sign_in page, if this key is not present.
func (m *oauthModule) userAuthorizationHandler(w http.ResponseWriter, r *http.Request) (userID string, err error) {
l := m.log.WithField("func", "UserAuthorizationHandler")
userID = r.FormValue("userid")
if userID == "" {
return "", errors.New("userid was empty, redirecting to sign in page")
}
l.Tracef("returning userID %s", userID)
return userID, err
}
// parseAuthForm parses the OAuthAuthorize form in the gin context, and stores
// the values in the form into the session.
func parseAuthForm(c *gin.Context, l *logrus.Entry) error {
s := sessions.Default(c)
// first make sure they've filled out the authorize form with the required values
form := &mastotypes.OAuthAuthorize{}
if err := c.ShouldBind(form); err != nil {
return err
}
l.Tracef("parsed form: %+v", form)
// these fields are *required* so check 'em
if form.ResponseType == "" || form.ClientID == "" || form.RedirectURI == "" {
return errors.New("missing one of: response_type, client_id or redirect_uri")
}
// set default scope to read
if form.Scope == "" {
form.Scope = "read"
}
// save these values from the form so we can use them elsewhere in the session
s.Set("force_login", form.ForceLogin)
s.Set("response_type", form.ResponseType)
s.Set("client_id", form.ClientID)
s.Set("redirect_uri", form.RedirectURI)
s.Set("scope", form.Scope)
return s.Save()
}

View file

@ -20,11 +20,10 @@ package oauth
import (
"context"
"fmt"
"github.com/gotosocial/gotosocial/internal/db"
"github.com/gotosocial/oauth2/v4"
"github.com/gotosocial/oauth2/v4/models"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/oauth2/v4"
"github.com/superseriousbusiness/oauth2/v4/models"
)
type clientStore struct {
@ -39,17 +38,17 @@ func newClientStore(db db.DB) oauth2.ClientStore {
}
func (cs *clientStore) GetByID(ctx context.Context, clientID string) (oauth2.ClientInfo, error) {
poc := &oauthClient{
poc := &Client{
ID: clientID,
}
if err := cs.db.GetByID(clientID, poc); err != nil {
return nil, fmt.Errorf("database error: %s", err)
return nil, err
}
return models.New(poc.ID, poc.Secret, poc.Domain, poc.UserID), nil
}
func (cs *clientStore) Set(ctx context.Context, id string, cli oauth2.ClientInfo) error {
poc := &oauthClient{
poc := &Client{
ID: cli.GetID(),
Secret: cli.GetSecret(),
Domain: cli.GetDomain(),
@ -59,13 +58,13 @@ func (cs *clientStore) Set(ctx context.Context, id string, cli oauth2.ClientInfo
}
func (cs *clientStore) Delete(ctx context.Context, id string) error {
poc := &oauthClient{
poc := &Client{
ID: id,
}
return cs.db.DeleteByID(id, poc)
}
type oauthClient struct {
type Client struct {
ID string
Secret string
Domain string

View file

@ -21,11 +21,11 @@ import (
"context"
"testing"
"github.com/gotosocial/gotosocial/internal/config"
"github.com/gotosocial/gotosocial/internal/db"
"github.com/gotosocial/oauth2/v4/models"
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/oauth2/v4/models"
)
type PgClientStoreTestSuite struct {
@ -69,7 +69,7 @@ func (suite *PgClientStoreTestSuite) SetupTest() {
suite.db = db
models := []interface{}{
&oauthClient{},
&Client{},
}
for _, m := range models {
@ -82,7 +82,7 @@ func (suite *PgClientStoreTestSuite) SetupTest() {
// TearDownTest drops the oauth_clients table and closes the pg connection after each test
func (suite *PgClientStoreTestSuite) TearDownTest() {
models := []interface{}{
&oauthClient{},
&Client{},
}
for _, m := range models {
if err := suite.db.DropTable(m); err != nil {
@ -136,7 +136,7 @@ func (suite *PgClientStoreTestSuite) TestClientSetAndDelete() {
// try to get the deleted client; we should get an error
deletedClient, err := cs.GetByID(context.Background(), suite.testClientID)
suite.Assert().Nil(deletedClient)
suite.Assert().NotNil(err)
suite.Assert().EqualValues(db.ErrNoEntries{}, err)
}
func TestPgClientStoreTestSuite(t *testing.T) {

View file

@ -0,0 +1,89 @@
// Code generated by mockery v2.7.4. DO NOT EDIT.
package oauth
import (
http "net/http"
mock "github.com/stretchr/testify/mock"
oauth2 "github.com/superseriousbusiness/oauth2/v4"
)
// MockServer is an autogenerated mock type for the Server type
type MockServer struct {
mock.Mock
}
// GenerateUserAccessToken provides a mock function with given fields: ti, clientSecret, userID
func (_m *MockServer) GenerateUserAccessToken(ti oauth2.TokenInfo, clientSecret string, userID string) (oauth2.TokenInfo, error) {
ret := _m.Called(ti, clientSecret, userID)
var r0 oauth2.TokenInfo
if rf, ok := ret.Get(0).(func(oauth2.TokenInfo, string, string) oauth2.TokenInfo); ok {
r0 = rf(ti, clientSecret, userID)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(oauth2.TokenInfo)
}
}
var r1 error
if rf, ok := ret.Get(1).(func(oauth2.TokenInfo, string, string) error); ok {
r1 = rf(ti, clientSecret, userID)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// HandleAuthorizeRequest provides a mock function with given fields: w, r
func (_m *MockServer) HandleAuthorizeRequest(w http.ResponseWriter, r *http.Request) error {
ret := _m.Called(w, r)
var r0 error
if rf, ok := ret.Get(0).(func(http.ResponseWriter, *http.Request) error); ok {
r0 = rf(w, r)
} else {
r0 = ret.Error(0)
}
return r0
}
// HandleTokenRequest provides a mock function with given fields: w, r
func (_m *MockServer) HandleTokenRequest(w http.ResponseWriter, r *http.Request) error {
ret := _m.Called(w, r)
var r0 error
if rf, ok := ret.Get(0).(func(http.ResponseWriter, *http.Request) error); ok {
r0 = rf(w, r)
} else {
r0 = ret.Error(0)
}
return r0
}
// ValidationBearerToken provides a mock function with given fields: r
func (_m *MockServer) ValidationBearerToken(r *http.Request) (oauth2.TokenInfo, error) {
ret := _m.Called(r)
var r0 oauth2.TokenInfo
if rf, ok := ret.Get(0).(func(*http.Request) oauth2.TokenInfo); ok {
r0 = rf(r)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(oauth2.TokenInfo)
}
}
var r1 error
if rf, ok := ret.Get(1).(func(*http.Request) error); ok {
r1 = rf(r)
} else {
r1 = ret.Error(1)
}
return r0, r1
}

View file

@ -0,0 +1,21 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package oauth
// TODO: write tests

254
internal/oauth/server.go Normal file
View file

@ -0,0 +1,254 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package oauth
import (
"context"
"fmt"
"net/http"
"github.com/gin-gonic/gin"
"github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/db/model"
"github.com/superseriousbusiness/oauth2/v4"
"github.com/superseriousbusiness/oauth2/v4/errors"
"github.com/superseriousbusiness/oauth2/v4/manage"
"github.com/superseriousbusiness/oauth2/v4/server"
)
const (
SessionAuthorizedToken = "authorized_token"
// SessionAuthorizedUser is the key set in the gin context for the id of
// a User who has successfully passed Bearer token authorization.
// The interface returned from grabbing this key should be parsed as a *gtsmodel.User
SessionAuthorizedUser = "authorized_user"
// SessionAuthorizedAccount is the key set in the gin context for the Account
// of a User who has successfully passed Bearer token authorization.
// The interface returned from grabbing this key should be parsed as a *gtsmodel.Account
SessionAuthorizedAccount = "authorized_account"
// SessionAuthorizedAccount is the key set in the gin context for the Application
// of a Client who has successfully passed Bearer token authorization.
// The interface returned from grabbing this key should be parsed as a *gtsmodel.Application
SessionAuthorizedApplication = "authorized_app"
)
// Server wraps some oauth2 server functions in an interface, exposing only what is needed
type Server interface {
HandleTokenRequest(w http.ResponseWriter, r *http.Request) error
HandleAuthorizeRequest(w http.ResponseWriter, r *http.Request) error
ValidationBearerToken(r *http.Request) (oauth2.TokenInfo, error)
GenerateUserAccessToken(ti oauth2.TokenInfo, clientSecret string, userID string) (accessToken oauth2.TokenInfo, err error)
}
// s fulfils the Server interface using the underlying oauth2 server
type s struct {
server *server.Server
log *logrus.Logger
}
type Authed struct {
Token oauth2.TokenInfo
Application *model.Application
User *model.User
Account *model.Account
}
// GetAuthed is a convenience function for returning an Authed struct from a gin context.
// In essence, it tries to extract a token, application, user, and account from the context,
// and then sets them on a struct for convenience.
//
// If any are not present in the context, they will be set to nil on the returned Authed struct.
//
// If *ALL* are not present, then nil and an error will be returned.
//
// If something goes wrong during parsing, then nil and an error will be returned (consider this not authed).
func GetAuthed(c *gin.Context) (*Authed, error) {
ctx := c.Copy()
a := &Authed{}
var i interface{}
var ok bool
i, ok = ctx.Get(SessionAuthorizedToken)
if ok {
parsed, ok := i.(oauth2.TokenInfo)
if !ok {
return nil, errors.New("could not parse token from session context")
}
a.Token = parsed
}
i, ok = ctx.Get(SessionAuthorizedApplication)
if ok {
parsed, ok := i.(*model.Application)
if !ok {
return nil, errors.New("could not parse application from session context")
}
a.Application = parsed
}
i, ok = ctx.Get(SessionAuthorizedUser)
if ok {
parsed, ok := i.(*model.User)
if !ok {
return nil, errors.New("could not parse user from session context")
}
a.User = parsed
}
i, ok = ctx.Get(SessionAuthorizedAccount)
if ok {
parsed, ok := i.(*model.Account)
if !ok {
return nil, errors.New("could not parse account from session context")
}
a.Account = parsed
}
if a.Token == nil && a.Application == nil && a.User == nil && a.Account == nil {
return nil, errors.New("not authorized")
}
return a, nil
}
// MustAuth is like GetAuthed, but will fail if one of the requirements is not met.
func MustAuth(c *gin.Context, requireToken bool, requireApp bool, requireUser bool, requireAccount bool) (*Authed, error) {
a, err := GetAuthed(c)
if err != nil {
return nil, err
}
if requireToken && a.Token == nil {
return nil, errors.New("token not supplied")
}
if requireApp && a.Application == nil {
return nil, errors.New("application not supplied")
}
if requireUser && a.User == nil {
return nil, errors.New("user not supplied")
}
if requireAccount && a.Account == nil {
return nil, errors.New("account not supplied")
}
return a, nil
}
// HandleTokenRequest wraps the oauth2 library's HandleTokenRequest function
func (s *s) HandleTokenRequest(w http.ResponseWriter, r *http.Request) error {
return s.server.HandleTokenRequest(w, r)
}
// HandleAuthorizeRequest wraps the oauth2 library's HandleAuthorizeRequest function
func (s *s) HandleAuthorizeRequest(w http.ResponseWriter, r *http.Request) error {
return s.server.HandleAuthorizeRequest(w, r)
}
// ValidationBearerToken wraps the oauth2 library's ValidationBearerToken function
func (s *s) ValidationBearerToken(r *http.Request) (oauth2.TokenInfo, error) {
return s.server.ValidationBearerToken(r)
}
// GenerateUserAccessToken shortcuts the normal oauth flow to create an user-level
// bearer token *without* requiring that user to log in. This is useful when we
// need to create a token for new users who haven't validated their email or logged in yet.
//
// The ti parameter refers to an existing Application token that was used to make the upstream
// request. This token needs to be validated and exist in database in order to create a new token.
func (s *s) GenerateUserAccessToken(ti oauth2.TokenInfo, clientSecret string, userID string) (oauth2.TokenInfo, error) {
authToken, err := s.server.Manager.GenerateAuthToken(context.Background(), oauth2.Code, &oauth2.TokenGenerateRequest{
ClientID: ti.GetClientID(),
ClientSecret: clientSecret,
UserID: userID,
RedirectURI: ti.GetRedirectURI(),
Scope: ti.GetScope(),
})
if err != nil {
return nil, fmt.Errorf("error generating auth token: %s", err)
}
if authToken == nil {
return nil, errors.New("generated auth token was empty")
}
s.log.Tracef("obtained auth token: %+v", authToken)
accessToken, err := s.server.Manager.GenerateAccessToken(context.Background(), oauth2.AuthorizationCode, &oauth2.TokenGenerateRequest{
ClientID: authToken.GetClientID(),
ClientSecret: clientSecret,
RedirectURI: authToken.GetRedirectURI(),
Scope: authToken.GetScope(),
Code: authToken.GetCode(),
})
if err != nil {
return nil, fmt.Errorf("error generating user-level access token: %s", err)
}
if accessToken == nil {
return nil, errors.New("generated user-level access token was empty")
}
s.log.Tracef("obtained user-level access token: %+v", accessToken)
return accessToken, nil
}
func New(database db.DB, log *logrus.Logger) Server {
ts := newTokenStore(context.Background(), database, log)
cs := newClientStore(database)
manager := manage.NewDefaultManager()
manager.MapTokenStorage(ts)
manager.MapClientStorage(cs)
manager.SetAuthorizeCodeTokenCfg(manage.DefaultAuthorizeCodeTokenCfg)
sc := &server.Config{
TokenType: "Bearer",
// Must follow the spec.
AllowGetAccessRequest: false,
// Support only the non-implicit flow.
AllowedResponseTypes: []oauth2.ResponseType{oauth2.Code},
// Allow:
// - Authorization Code (for first & third parties)
// - Client Credentials (for applications)
AllowedGrantTypes: []oauth2.GrantType{
oauth2.AuthorizationCode,
oauth2.ClientCredentials,
},
AllowedCodeChallengeMethods: []oauth2.CodeChallengeMethod{oauth2.CodeChallengePlain},
}
srv := server.NewServer(sc, manager)
srv.SetInternalErrorHandler(func(err error) *errors.Response {
log.Errorf("internal oauth error: %s", err)
return nil
})
srv.SetResponseErrorHandler(func(re *errors.Response) {
log.Errorf("internal response error: %s", re.Error)
})
srv.SetUserAuthorizationHandler(func(w http.ResponseWriter, r *http.Request) (string, error) {
userID := r.FormValue("userid")
if userID == "" {
return "", errors.New("userid was empty")
}
return userID, nil
})
srv.SetClientInfoHandler(server.ClientFormHandler)
return &s{
server: srv,
log: log,
}
}

View file

@ -24,10 +24,10 @@ import (
"fmt"
"time"
"github.com/gotosocial/gotosocial/internal/db"
"github.com/gotosocial/oauth2/v4"
"github.com/gotosocial/oauth2/v4/models"
"github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/oauth2/v4"
"github.com/superseriousbusiness/oauth2/v4/models"
)
// tokenStore is an implementation of oauth2.TokenStore, which uses our db interface as a storage backend.
@ -70,7 +70,7 @@ func newTokenStore(ctx context.Context, db db.DB, log *logrus.Logger) oauth2.Tok
func (pts *tokenStore) sweep() error {
// select *all* tokens from the db
// todo: if this becomes expensive (ie., there are fucking LOADS of tokens) then figure out a better way.
tokens := new([]*oauthToken)
tokens := new([]*Token)
if err := pts.db.GetAll(tokens); err != nil {
return err
}
@ -92,7 +92,7 @@ func (pts *tokenStore) sweep() error {
}
// Create creates and store the new token information.
// For the original implementation, see https://github.com/gotosocial/oauth2/blob/master/store/token.go#L34
// For the original implementation, see https://github.com/superseriousbusiness/oauth2/blob/master/store/token.go#L34
func (pts *tokenStore) Create(ctx context.Context, info oauth2.TokenInfo) error {
t, ok := info.(*models.Token)
if !ok {
@ -106,22 +106,25 @@ func (pts *tokenStore) Create(ctx context.Context, info oauth2.TokenInfo) error
// RemoveByCode deletes a token from the DB based on the Code field
func (pts *tokenStore) RemoveByCode(ctx context.Context, code string) error {
return pts.db.DeleteWhere("code", code, &oauthToken{})
return pts.db.DeleteWhere("code", code, &Token{})
}
// RemoveByAccess deletes a token from the DB based on the Access field
func (pts *tokenStore) RemoveByAccess(ctx context.Context, access string) error {
return pts.db.DeleteWhere("access", access, &oauthToken{})
return pts.db.DeleteWhere("access", access, &Token{})
}
// RemoveByRefresh deletes a token from the DB based on the Refresh field
func (pts *tokenStore) RemoveByRefresh(ctx context.Context, refresh string) error {
return pts.db.DeleteWhere("refresh", refresh, &oauthToken{})
return pts.db.DeleteWhere("refresh", refresh, &Token{})
}
// GetByCode selects a token from the DB based on the Code field
func (pts *tokenStore) GetByCode(ctx context.Context, code string) (oauth2.TokenInfo, error) {
pgt := &oauthToken{
if code == "" {
return nil, nil
}
pgt := &Token{
Code: code,
}
if err := pts.db.GetWhere("code", code, pgt); err != nil {
@ -132,7 +135,10 @@ func (pts *tokenStore) GetByCode(ctx context.Context, code string) (oauth2.Token
// GetByAccess selects a token from the DB based on the Access field
func (pts *tokenStore) GetByAccess(ctx context.Context, access string) (oauth2.TokenInfo, error) {
pgt := &oauthToken{
if access == "" {
return nil, nil
}
pgt := &Token{
Access: access,
}
if err := pts.db.GetWhere("access", access, pgt); err != nil {
@ -143,7 +149,10 @@ func (pts *tokenStore) GetByAccess(ctx context.Context, access string) (oauth2.T
// GetByRefresh selects a token from the DB based on the Refresh field
func (pts *tokenStore) GetByRefresh(ctx context.Context, refresh string) (oauth2.TokenInfo, error) {
pgt := &oauthToken{
if refresh == "" {
return nil, nil
}
pgt := &Token{
Refresh: refresh,
}
if err := pts.db.GetWhere("refresh", refresh, pgt); err != nil {
@ -156,17 +165,17 @@ func (pts *tokenStore) GetByRefresh(ctx context.Context, refresh string) (oauth2
The following models are basically helpers for the postgres token store implementation, they should only be used internally.
*/
// oauthToken is a translation of the gotosocial token with the ExpiresIn fields replaced with ExpiresAt.
// Token is a translation of the gotosocial token with the ExpiresIn fields replaced with ExpiresAt.
//
// Explanation for this: gotosocial assumes an in-memory or file database of some kind, where a time-to-live parameter (TTL) can be defined,
// and tokens with expired TTLs are automatically removed. Since Postgres doesn't have that feature, it's easier to set an expiry time and
// then periodically sweep out tokens when that time has passed.
//
// Note that this struct does *not* satisfy the token interface shown here: https://github.com/gotosocial/oauth2/blob/master/model.go#L22
// and implemented here: https://github.com/gotosocial/oauth2/blob/master/models/token.go.
// As such, manual translation is always required between oauthToken and the gotosocial *model.Token. The helper functions oauthTokenToPGToken
// Note that this struct does *not* satisfy the token interface shown here: https://github.com/superseriousbusiness/oauth2/blob/master/model.go#L22
// and implemented here: https://github.com/superseriousbusiness/oauth2/blob/master/models/token.go.
// As such, manual translation is always required between Token and the gotosocial *model.Token. The helper functions oauthTokenToPGToken
// and pgTokenToOauthToken can be used for that.
type oauthToken struct {
type Token struct {
ID string `pg:"type:uuid,default:gen_random_uuid(),pk,notnull"`
ClientID string
UserID string
@ -186,7 +195,7 @@ type oauthToken struct {
}
// oauthTokenToPGToken is a lil util function that takes a gotosocial token and gives back a token for inserting into postgres
func oauthTokenToPGToken(tkn *models.Token) *oauthToken {
func oauthTokenToPGToken(tkn *models.Token) *Token {
now := time.Now()
// For the following, we want to make sure we're not adding a time.Now() to an *empty* ExpiresIn, otherwise that's
@ -208,7 +217,7 @@ func oauthTokenToPGToken(tkn *models.Token) *oauthToken {
rea = now.Add(tkn.RefreshExpiresIn)
}
return &oauthToken{
return &Token{
ClientID: tkn.ClientID,
UserID: tkn.UserID,
RedirectURI: tkn.RedirectURI,
@ -228,7 +237,7 @@ func oauthTokenToPGToken(tkn *models.Token) *oauthToken {
}
// pgTokenToOauthToken is a lil util function that takes a postgres token and gives back a gotosocial token
func pgTokenToOauthToken(pgt *oauthToken) *models.Token {
func pgTokenToOauthToken(pgt *Token) *models.Token {
now := time.Now()
return &models.Token{

View file

@ -0,0 +1,21 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package oauth
// TODO: write tests

View file

@ -0,0 +1,44 @@
// Code generated by mockery v2.7.4. DO NOT EDIT.
package router
import (
context "context"
gin "github.com/gin-gonic/gin"
mock "github.com/stretchr/testify/mock"
)
// MockRouter is an autogenerated mock type for the Router type
type MockRouter struct {
mock.Mock
}
// AttachHandler provides a mock function with given fields: method, path, f
func (_m *MockRouter) AttachHandler(method string, path string, f gin.HandlerFunc) {
_m.Called(method, path, f)
}
// AttachMiddleware provides a mock function with given fields: handler
func (_m *MockRouter) AttachMiddleware(handler gin.HandlerFunc) {
_m.Called(handler)
}
// Start provides a mock function with given fields:
func (_m *MockRouter) Start() {
_m.Called()
}
// Stop provides a mock function with given fields: ctx
func (_m *MockRouter) Stop(ctx context.Context) error {
ret := _m.Called(ctx)
var r0 error
if rf, ok := ret.Get(0).(func(context.Context) error); ok {
r0 = rf(ctx)
} else {
r0 = ret.Error(0)
}
return r0
}

View file

@ -19,62 +19,66 @@
package router
import (
"context"
"crypto/rand"
"fmt"
"net/http"
"os"
"path/filepath"
"github.com/gin-contrib/sessions"
"github.com/gin-contrib/sessions/memstore"
"github.com/gin-gonic/gin"
"github.com/gotosocial/gotosocial/internal/config"
"github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/config"
)
// Router provides the REST interface for gotosocial, using gin.
type Router interface {
// Attach a gin handler to the router with the given method and path
AttachHandler(method string, path string, handler gin.HandlerFunc)
AttachHandler(method string, path string, f gin.HandlerFunc)
// Attach a gin middleware to the router that will be used globally
AttachMiddleware(handler gin.HandlerFunc)
// Start the router
Start()
// Stop the router
Stop()
Stop(ctx context.Context) error
}
// router fulfils the Router interface using gin and logrus
type router struct {
logger *logrus.Logger
engine *gin.Engine
srv *http.Server
}
// Start starts the router nicely
func (s *router) Start() {
// todo: start gracefully
if err := s.engine.Run(); err != nil {
s.logger.Panicf("server error: %s", err)
func (r *router) Start() {
go func() {
if err := r.srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
r.logger.Fatalf("listen: %s", err)
}
}()
}
// Stop shuts down the router nicely
func (s *router) Stop() {
// todo: shut down gracefully
func (r *router) Stop(ctx context.Context) error {
return r.srv.Shutdown(ctx)
}
// AttachHandler attaches the given gin.HandlerFunc to the router with the specified method and path.
// If the path is set to ANY, then the handlerfunc will be used for ALL methods at its given path.
func (s *router) AttachHandler(method string, path string, handler gin.HandlerFunc) {
func (r *router) AttachHandler(method string, path string, handler gin.HandlerFunc) {
if method == "ANY" {
s.engine.Any(path, handler)
r.engine.Any(path, handler)
} else {
s.engine.Handle(method, path, handler)
r.engine.Handle(method, path, handler)
}
}
// AttachMiddleware attaches a gin middleware to the router that will be used globally
func (s *router) AttachMiddleware(middleware gin.HandlerFunc) {
s.engine.Use(middleware)
func (r *router) AttachMiddleware(middleware gin.HandlerFunc) {
r.engine.Use(middleware)
}
// New returns a new Router with the specified configuration, using the given logrus logger.
@ -100,6 +104,10 @@ func New(config *config.Config, logger *logrus.Logger) (Router, error) {
return &router{
logger: logger,
engine: engine,
srv: &http.Server{
Addr: ":8080",
Handler: engine,
},
}, nil
}

31
internal/storage/inmem.go Normal file
View file

@ -0,0 +1,31 @@
package storage
import (
"fmt"
"github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/config"
)
func NewInMem(c *config.Config, log *logrus.Logger) (Storage, error) {
return &inMemStorage{
stored: make(map[string][]byte),
}, nil
}
type inMemStorage struct {
stored map[string][]byte
}
func (s *inMemStorage) StoreFileAt(path string, data []byte) error {
s.stored[path] = data
return nil
}
func (s *inMemStorage) RetrieveFileFrom(path string) ([]byte, error) {
d, ok := s.stored[path]
if !ok {
return nil, fmt.Errorf("no data found at path %s", path)
}
return d, nil
}

21
internal/storage/local.go Normal file
View file

@ -0,0 +1,21 @@
package storage
import (
"github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/config"
)
func NewLocal(c *config.Config, log *logrus.Logger) (Storage, error) {
return &localStorage{}, nil
}
type localStorage struct {
}
func (s *localStorage) StoreFileAt(path string, data []byte) error {
return nil
}
func (s *localStorage) RetrieveFileFrom(path string) ([]byte, error) {
return nil, nil
}

View file

@ -0,0 +1,47 @@
// Code generated by mockery v2.7.4. DO NOT EDIT.
package storage
import mock "github.com/stretchr/testify/mock"
// MockStorage is an autogenerated mock type for the Storage type
type MockStorage struct {
mock.Mock
}
// RetrieveFileFrom provides a mock function with given fields: path
func (_m *MockStorage) RetrieveFileFrom(path string) ([]byte, error) {
ret := _m.Called(path)
var r0 []byte
if rf, ok := ret.Get(0).(func(string) []byte); ok {
r0 = rf(path)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]byte)
}
}
var r1 error
if rf, ok := ret.Get(1).(func(string) error); ok {
r1 = rf(path)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// StoreFileAt provides a mock function with given fields: path, data
func (_m *MockStorage) StoreFileAt(path string, data []byte) error {
ret := _m.Called(path, data)
var r0 error
if rf, ok := ret.Get(0).(func(string, []byte) error); ok {
r0 = rf(path, data)
} else {
r0 = ret.Error(0)
}
return r0
}

View file

@ -0,0 +1,24 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package storage
type Storage interface {
StoreFileAt(path string, data []byte) error
RetrieveFileFrom(path string) ([]byte, error)
}

32
internal/util/parse.go Normal file
View file

@ -0,0 +1,32 @@
package util
import "fmt"
type URIs struct {
HostURL string
UserURL string
UserURI string
InboxURL string
OutboxURL string
FollowersURL string
CollectionURL string
}
func GenerateURIs(username string, protocol string, host string) *URIs {
hostURL := fmt.Sprintf("%s://%s", protocol, host)
userURL := fmt.Sprintf("%s/@%s", hostURL, username)
userURI := fmt.Sprintf("%s/users/%s", hostURL, username)
inboxURL := fmt.Sprintf("%s/inbox", userURI)
outboxURL := fmt.Sprintf("%s/outbox", userURI)
followersURL := fmt.Sprintf("%s/followers", userURI)
collectionURL := fmt.Sprintf("%s/collections/featured", userURI)
return &URIs{
HostURL: hostURL,
UserURL: userURL,
UserURI: userURI,
InboxURL: inboxURL,
OutboxURL: outboxURL,
FollowersURL: followersURL,
CollectionURL: collectionURL,
}
}

144
internal/util/validation.go Normal file
View file

@ -0,0 +1,144 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package util
import (
"errors"
"fmt"
"net/mail"
"regexp"
pwv "github.com/wagslane/go-password-validator"
"golang.org/x/text/language"
)
const (
// MinimumPasswordEntropy dictates password strength. See https://github.com/wagslane/go-password-validator
MinimumPasswordEntropy = 60
// MinimumReasonLength is the length of chars we expect as a bare minimum effort
MinimumReasonLength = 40
// MaximumReasonLength is the maximum amount of chars we're happy to accept
MaximumReasonLength = 500
// MaximumEmailLength is the maximum length of an email address we're happy to accept
MaximumEmailLength = 256
// MaximumUsernameLength is the maximum length of a username we're happy to accept
MaximumUsernameLength = 64
// MaximumPasswordLength is the maximum length of a password we're happy to accept
MaximumPasswordLength = 64
// NewUsernameRegexString is string representation of the regular expression for validating usernames
NewUsernameRegexString = `^[a-z0-9_]+$`
)
var (
// NewUsernameRegex is the compiled regex for validating new usernames
NewUsernameRegex = regexp.MustCompile(NewUsernameRegexString)
)
// ValidateNewPassword returns an error if the given password is not sufficiently strong, or nil if it's ok.
func ValidateNewPassword(password string) error {
if password == "" {
return errors.New("no password provided")
}
if len(password) > MaximumPasswordLength {
return fmt.Errorf("password should be no more than %d chars", MaximumPasswordLength)
}
return pwv.Validate(password, MinimumPasswordEntropy)
}
// ValidateUsername makes sure that a given username is valid (ie., letters, numbers, underscores, check length).
// Returns an error if not.
func ValidateUsername(username string) error {
if username == "" {
return errors.New("no username provided")
}
if len(username) > MaximumUsernameLength {
return fmt.Errorf("username should be no more than %d chars but '%s' was %d", MaximumUsernameLength, username, len(username))
}
if !NewUsernameRegex.MatchString(username) {
return fmt.Errorf("given username %s was invalid: must contain only lowercase letters, numbers, and underscores", username)
}
return nil
}
// ValidateEmail makes sure that a given email address is a valid address.
// Returns an error if not.
func ValidateEmail(email string) error {
if email == "" {
return errors.New("no email provided")
}
if len(email) > MaximumEmailLength {
return fmt.Errorf("email address should be no more than %d chars but '%s' was %d", MaximumEmailLength, email, len(email))
}
_, err := mail.ParseAddress(email)
return err
}
// ValidateLanguage checks that the given language string is a 2- or 3-letter ISO 639 code.
// Returns an error if the language cannot be parsed. See: https://pkg.go.dev/golang.org/x/text/language
func ValidateLanguage(lang string) error {
if lang == "" {
return errors.New("no language provided")
}
_, err := language.ParseBase(lang)
return err
}
// ValidateSignUpReason checks that a sufficient reason is given for a server signup request
func ValidateSignUpReason(reason string, reasonRequired bool) error {
if !reasonRequired {
// we don't care!
// we're not going to do anything with this text anyway if no reason is required
return nil
}
if reason == "" {
return errors.New("no reason provided")
}
if len(reason) < MinimumReasonLength {
return fmt.Errorf("reason should be at least %d chars but '%s' was %d", MinimumReasonLength, reason, len(reason))
}
if len(reason) > MaximumReasonLength {
return fmt.Errorf("reason should be no more than %d chars but given reason was %d", MaximumReasonLength, len(reason))
}
return nil
}
func ValidateDisplayName(displayName string) error {
// TODO: add some validation logic here -- length, characters, etc
return nil
}
func ValidateNote(note string) error {
// TODO: add some validation logic here -- length, characters, etc
return nil
}
func ValidatePrivacy(privacy string) error {
// TODO: add some validation logic here -- length, characters, etc
return nil
}

View file

@ -0,0 +1,288 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package util
import (
"errors"
"fmt"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/suite"
)
type ValidationTestSuite struct {
suite.Suite
}
func (suite *ValidationTestSuite) TestCheckPasswordStrength() {
empty := ""
terriblePassword := "password"
weakPassword := "OKPassword"
shortPassword := "Ok12"
specialPassword := "Ok12%"
longPassword := "thisisafuckinglongpasswordbutnospecialchars"
tooLong := "Vestibulum ante ipsum primis in faucibus orci luctus et ultrices posuere cubilia curae; Quisque a enim nibh. Vestibulum bibendum leo ac porttitor auctor."
strongPassword := "3dX5@Zc%mV*W2MBNEy$@"
var err error
err = ValidateNewPassword(empty)
if assert.Error(suite.T(), err) {
assert.Equal(suite.T(), errors.New("no password provided"), err)
}
err = ValidateNewPassword(terriblePassword)
if assert.Error(suite.T(), err) {
assert.Equal(suite.T(), errors.New("insecure password, try including more special characters, using uppercase letters, using numbers or using a longer password"), err)
}
err = ValidateNewPassword(weakPassword)
if assert.Error(suite.T(), err) {
assert.Equal(suite.T(), errors.New("insecure password, try including more special characters, using numbers or using a longer password"), err)
}
err = ValidateNewPassword(shortPassword)
if assert.Error(suite.T(), err) {
assert.Equal(suite.T(), errors.New("insecure password, try including more special characters or using a longer password"), err)
}
err = ValidateNewPassword(specialPassword)
if assert.Error(suite.T(), err) {
assert.Equal(suite.T(), errors.New("insecure password, try including more special characters or using a longer password"), err)
}
err = ValidateNewPassword(longPassword)
if assert.NoError(suite.T(), err) {
assert.Equal(suite.T(), nil, err)
}
err = ValidateNewPassword(tooLong)
if assert.Error(suite.T(), err) {
assert.Equal(suite.T(), errors.New("password should be no more than 64 chars"), err)
}
err = ValidateNewPassword(strongPassword)
if assert.NoError(suite.T(), err) {
assert.Equal(suite.T(), nil, err)
}
}
func (suite *ValidationTestSuite) TestValidateUsername() {
empty := ""
tooLong := "holycrapthisisthelongestusernameiveeverseeninmylifethatstoomuchman"
withSpaces := "this username has spaces in it"
weirdChars := "thisusername&&&&&&&istooweird!!"
leadingSpace := " see_that_leading_space"
trailingSpace := "thisusername_ends_with_a_space "
newlines := "this_is\n_almost_ok"
goodUsername := "this_is_a_good_username"
var err error
err = ValidateUsername(empty)
if assert.Error(suite.T(), err) {
assert.Equal(suite.T(), errors.New("no username provided"), err)
}
err = ValidateUsername(tooLong)
if assert.Error(suite.T(), err) {
assert.Equal(suite.T(), fmt.Errorf("username should be no more than 64 chars but '%s' was 66", tooLong), err)
}
err = ValidateUsername(withSpaces)
if assert.Error(suite.T(), err) {
assert.Equal(suite.T(), fmt.Errorf("given username %s was invalid: must contain only lowercase letters, numbers, and underscores", withSpaces), err)
}
err = ValidateUsername(weirdChars)
if assert.Error(suite.T(), err) {
assert.Equal(suite.T(), fmt.Errorf("given username %s was invalid: must contain only lowercase letters, numbers, and underscores", weirdChars), err)
}
err = ValidateUsername(leadingSpace)
if assert.Error(suite.T(), err) {
assert.Equal(suite.T(), fmt.Errorf("given username %s was invalid: must contain only lowercase letters, numbers, and underscores", leadingSpace), err)
}
err = ValidateUsername(trailingSpace)
if assert.Error(suite.T(), err) {
assert.Equal(suite.T(), fmt.Errorf("given username %s was invalid: must contain only lowercase letters, numbers, and underscores", trailingSpace), err)
}
err = ValidateUsername(newlines)
if assert.Error(suite.T(), err) {
assert.Equal(suite.T(), fmt.Errorf("given username %s was invalid: must contain only lowercase letters, numbers, and underscores", newlines), err)
}
err = ValidateUsername(goodUsername)
if assert.NoError(suite.T(), err) {
assert.Equal(suite.T(), nil, err)
}
}
func (suite *ValidationTestSuite) TestValidateEmail() {
empty := ""
notAnEmailAddress := "this-is-no-email-address!"
almostAnEmailAddress := "@thisisalmostan@email.address"
aWebsite := "https://thisisawebsite.com"
tooLong := "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaahhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhggggggggggggggggggggggggggggggggggggggghhhhhhhhhhhhhhhhhggggggggggggggggggggghhhhhhhhhhhhhhhhhhhhhhhhhhhhhh@gmail.com"
emailAddress := "thisis.actually@anemail.address"
var err error
err = ValidateEmail(empty)
if assert.Error(suite.T(), err) {
assert.Equal(suite.T(), errors.New("no email provided"), err)
}
err = ValidateEmail(notAnEmailAddress)
if assert.Error(suite.T(), err) {
assert.Equal(suite.T(), errors.New("mail: missing '@' or angle-addr"), err)
}
err = ValidateEmail(almostAnEmailAddress)
if assert.Error(suite.T(), err) {
assert.Equal(suite.T(), errors.New("mail: no angle-addr"), err)
}
err = ValidateEmail(aWebsite)
if assert.Error(suite.T(), err) {
assert.Equal(suite.T(), errors.New("mail: missing '@' or angle-addr"), err)
}
err = ValidateEmail(tooLong)
if assert.Error(suite.T(), err) {
assert.Equal(suite.T(), fmt.Errorf("email address should be no more than 256 chars but '%s' was 286", tooLong), err)
}
err = ValidateEmail(emailAddress)
if assert.NoError(suite.T(), err) {
assert.Equal(suite.T(), nil, err)
}
}
func (suite *ValidationTestSuite) TestValidateLanguage() {
empty := ""
notALanguage := "this isn't a language at all!"
english := "en"
capitalEnglish := "EN"
arabic3Letters := "ara"
mixedCapsEnglish := "eN"
englishUS := "en-us"
dutch := "nl"
german := "de"
var err error
err = ValidateLanguage(empty)
if assert.Error(suite.T(), err) {
assert.Equal(suite.T(), errors.New("no language provided"), err)
}
err = ValidateLanguage(notALanguage)
if assert.Error(suite.T(), err) {
assert.Equal(suite.T(), errors.New("language: tag is not well-formed"), err)
}
err = ValidateLanguage(english)
if assert.NoError(suite.T(), err) {
assert.Equal(suite.T(), nil, err)
}
err = ValidateLanguage(capitalEnglish)
if assert.NoError(suite.T(), err) {
assert.Equal(suite.T(), nil, err)
}
err = ValidateLanguage(arabic3Letters)
if assert.NoError(suite.T(), err) {
assert.Equal(suite.T(), nil, err)
}
err = ValidateLanguage(mixedCapsEnglish)
if assert.NoError(suite.T(), err) {
assert.Equal(suite.T(), nil, err)
}
err = ValidateLanguage(englishUS)
if assert.Error(suite.T(), err) {
assert.Equal(suite.T(), errors.New("language: tag is not well-formed"), err)
}
err = ValidateLanguage(dutch)
if assert.NoError(suite.T(), err) {
assert.Equal(suite.T(), nil, err)
}
err = ValidateLanguage(german)
if assert.NoError(suite.T(), err) {
assert.Equal(suite.T(), nil, err)
}
}
func (suite *ValidationTestSuite) TestValidateReason() {
empty := ""
badReason := "because"
goodReason := "to smash the state and destroy capitalism ultimately and completely"
tooLong := "Lorem ipsum dolor sit amet, consectetur adipiscing elit. Mauris auctor mollis viverra. Maecenas maximus mollis sem, nec fermentum velit consectetur non. Vestibulum ante ipsum primis in faucibus orci luctus et ultrices posuere cubilia curae; Quisque a enim nibh. Vestibulum bibendum leo ac porttitor auctor. Curabitur velit tellus, facilisis vitae lorem a, ullamcorper efficitur leo. Sed a auctor tortor. Sed ut finibus ante, sit amet laoreet sapien. Donec ullamcorper tellus a nibh sodales vulputate. Donec id dolor eu odio mollis bibendum. Pellentesque habitant morbi tristique senectus et netus at."
var err error
// check with no reason required
err = ValidateSignUpReason(empty, false)
if assert.NoError(suite.T(), err) {
assert.Equal(suite.T(), nil, err)
}
err = ValidateSignUpReason(badReason, false)
if assert.NoError(suite.T(), err) {
assert.Equal(suite.T(), nil, err)
}
err = ValidateSignUpReason(tooLong, false)
if assert.NoError(suite.T(), err) {
assert.Equal(suite.T(), nil, err)
}
err = ValidateSignUpReason(goodReason, false)
if assert.NoError(suite.T(), err) {
assert.Equal(suite.T(), nil, err)
}
// check with reason required
err = ValidateSignUpReason(empty, true)
if assert.Error(suite.T(), err) {
assert.Equal(suite.T(), errors.New("no reason provided"), err)
}
err = ValidateSignUpReason(badReason, true)
if assert.Error(suite.T(), err) {
assert.Equal(suite.T(), errors.New("reason should be at least 40 chars but 'because' was 7"), err)
}
err = ValidateSignUpReason(tooLong, true)
if assert.Error(suite.T(), err) {
assert.Equal(suite.T(), errors.New("reason should be no more than 500 chars but given reason was 600"), err)
}
err = ValidateSignUpReason(goodReason, true)
if assert.NoError(suite.T(), err) {
assert.Equal(suite.T(), nil, err)
}
}
func TestValidationTestSuite(t *testing.T) {
suite.Run(t, new(ValidationTestSuite))
}

View file

@ -18,6 +18,8 @@
package mastotypes
import "mime/multipart"
// Account represents a mastodon-api Account object, as described here: https://docs.joinmastodon.org/entities/account/
type Account struct {
// The account id
@ -31,7 +33,7 @@ type Account struct {
// Whether the account manually approves follow requests.
Locked bool `json:"locked"`
// Whether the account has opted into discovery features such as the profile directory.
Discoverable bool `json:"discoverable"`
Discoverable bool `json:"discoverable,omitempty"`
// A presentational flag. Indicates that the account may perform automated actions, may not be monitored, or identifies as a robot.
Bot bool `json:"bot"`
// When the account was created. (ISO 8601 Datetime)
@ -61,9 +63,69 @@ type Account struct {
// Additional metadata attached to a profile as name-value pairs.
Fields []Field `json:"fields"`
// An extra entity returned when an account is suspended.
Suspended bool `json:"suspended"`
Suspended bool `json:"suspended,omitempty"`
// When a timed mute will expire, if applicable. (ISO 8601 Datetime)
MuteExpiresAt string `json:"mute_expires_at"`
MuteExpiresAt string `json:"mute_expires_at,omitempty"`
// An extra entity to be used with API methods to verify credentials and update credentials.
Source *Source `json:"source"`
}
// AccountCreateRequest represents the form submitted during a POST request to /api/v1/accounts.
// See https://docs.joinmastodon.org/methods/accounts/
type AccountCreateRequest struct {
// Text that will be reviewed by moderators if registrations require manual approval.
Reason string `form:"reason"`
// The desired username for the account
Username string `form:"username" binding:"required"`
// The email address to be used for login
Email string `form:"email" binding:"required"`
// The password to be used for login
Password string `form:"password" binding:"required"`
// Whether the user agrees to the local rules, terms, and policies.
// These should be presented to the user in order to allow them to consent before setting this parameter to TRUE.
Agreement bool `form:"agreement" binding:"required"`
// The language of the confirmation email that will be sent
Locale string `form:"locale" binding:"required"`
}
// UpdateCredentialsRequest represents the form submitted during a PATCH request to /api/v1/accounts/update_credentials.
// See https://docs.joinmastodon.org/methods/accounts/
type UpdateCredentialsRequest struct {
// Whether the account should be shown in the profile directory.
Discoverable *bool `form:"discoverable"`
// Whether the account has a bot flag.
Bot *bool `form:"bot"`
// The display name to use for the profile.
DisplayName *string `form:"display_name"`
// The account bio.
Note *string `form:"note"`
// Avatar image encoded using multipart/form-data
Avatar *multipart.FileHeader `form:"avatar"`
// Header image encoded using multipart/form-data
Header *multipart.FileHeader `form:"header"`
// Whether manual approval of follow requests is required.
Locked *bool `form:"locked"`
// New Source values for this account
Source *UpdateSource `form:"source"`
// Profile metadata name and value
FieldsAttributes *[]UpdateField `form:"fields_attributes"`
}
// UpdateSource is to be used specifically in an UpdateCredentialsRequest.
type UpdateSource struct {
// Default post privacy for authored statuses.
Privacy *string `form:"privacy"`
// Whether to mark authored statuses as sensitive by default.
Sensitive *bool `form:"sensitive"`
// Default language to use for authored statuses. (ISO 6391)
Language *string `form:"language"`
}
// UpdateField is to be used specifically in an UpdateCredentialsRequest.
// By default, max 4 fields and 255 characters per property/value.
type UpdateField struct {
// Name of the field
Name *string `form:"name"`
// Value of the field
Value *string `form:"value"`
}

View file

@ -43,11 +43,11 @@ type Application struct {
// And here: https://docs.joinmastodon.org/client/token/
type ApplicationPOSTRequest struct {
// A name for your application
ClientName string `form:"client_name"`
ClientName string `form:"client_name" binding:"required"`
// Where the user should be redirected after authorization.
// To display the authorization code to the user instead of redirecting
// to a web page, use urn:ietf:wg:oauth:2.0:oob in this parameter.
RedirectURIs string `form:"redirect_uris"`
RedirectURIs string `form:"redirect_uris" binding:"required"`
// Space separated list of scopes. If none is provided, defaults to read.
Scopes string `form:"scopes"`
// A URL to the homepage of your app

View file

@ -28,7 +28,6 @@ type Field struct {
Value string `json:"value"`
// OPTIONAL
// Timestamp of when the server verified a URL value for a rel="me” link. String (ISO 8601 Datetime) if value is a verified URL
VerifiedAt string `json:"verified_at,omitempty"`
}

View file

@ -18,5 +18,24 @@
package mastotypes
// Source represents display or publishing preferences of user's own account.
// Returned as an additional entity when verifying and updated credentials, as an attribute of Account.
// See https://docs.joinmastodon.org/entities/source/
type Source struct {
// The default post privacy to be used for new statuses.
// public = Public post
// unlisted = Unlisted post
// private = Followers-only post
// direct = Direct post
Privacy string `json:"privacy,omitempty"`
// Whether new statuses should be marked sensitive by default.
Sensitive bool `json:"sensitive,omitempty"`
// The default posting language for new statuses.
Language string `json:"language,omitempty"`
// Profile bio.
Note string `json:"note"`
// Metadata about the account.
Fields []Field `json:"fields"`
// The number of pending follow requests.
FollowRequestsCount int `json:"follow_requests_count,omitempty"`
}

View file

@ -18,5 +18,6 @@
package mastotypes
// Tag represents a hashtag used within the content of a status. See https://docs.joinmastodon.org/entities/tag/
type Tag struct {
}

31
pkg/mastotypes/token.go Normal file
View file

@ -0,0 +1,31 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package mastotypes
// Token represents an OAuth token used for authenticating with the API and performing actions.. See https://docs.joinmastodon.org/entities/token/
type Token struct {
// An OAuth token to be used for authorization.
AccessToken string `json:"access_token"`
// The OAuth token type. Mastodon uses Bearer tokens.
TokenType string `json:"token_type"`
// The OAuth scopes granted by this token, space-separated.
Scope string `json:"scope"`
// When the token was generated. (UNIX timestamp seconds)
CreatedAt int64 `json:"created_at"`
}

33
scripts/auth_flow.sh Executable file
View file

@ -0,0 +1,33 @@
#!/bin/sh
set -eux
SERVER_URL="http://localhost:8080"
REDIRECT_URI="${SERVER_URL}"
CLIENT_NAME="Test Application Name"
REGISTRATION_REASON="Testing whether or not this dang diggity thing works!"
REGISTRATION_EMAIL="test@example.org"
REGISTRATION_USERNAME="test_user"
REGISTRATION_PASSWORD="very safe password 123"
REGISTRATION_AGREEMENT="true"
REGISTRATION_LOCALE="en"
# Step 1: create the app to register the new account
CREATE_APP_RESPONSE=$(curl --fail -s -X POST -F "client_name=${CLIENT_NAME}" -F "redirect_uris=${REDIRECT_URI}" "${SERVER_URL}/api/v1/apps")
CLIENT_ID=$(echo "${CREATE_APP_RESPONSE}" | jq -r .client_id)
CLIENT_SECRET=$(echo "${CREATE_APP_RESPONSE}" | jq -r .client_secret)
echo "Obtained client_id: ${CLIENT_ID} and client_secret: ${CLIENT_SECRET}"
# Step 2: obtain a code for that app
APP_CODE_RESPONSE=$(curl --fail -s -X POST -F "scope=read" -F "grant_type=client_credentials" -F "client_id=${CLIENT_ID}" -F "client_secret=${CLIENT_SECRET}" -F "redirect_uri=${REDIRECT_URI}" "${SERVER_URL}/oauth/token")
APP_ACCESS_TOKEN=$(echo "${APP_CODE_RESPONSE}" | jq -r .access_token)
echo "Obtained app access token: ${APP_ACCESS_TOKEN}"
# Step 3: use the code to register a new account
ACCOUNT_REGISTER_RESPONSE=$(curl --fail -s -H "Authorization: Bearer ${APP_ACCESS_TOKEN}" -F "reason=${REGISTRATION_REASON}" -F "email=${REGISTRATION_EMAIL}" -F "username=${REGISTRATION_USERNAME}" -F "password=${REGISTRATION_PASSWORD}" -F "agreement=${REGISTRATION_AGREEMENT}" -F "locale=${REGISTRATION_LOCALE}" "${SERVER_URL}/api/v1/accounts")
USER_ACCESS_TOKEN=$(echo "${ACCOUNT_REGISTER_RESPONSE}" | jq -r .access_token)
echo "Obtained user access token: ${USER_ACCESS_TOKEN}"
# # Step 4: verify the returned access token
curl -s -H "Authorization: Bearer ${USER_ACCESS_TOKEN}" "${SERVER_URL}/api/v1/accounts/verify_credentials" | jq