[chore] The Big Middleware and API Refactor (tm) (#1250)

* interim commit: start refactoring middlewares into package under router

* another interim commit, this is becoming a big job

* another fucking massive interim commit

* refactor bookmarks to new style

* ambassador, wiz zeze commits you are spoiling uz

* she compiles, we're getting there

* we're just normal men; we're just innocent men

* apiutil

* whoopsie

* i'm glad noone reads commit msgs haha :blob_sweat:

* use that weirdo go-bytesize library for maxMultipartMemory

* fix media module paths
This commit is contained in:
tobi 2023-01-02 13:10:50 +01:00 committed by GitHub
parent 560ff1209d
commit 941893a774
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
228 changed files with 3188 additions and 3047 deletions

View file

@ -251,7 +251,7 @@ Most of the crucial business logic of the application is found inside the variou
`internal/ap` - ActivityPub utility functions and interfaces. `internal/ap` - ActivityPub utility functions and interfaces.
`internal/api` - Models, routers, and utilities for the client and federated (s2s) APIs. This is where routes are attached to the router, and where you want to be if you're adding a route. `internal/api` - Models, routers, and utilities for the client and federated (ActivityPub) APIs. This is where routes are attached to the router, and where you want to be if you're adding a route.
`internal/concurrency` - Worker models used by the processor and other queues. `internal/concurrency` - Worker models used by the processor and other queues.
@ -283,6 +283,8 @@ Most of the crucial business logic of the application is found inside the variou
`internal/messages` - Models for wrapping worker messages. `internal/messages` - Models for wrapping worker messages.
`internal/middleware` - Gin Gonic router middlewares: http signature checking, cache control, token checks, etc.
`internal/netutil` - HTTP / net request validation code. `internal/netutil` - HTTP / net request validation code.
`internal/oauth` - Wrapper code/interfaces for OAuth server implementation. `internal/oauth` - Wrapper code/interfaces for OAuth server implementation.

View file

@ -20,37 +20,20 @@ package server
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"net/http"
"os" "os"
"os/signal" "os/signal"
"syscall" "syscall"
"github.com/gin-gonic/gin"
"github.com/superseriousbusiness/gotosocial/cmd/gotosocial/action" "github.com/superseriousbusiness/gotosocial/cmd/gotosocial/action"
"github.com/superseriousbusiness/gotosocial/internal/api" "github.com/superseriousbusiness/gotosocial/internal/api"
"github.com/superseriousbusiness/gotosocial/internal/api/client/account" apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util"
"github.com/superseriousbusiness/gotosocial/internal/api/client/admin" "github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/api/client/app" "github.com/superseriousbusiness/gotosocial/internal/middleware"
"github.com/superseriousbusiness/gotosocial/internal/api/client/auth"
"github.com/superseriousbusiness/gotosocial/internal/api/client/blocks"
"github.com/superseriousbusiness/gotosocial/internal/api/client/bookmarks"
"github.com/superseriousbusiness/gotosocial/internal/api/client/emoji"
"github.com/superseriousbusiness/gotosocial/internal/api/client/favourites"
"github.com/superseriousbusiness/gotosocial/internal/api/client/fileserver"
"github.com/superseriousbusiness/gotosocial/internal/api/client/filter"
"github.com/superseriousbusiness/gotosocial/internal/api/client/followrequest"
"github.com/superseriousbusiness/gotosocial/internal/api/client/instance"
"github.com/superseriousbusiness/gotosocial/internal/api/client/list"
mediaModule "github.com/superseriousbusiness/gotosocial/internal/api/client/media"
"github.com/superseriousbusiness/gotosocial/internal/api/client/notification"
"github.com/superseriousbusiness/gotosocial/internal/api/client/search"
"github.com/superseriousbusiness/gotosocial/internal/api/client/status"
"github.com/superseriousbusiness/gotosocial/internal/api/client/streaming"
"github.com/superseriousbusiness/gotosocial/internal/api/client/timeline"
userClient "github.com/superseriousbusiness/gotosocial/internal/api/client/user"
"github.com/superseriousbusiness/gotosocial/internal/api/s2s/nodeinfo"
"github.com/superseriousbusiness/gotosocial/internal/api/s2s/user"
"github.com/superseriousbusiness/gotosocial/internal/api/s2s/webfinger"
"github.com/superseriousbusiness/gotosocial/internal/api/security"
"github.com/superseriousbusiness/gotosocial/internal/concurrency" "github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db/bundb" "github.com/superseriousbusiness/gotosocial/internal/db/bundb"
@ -106,11 +89,6 @@ var Start action.GTSAction = func(ctx context.Context) error {
federatingDB := federatingdb.New(dbService, fedWorker) federatingDB := federatingdb.New(dbService, fedWorker)
router, err := router.New(ctx, dbService)
if err != nil {
return fmt.Errorf("error creating router: %s", err)
}
// build converters and util // build converters and util
typeConverter := typeutils.NewConverter(dbService) typeConverter := typeutils.NewConverter(dbService)
@ -148,85 +126,72 @@ var Start action.GTSAction = func(ctx context.Context) error {
} }
} }
// create and start the message processor using the other services we've created so far // create the message processor using the other services we've created so far
processor := processing.NewProcessor(typeConverter, federator, oauthServer, mediaManager, storage, dbService, emailSender, clientWorker, fedWorker) processor := processing.NewProcessor(typeConverter, federator, oauthServer, mediaManager, storage, dbService, emailSender, clientWorker, fedWorker)
if err := processor.Start(); err != nil { if err := processor.Start(); err != nil {
return fmt.Errorf("error starting processor: %s", err) return fmt.Errorf("error creating processor: %s", err)
} }
idp, err := oidc.NewIDP(ctx) /*
HTTP router initialization
*/
router, err := router.New(ctx)
if err != nil { if err != nil {
return fmt.Errorf("error creating oidc idp: %s", err) return fmt.Errorf("error creating router: %s", err)
} }
// build web module // attach global middlewares which are used for every request
webModule := web.New(processor) router.AttachGlobalMiddleware(
middleware.Logger(),
middleware.UserAgent(),
middleware.CORS(),
middleware.ExtraHeaders(),
)
// build client api modules // attach global no route / 404 handler to the router
authModule := auth.New(dbService, idp, processor) router.AttachNoRouteHandler(func(c *gin.Context) {
accountModule := account.New(processor) apiutil.ErrorHandler(c, gtserror.NewErrorNotFound(errors.New(http.StatusText(http.StatusNotFound))), processor.InstanceGet)
instanceModule := instance.New(processor) })
appsModule := app.New(processor)
followRequestsModule := followrequest.New(processor)
webfingerModule := webfinger.New(processor)
nodeInfoModule := nodeinfo.New(processor)
usersModule := user.New(processor)
timelineModule := timeline.New(processor)
notificationModule := notification.New(processor)
searchModule := search.New(processor)
filtersModule := filter.New(processor)
emojiModule := emoji.New(processor)
listsModule := list.New(processor)
mm := mediaModule.New(processor)
fileServerModule := fileserver.New(processor)
adminModule := admin.New(processor)
statusModule := status.New(processor)
bookmarksModule := bookmarks.New(processor)
securityModule := security.New(dbService, oauthServer)
streamingModule := streaming.New(processor)
favouritesModule := favourites.New(processor)
blocksModule := blocks.New(processor)
userClientModule := userClient.New(processor)
apis := []api.ClientModule{ // build router modules
// modules with middleware go first var idp oidc.IDP
securityModule, if config.GetOIDCEnabled() {
authModule, idp, err = oidc.NewIDP(ctx)
if err != nil {
// now the web module return fmt.Errorf("error creating oidc idp: %w", err)
webModule,
// now everything else
accountModule,
instanceModule,
appsModule,
followRequestsModule,
mm,
fileServerModule,
adminModule,
statusModule,
bookmarksModule,
webfingerModule,
nodeInfoModule,
usersModule,
timelineModule,
notificationModule,
searchModule,
filtersModule,
emojiModule,
listsModule,
streamingModule,
favouritesModule,
blocksModule,
userClientModule,
}
for _, m := range apis {
if err := m.Route(router); err != nil {
return fmt.Errorf("routing error: %s", err)
} }
} }
routerSession, err := dbService.GetSession(ctx)
if err != nil {
return fmt.Errorf("error retrieving router session for session middleware: %w", err)
}
sessionName, err := middleware.SessionName()
if err != nil {
return fmt.Errorf("error generating session name for session middleware: %w", err)
}
var (
authModule = api.NewAuth(dbService, processor, idp, routerSession, sessionName) // auth/oauth paths
clientModule = api.NewClient(dbService, processor) // api client endpoints
fileserverModule = api.NewFileserver(processor) // fileserver endpoints
wellKnownModule = api.NewWellKnown(processor) // .well-known endpoints
nodeInfoModule = api.NewNodeInfo(processor) // nodeinfo endpoint
activityPubModule = api.NewActivityPub(dbService, processor) // ActivityPub endpoints
webModule = web.New(processor) // web pages + user profiles + settings panels etc
)
// these should be routed in order
authModule.Route(router)
clientModule.Route(router)
fileserverModule.Route(router)
wellKnownModule.Route(router)
nodeInfoModule.Route(router)
activityPubModule.Route(router)
webModule.Route(router)
gts, err := gotosocial.NewServer(dbService, router, federator, mediaManager) gts, err := gotosocial.NewServer(dbService, router, federator, mediaManager)
if err != nil { if err != nil {
return fmt.Errorf("error creating gotosocial service: %s", err) return fmt.Errorf("error creating gotosocial service: %s", err)

View file

@ -21,6 +21,7 @@ package testrig
import ( import (
"bytes" "bytes"
"context" "context"
"errors"
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
@ -28,35 +29,17 @@ import (
"os/signal" "os/signal"
"syscall" "syscall"
"github.com/gin-gonic/gin"
"github.com/superseriousbusiness/gotosocial/cmd/gotosocial/action" "github.com/superseriousbusiness/gotosocial/cmd/gotosocial/action"
"github.com/superseriousbusiness/gotosocial/internal/api" "github.com/superseriousbusiness/gotosocial/internal/api"
"github.com/superseriousbusiness/gotosocial/internal/api/client/account" apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util"
"github.com/superseriousbusiness/gotosocial/internal/api/client/admin"
"github.com/superseriousbusiness/gotosocial/internal/api/client/app"
"github.com/superseriousbusiness/gotosocial/internal/api/client/auth"
"github.com/superseriousbusiness/gotosocial/internal/api/client/blocks"
"github.com/superseriousbusiness/gotosocial/internal/api/client/emoji"
"github.com/superseriousbusiness/gotosocial/internal/api/client/favourites"
"github.com/superseriousbusiness/gotosocial/internal/api/client/fileserver"
"github.com/superseriousbusiness/gotosocial/internal/api/client/filter"
"github.com/superseriousbusiness/gotosocial/internal/api/client/followrequest"
"github.com/superseriousbusiness/gotosocial/internal/api/client/instance"
"github.com/superseriousbusiness/gotosocial/internal/api/client/list"
mediaModule "github.com/superseriousbusiness/gotosocial/internal/api/client/media"
"github.com/superseriousbusiness/gotosocial/internal/api/client/notification"
"github.com/superseriousbusiness/gotosocial/internal/api/client/search"
"github.com/superseriousbusiness/gotosocial/internal/api/client/status"
"github.com/superseriousbusiness/gotosocial/internal/api/client/streaming"
"github.com/superseriousbusiness/gotosocial/internal/api/client/timeline"
userClient "github.com/superseriousbusiness/gotosocial/internal/api/client/user"
"github.com/superseriousbusiness/gotosocial/internal/api/s2s/nodeinfo"
"github.com/superseriousbusiness/gotosocial/internal/api/s2s/user"
"github.com/superseriousbusiness/gotosocial/internal/api/s2s/webfinger"
"github.com/superseriousbusiness/gotosocial/internal/api/security"
"github.com/superseriousbusiness/gotosocial/internal/concurrency" "github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/gotosocial" "github.com/superseriousbusiness/gotosocial/internal/gotosocial"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/log" "github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/superseriousbusiness/gotosocial/internal/messages" "github.com/superseriousbusiness/gotosocial/internal/messages"
"github.com/superseriousbusiness/gotosocial/internal/middleware"
"github.com/superseriousbusiness/gotosocial/internal/oidc" "github.com/superseriousbusiness/gotosocial/internal/oidc"
"github.com/superseriousbusiness/gotosocial/internal/storage" "github.com/superseriousbusiness/gotosocial/internal/storage"
"github.com/superseriousbusiness/gotosocial/internal/web" "github.com/superseriousbusiness/gotosocial/internal/web"
@ -70,7 +53,6 @@ var Start action.GTSAction = func(ctx context.Context) error {
dbService := testrig.NewTestDB() dbService := testrig.NewTestDB()
testrig.StandardDBSetup(dbService, nil) testrig.StandardDBSetup(dbService, nil)
router := testrig.NewTestRouter(dbService)
var storageBackend *storage.Driver var storageBackend *storage.Driver
if os.Getenv("GTS_STORAGE_BACKEND") == "s3" { if os.Getenv("GTS_STORAGE_BACKEND") == "s3" {
storageBackend, _ = storage.NewS3Storage() storageBackend, _ = storage.NewS3Storage()
@ -84,7 +66,6 @@ var Start action.GTSAction = func(ctx context.Context) error {
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1) fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
// build backend handlers // build backend handlers
oauthServer := testrig.NewTestOauthServer(dbService)
transportController := testrig.NewTestTransportController(testrig.NewMockHTTPClient(func(req *http.Request) (*http.Response, error) { transportController := testrig.NewTestTransportController(testrig.NewMockHTTPClient(func(req *http.Request) (*http.Response, error) {
r := io.NopCloser(bytes.NewReader([]byte{})) r := io.NopCloser(bytes.NewReader([]byte{}))
return &http.Response{ return &http.Response{
@ -102,77 +83,64 @@ var Start action.GTSAction = func(ctx context.Context) error {
return fmt.Errorf("error starting processor: %s", err) return fmt.Errorf("error starting processor: %s", err)
} }
idp, err := oidc.NewIDP(ctx) /*
if err != nil { HTTP router initialization
return fmt.Errorf("error creating oidc idp: %s", err) */
}
// build web module router := testrig.NewTestRouter(dbService)
webModule := web.New(processor)
// build client api modules // attach global middlewares which are used for every request
authModule := auth.New(dbService, idp, processor) router.AttachGlobalMiddleware(
accountModule := account.New(processor) middleware.Logger(),
instanceModule := instance.New(processor) middleware.UserAgent(),
appsModule := app.New(processor) middleware.CORS(),
followRequestsModule := followrequest.New(processor) middleware.ExtraHeaders(),
webfingerModule := webfinger.New(processor) )
nodeInfoModule := nodeinfo.New(processor)
usersModule := user.New(processor)
timelineModule := timeline.New(processor)
notificationModule := notification.New(processor)
searchModule := search.New(processor)
filtersModule := filter.New(processor)
emojiModule := emoji.New(processor)
listsModule := list.New(processor)
mm := mediaModule.New(processor)
fileServerModule := fileserver.New(processor)
adminModule := admin.New(processor)
statusModule := status.New(processor)
securityModule := security.New(dbService, oauthServer)
streamingModule := streaming.New(processor)
favouritesModule := favourites.New(processor)
blocksModule := blocks.New(processor)
userClientModule := userClient.New(processor)
apis := []api.ClientModule{ // attach global no route / 404 handler to the router
// modules with middleware go first router.AttachNoRouteHandler(func(c *gin.Context) {
securityModule, apiutil.ErrorHandler(c, gtserror.NewErrorNotFound(errors.New(http.StatusText(http.StatusNotFound))), processor.InstanceGet)
authModule, })
// now the web module // build router modules
webModule, var idp oidc.IDP
var err error
// now everything else if config.GetOIDCEnabled() {
accountModule, idp, err = oidc.NewIDP(ctx)
instanceModule, if err != nil {
appsModule, return fmt.Errorf("error creating oidc idp: %w", err)
followRequestsModule,
mm,
fileServerModule,
adminModule,
statusModule,
webfingerModule,
nodeInfoModule,
usersModule,
timelineModule,
notificationModule,
searchModule,
filtersModule,
emojiModule,
listsModule,
streamingModule,
favouritesModule,
blocksModule,
userClientModule,
}
for _, m := range apis {
if err := m.Route(router); err != nil {
return fmt.Errorf("routing error: %s", err)
} }
} }
routerSession, err := dbService.GetSession(ctx)
if err != nil {
return fmt.Errorf("error retrieving router session for session middleware: %w", err)
}
sessionName, err := middleware.SessionName()
if err != nil {
return fmt.Errorf("error generating session name for session middleware: %w", err)
}
var (
authModule = api.NewAuth(dbService, processor, idp, routerSession, sessionName) // auth/oauth paths
clientModule = api.NewClient(dbService, processor) // api client endpoints
fileserverModule = api.NewFileserver(processor) // fileserver endpoints
wellKnownModule = api.NewWellKnown(processor) // .well-known endpoints
nodeInfoModule = api.NewNodeInfo(processor) // nodeinfo endpoint
activityPubModule = api.NewActivityPub(dbService, processor) // ActivityPub endpoints
webModule = web.New(processor) // web pages + user profiles + settings panels etc
)
// these should be routed in order
authModule.Route(router)
clientModule.Route(router)
fileserverModule.Route(router)
wellKnownModule.Route(router)
nodeInfoModule.Route(router)
activityPubModule.Route(router)
webModule.Route(router)
gts, err := gotosocial.NewServer(dbService, router, federator, mediaManager) gts, err := gotosocial.NewServer(dbService, router, federator, mediaManager)
if err != nil { if err != nil {
return fmt.Errorf("error creating gotosocial service: %s", err) return fmt.Errorf("error creating gotosocial service: %s", err)

View file

@ -1,10 +1,16 @@
# Rate Limit # Rate Limit
To mitigate abuse + scraping of your instance, an IP-based HTTP rate limit is in place. To mitigate abuse + scraping of your instance, IP-based HTTP rate limiting is in place.
This rate limit applies not just to the API, but to all requests (web, federation, etc). There are separate rate limiters configured for different groupings of endpoints. In other words, being rate limited for one part of the API doesn't necessarily mean you will be rate limited for other parts. Each entry in the following list has a separate rate limiter:
By default, a maximum of 1000 requests in a 5 minute time window are allowed. - `/users/*` and `/emoji/*` - ActivityPub (s2s) endpoints.
- `/auth/*` and `/oauth/*` - Sign in + OAUTH token requests.
- `/fileserver/*` - Media attachments, emojis, etc.
- `/nodeinfo/*` - NodeInfo endpoint(s).
- `/.well-known/*` - webfinger + nodeinfo requests.
By default, each rate limiter allows a maximum of 300 requests in a 5 minute time window: 1 request per second per client IP address.
Every response will include the current status of the rate limit with the following headers: Every response will include the current status of the rate limit with the following headers:

View file

@ -1914,7 +1914,7 @@ definitions:
title: SwaggerCollection represents an activitypub collection. title: SwaggerCollection represents an activitypub collection.
type: object type: object
x-go-name: SwaggerCollection x-go-name: SwaggerCollection
x-go-package: github.com/superseriousbusiness/gotosocial/internal/api/s2s/user x-go-package: github.com/superseriousbusiness/gotosocial/internal/api/activitypub/users
swaggerCollectionPage: swaggerCollectionPage:
properties: properties:
id: id:
@ -1949,7 +1949,7 @@ definitions:
title: SwaggerCollectionPage represents one page of a collection. title: SwaggerCollectionPage represents one page of a collection.
type: object type: object
x-go-name: SwaggerCollectionPage x-go-name: SwaggerCollectionPage
x-go-package: github.com/superseriousbusiness/gotosocial/internal/api/s2s/user x-go-package: github.com/superseriousbusiness/gotosocial/internal/api/activitypub/users
tag: tag:
properties: properties:
name: name:
@ -2049,9 +2049,9 @@ paths:
description: "" description: ""
schema: schema:
$ref: '#/definitions/wellKnownResponse' $ref: '#/definitions/wellKnownResponse'
summary: Directs callers to /nodeinfo/2.0. summary: Returns a well-known response which redirects callers to `/nodeinfo/2.0`.
tags: tags:
- nodeinfo - .well-known
/.well-known/webfinger: /.well-known/webfinger:
get: get:
description: |- description: |-
@ -2074,7 +2074,7 @@ paths:
$ref: '#/definitions/wellKnownResponse' $ref: '#/definitions/wellKnownResponse'
summary: Handles webfinger account lookup requests. summary: Handles webfinger account lookup requests.
tags: tags:
- webfinger - .well-known
/api/{api_version}/media: /api/{api_version}/media:
post: post:
consumes: consumes:

View file

@ -641,9 +641,11 @@ syslog-address: "localhost:514"
# Default: "lax" # Default: "lax"
advanced-cookies-samesite: "lax" advanced-cookies-samesite: "lax"
# Int. Amount of requests to permit from a single IP address within a span of 5 minutes. # Int. Amount of requests to permit per router grouping from a single IP address within
# If this amount is exceeded, a 429 HTTP error code will be returned. # a span of 5 minutes. If this amount is exceeded, a 429 HTTP error code will be returned.
# See https://docs.gotosocial.org/en/latest/api/swagger/#rate-limit. #
# Router groupings and rate limit headers are described here:
# https://docs.gotosocial.org/en/latest/api/swagger/#rate-limit.
# #
# If you find yourself adjusting this limit because it's regularly being exceeded, # If you find yourself adjusting this limit because it's regularly being exceeded,
# you should first verify that your settings for `trusted-proxies` (above) are correct. # you should first verify that your settings for `trusted-proxies` (above) are correct.
@ -655,5 +657,5 @@ advanced-cookies-samesite: "lax"
# If you set this to 0 or less, rate limiting will be disabled entirely. # If you set this to 0 or less, rate limiting will be disabled entirely.
# #
# Examples: [1000, 500, 0] # Examples: [1000, 500, 0]
# Default: 1000 # Default: 300
advanced-rate-limit-requests: 1000 advanced-rate-limit-requests: 300

View file

@ -0,0 +1,66 @@
/*
GoToSocial
Copyright (C) 2021-2022 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package api
import (
"context"
"net/url"
"github.com/superseriousbusiness/gotosocial/internal/api/activitypub/emoji"
"github.com/superseriousbusiness/gotosocial/internal/api/activitypub/users"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/middleware"
"github.com/superseriousbusiness/gotosocial/internal/processing"
"github.com/superseriousbusiness/gotosocial/internal/router"
)
type ActivityPub struct {
emoji *emoji.Module
users *users.Module
isURIBlocked func(context.Context, *url.URL) (bool, db.Error)
}
func (a *ActivityPub) Route(r router.Router) {
// create groupings for the 'emoji' and 'users' prefixes
emojiGroup := r.AttachGroup("emoji")
usersGroup := r.AttachGroup("users")
// instantiate + attach shared, non-global middlewares to both of these groups
var (
rateLimitMiddleware = middleware.RateLimit() // nolint:contextcheck
signatureCheckMiddleware = middleware.SignatureCheck(a.isURIBlocked)
gzipMiddleware = middleware.Gzip()
cacheControlMiddleware = middleware.CacheControl("no-store")
)
emojiGroup.Use(rateLimitMiddleware, signatureCheckMiddleware, gzipMiddleware, cacheControlMiddleware)
usersGroup.Use(rateLimitMiddleware, signatureCheckMiddleware, gzipMiddleware, cacheControlMiddleware)
a.emoji.Route(emojiGroup.Handle)
a.users.Route(usersGroup.Handle)
}
func NewActivityPub(db db.DB, p processing.Processor) *ActivityPub {
return &ActivityPub{
emoji: emoji.New(p),
users: users.New(p),
isURIBlocked: db.IsURIBlocked,
}
}

View file

@ -21,30 +21,27 @@ package emoji
import ( import (
"net/http" "net/http"
"github.com/superseriousbusiness/gotosocial/internal/api" "github.com/gin-gonic/gin"
"github.com/superseriousbusiness/gotosocial/internal/processing" "github.com/superseriousbusiness/gotosocial/internal/processing"
"github.com/superseriousbusiness/gotosocial/internal/router"
) )
const ( const (
// BasePath is the base path for serving the emoji API // EmojiIDKey is for emoji IDs
BasePath = "/api/v1/custom_emojis" EmojiIDKey = "id"
// EmojiBasePath is the base path for serving AP Emojis, minus the "emoji" prefix
EmojiWithIDPath = "/:" + EmojiIDKey
) )
// Module implements the ClientAPIModule interface for everything related to emoji
type Module struct { type Module struct {
processor processing.Processor processor processing.Processor
} }
// New returns a new emoji module func New(processor processing.Processor) *Module {
func New(processor processing.Processor) api.ClientModule {
return &Module{ return &Module{
processor: processor, processor: processor,
} }
} }
// Route attaches all routes from this module to the given router func (m *Module) Route(attachHandler func(method string, path string, f ...gin.HandlerFunc) gin.IRoutes) {
func (m *Module) Route(r router.Router) error { attachHandler(http.MethodGet, EmojiWithIDPath, m.EmojiGetHandler)
r.AttachHandler(http.MethodGet, BasePath, m.EmojisGETHandler)
return nil
} }

View file

@ -19,54 +19,39 @@
package emoji package emoji
import ( import (
"context"
"encoding/json" "encoding/json"
"errors" "errors"
"net/http" "net/http"
"strings" "strings"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/superseriousbusiness/gotosocial/internal/ap" apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util"
"github.com/superseriousbusiness/gotosocial/internal/api"
"github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/gtserror"
) )
// EmojiGetHandler
func (m *Module) EmojiGetHandler(c *gin.Context) { func (m *Module) EmojiGetHandler(c *gin.Context) {
// usernames on our instance are always lowercase
requestedEmojiID := strings.ToUpper(c.Param(EmojiIDKey)) requestedEmojiID := strings.ToUpper(c.Param(EmojiIDKey))
if requestedEmojiID == "" { if requestedEmojiID == "" {
err := errors.New("no emoji id specified in request") err := errors.New("no emoji id specified in request")
api.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGet)
return return
} }
format, err := api.NegotiateAccept(c, api.ActivityPubAcceptHeaders...) format, err := apiutil.NegotiateAccept(c, apiutil.ActivityPubAcceptHeaders...)
if err != nil { if err != nil {
api.ErrorHandler(c, gtserror.NewErrorNotAcceptable(err, err.Error()), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorNotAcceptable(err, err.Error()), m.processor.InstanceGet)
return return
} }
ctx := c.Request.Context() resp, errWithCode := m.processor.GetFediEmoji(apiutil.TransferSignatureContext(c), requestedEmojiID, c.Request.URL)
verifier, signed := c.Get(string(ap.ContextRequestingPublicKeyVerifier))
if signed {
ctx = context.WithValue(ctx, ap.ContextRequestingPublicKeyVerifier, verifier)
}
signature, signed := c.Get(string(ap.ContextRequestingPublicKeySignature))
if signed {
ctx = context.WithValue(ctx, ap.ContextRequestingPublicKeySignature, signature)
}
resp, errWithCode := m.processor.GetFediEmoji(ctx, requestedEmojiID, c.Request.URL)
if errWithCode != nil { if errWithCode != nil {
api.ErrorHandler(c, errWithCode, m.processor.InstanceGet) apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGet)
return return
} }
b, err := json.Marshal(resp) b, err := json.Marshal(resp)
if err != nil { if err != nil {
api.ErrorHandler(c, gtserror.NewErrorInternalError(err), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorInternalError(err), m.processor.InstanceGet)
return return
} }

View file

@ -26,8 +26,7 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/api/s2s/emoji" "github.com/superseriousbusiness/gotosocial/internal/api/activitypub/emoji"
"github.com/superseriousbusiness/gotosocial/internal/api/security"
"github.com/superseriousbusiness/gotosocial/internal/concurrency" "github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/email" "github.com/superseriousbusiness/gotosocial/internal/email"
@ -35,7 +34,7 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/media" "github.com/superseriousbusiness/gotosocial/internal/media"
"github.com/superseriousbusiness/gotosocial/internal/messages" "github.com/superseriousbusiness/gotosocial/internal/messages"
"github.com/superseriousbusiness/gotosocial/internal/oauth" "github.com/superseriousbusiness/gotosocial/internal/middleware"
"github.com/superseriousbusiness/gotosocial/internal/processing" "github.com/superseriousbusiness/gotosocial/internal/processing"
"github.com/superseriousbusiness/gotosocial/internal/storage" "github.com/superseriousbusiness/gotosocial/internal/storage"
"github.com/superseriousbusiness/gotosocial/internal/typeutils" "github.com/superseriousbusiness/gotosocial/internal/typeutils"
@ -44,20 +43,20 @@ import (
type EmojiGetTestSuite struct { type EmojiGetTestSuite struct {
suite.Suite suite.Suite
db db.DB db db.DB
tc typeutils.TypeConverter tc typeutils.TypeConverter
mediaManager media.Manager mediaManager media.Manager
federator federation.Federator federator federation.Federator
emailSender email.Sender emailSender email.Sender
processor processing.Processor processor processing.Processor
storage *storage.Driver storage *storage.Driver
oauthServer oauth.Server
securityModule *security.Module
testEmojis map[string]*gtsmodel.Emoji testEmojis map[string]*gtsmodel.Emoji
testAccounts map[string]*gtsmodel.Account testAccounts map[string]*gtsmodel.Account
emojiModule *emoji.Module emojiModule *emoji.Module
signatureCheck gin.HandlerFunc
} }
func (suite *EmojiGetTestSuite) SetupSuite() { func (suite *EmojiGetTestSuite) SetupSuite() {
@ -79,12 +78,12 @@ func (suite *EmojiGetTestSuite) SetupTest() {
suite.federator = testrig.NewTestFederator(suite.db, testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil, "../../../../testrig/media"), suite.db, fedWorker), suite.storage, suite.mediaManager, fedWorker) suite.federator = testrig.NewTestFederator(suite.db, testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil, "../../../../testrig/media"), suite.db, fedWorker), suite.storage, suite.mediaManager, fedWorker)
suite.emailSender = testrig.NewEmailSender("../../../../web/template/", nil) suite.emailSender = testrig.NewEmailSender("../../../../web/template/", nil)
suite.processor = testrig.NewTestProcessor(suite.db, suite.storage, suite.federator, suite.emailSender, suite.mediaManager, clientWorker, fedWorker) suite.processor = testrig.NewTestProcessor(suite.db, suite.storage, suite.federator, suite.emailSender, suite.mediaManager, clientWorker, fedWorker)
suite.emojiModule = emoji.New(suite.processor).(*emoji.Module) suite.emojiModule = emoji.New(suite.processor)
suite.oauthServer = testrig.NewTestOauthServer(suite.db)
suite.securityModule = security.New(suite.db, suite.oauthServer).(*security.Module)
testrig.StandardDBSetup(suite.db, suite.testAccounts) testrig.StandardDBSetup(suite.db, suite.testAccounts)
testrig.StandardStorageSetup(suite.storage, "../../../../testrig/media") testrig.StandardStorageSetup(suite.storage, "../../../../testrig/media")
suite.signatureCheck = middleware.SignatureCheck(suite.db.IsURIBlocked)
suite.NoError(suite.processor.Start()) suite.NoError(suite.processor.Start())
} }
@ -108,7 +107,7 @@ func (suite *EmojiGetTestSuite) TestGetEmoji() {
ctx.Request.Header.Set("Date", signedRequest.DateHeader) ctx.Request.Header.Set("Date", signedRequest.DateHeader)
// we need to pass the context through signature check first to set appropriate values on it // we need to pass the context through signature check first to set appropriate values on it
suite.securityModule.SignatureCheck(ctx) suite.signatureCheck(ctx)
// normally the router would populate these params from the path values, // normally the router would populate these params from the path values,
// but because we're calling the function directly, we need to set them manually. // but because we're calling the function directly, we need to set them manually.

View file

@ -16,31 +16,7 @@
along with this program. If not, see <http://www.gnu.org/licenses/>. along with this program. If not, see <http://www.gnu.org/licenses/>.
*/ */
package user package users
import (
"context"
"github.com/gin-gonic/gin"
"github.com/superseriousbusiness/gotosocial/internal/ap"
)
// transferContext transfers the signature verifier and signature from the gin context to the request context
func transferContext(c *gin.Context) context.Context {
ctx := c.Request.Context()
verifier, signed := c.Get(string(ap.ContextRequestingPublicKeyVerifier))
if signed {
ctx = context.WithValue(ctx, ap.ContextRequestingPublicKeyVerifier, verifier)
}
signature, signed := c.Get(string(ap.ContextRequestingPublicKeySignature))
if signed {
ctx = context.WithValue(ctx, ap.ContextRequestingPublicKeySignature, signature)
}
return ctx
}
// SwaggerCollection represents an activitypub collection. // SwaggerCollection represents an activitypub collection.
// swagger:model swaggerCollection // swagger:model swaggerCollection

View file

@ -16,7 +16,7 @@
along with this program. If not, see <http://www.gnu.org/licenses/>. along with this program. If not, see <http://www.gnu.org/licenses/>.
*/ */
package user package users
import ( import (
"encoding/json" "encoding/json"
@ -25,7 +25,7 @@ import (
"strings" "strings"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/superseriousbusiness/gotosocial/internal/api" apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util"
"github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/gtserror"
) )
@ -35,31 +35,31 @@ func (m *Module) FollowersGETHandler(c *gin.Context) {
requestedUsername := strings.ToLower(c.Param(UsernameKey)) requestedUsername := strings.ToLower(c.Param(UsernameKey))
if requestedUsername == "" { if requestedUsername == "" {
err := errors.New("no username specified in request") err := errors.New("no username specified in request")
api.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGet)
return return
} }
format, err := api.NegotiateAccept(c, api.HTMLOrActivityPubHeaders...) format, err := apiutil.NegotiateAccept(c, apiutil.HTMLOrActivityPubHeaders...)
if err != nil { if err != nil {
api.ErrorHandler(c, gtserror.NewErrorNotAcceptable(err, err.Error()), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorNotAcceptable(err, err.Error()), m.processor.InstanceGet)
return return
} }
if format == string(api.TextHTML) { if format == string(apiutil.TextHTML) {
// redirect to the user's profile // redirect to the user's profile
c.Redirect(http.StatusSeeOther, "/@"+requestedUsername) c.Redirect(http.StatusSeeOther, "/@"+requestedUsername)
return return
} }
resp, errWithCode := m.processor.GetFediFollowers(transferContext(c), requestedUsername, c.Request.URL) resp, errWithCode := m.processor.GetFediFollowers(apiutil.TransferSignatureContext(c), requestedUsername, c.Request.URL)
if errWithCode != nil { if errWithCode != nil {
api.ErrorHandler(c, errWithCode, m.processor.InstanceGet) apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGet)
return return
} }
b, err := json.Marshal(resp) b, err := json.Marshal(resp)
if err != nil { if err != nil {
api.ErrorHandler(c, gtserror.NewErrorInternalError(err), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorInternalError(err), m.processor.InstanceGet)
return return
} }

View file

@ -16,7 +16,7 @@
along with this program. If not, see <http://www.gnu.org/licenses/>. along with this program. If not, see <http://www.gnu.org/licenses/>.
*/ */
package user package users
import ( import (
"encoding/json" "encoding/json"
@ -25,7 +25,7 @@ import (
"strings" "strings"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/superseriousbusiness/gotosocial/internal/api" apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util"
"github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/gtserror"
) )
@ -35,31 +35,31 @@ func (m *Module) FollowingGETHandler(c *gin.Context) {
requestedUsername := strings.ToLower(c.Param(UsernameKey)) requestedUsername := strings.ToLower(c.Param(UsernameKey))
if requestedUsername == "" { if requestedUsername == "" {
err := errors.New("no username specified in request") err := errors.New("no username specified in request")
api.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGet)
return return
} }
format, err := api.NegotiateAccept(c, api.HTMLOrActivityPubHeaders...) format, err := apiutil.NegotiateAccept(c, apiutil.HTMLOrActivityPubHeaders...)
if err != nil { if err != nil {
api.ErrorHandler(c, gtserror.NewErrorNotAcceptable(err, err.Error()), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorNotAcceptable(err, err.Error()), m.processor.InstanceGet)
return return
} }
if format == string(api.TextHTML) { if format == string(apiutil.TextHTML) {
// redirect to the user's profile // redirect to the user's profile
c.Redirect(http.StatusSeeOther, "/@"+requestedUsername) c.Redirect(http.StatusSeeOther, "/@"+requestedUsername)
return return
} }
resp, errWithCode := m.processor.GetFediFollowing(transferContext(c), requestedUsername, c.Request.URL) resp, errWithCode := m.processor.GetFediFollowing(apiutil.TransferSignatureContext(c), requestedUsername, c.Request.URL)
if errWithCode != nil { if errWithCode != nil {
api.ErrorHandler(c, errWithCode, m.processor.InstanceGet) apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGet)
return return
} }
b, err := json.Marshal(resp) b, err := json.Marshal(resp)
if err != nil { if err != nil {
api.ErrorHandler(c, gtserror.NewErrorInternalError(err), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorInternalError(err), m.processor.InstanceGet)
return return
} }

View file

@ -16,14 +16,14 @@
along with this program. If not, see <http://www.gnu.org/licenses/>. along with this program. If not, see <http://www.gnu.org/licenses/>.
*/ */
package user package users
import ( import (
"errors" "errors"
"strings" "strings"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/superseriousbusiness/gotosocial/internal/api" apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util"
"github.com/superseriousbusiness/gotosocial/internal/gtserror" //nolint:typecheck "github.com/superseriousbusiness/gotosocial/internal/gtserror" //nolint:typecheck
) )
@ -34,18 +34,18 @@ func (m *Module) InboxPOSTHandler(c *gin.Context) {
requestedUsername := strings.ToLower(c.Param(UsernameKey)) requestedUsername := strings.ToLower(c.Param(UsernameKey))
if requestedUsername == "" { if requestedUsername == "" {
err := errors.New("no username specified in request") err := errors.New("no username specified in request")
api.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGet)
return return
} }
if posted, err := m.processor.InboxPost(transferContext(c), c.Writer, c.Request); err != nil { if posted, err := m.processor.InboxPost(apiutil.TransferSignatureContext(c), c.Writer, c.Request); err != nil {
if withCode, ok := err.(gtserror.WithCode); ok { if withCode, ok := err.(gtserror.WithCode); ok {
api.ErrorHandler(c, withCode, m.processor.InstanceGet) apiutil.ErrorHandler(c, withCode, m.processor.InstanceGet)
} else { } else {
api.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGet)
} }
} else if !posted { } else if !posted {
err := errors.New("unable to process request") err := errors.New("unable to process request")
api.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGet)
} }
} }

View file

@ -16,7 +16,7 @@
along with this program. If not, see <http://www.gnu.org/licenses/>. along with this program. If not, see <http://www.gnu.org/licenses/>.
*/ */
package user_test package users_test
import ( import (
"bytes" "bytes"
@ -32,7 +32,7 @@ import (
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/activity/pub" "github.com/superseriousbusiness/activity/pub"
"github.com/superseriousbusiness/activity/streams" "github.com/superseriousbusiness/activity/streams"
"github.com/superseriousbusiness/gotosocial/internal/api/s2s/user" "github.com/superseriousbusiness/gotosocial/internal/api/activitypub/users"
"github.com/superseriousbusiness/gotosocial/internal/concurrency" "github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
@ -92,7 +92,7 @@ func (suite *InboxPostTestSuite) TestPostBlock() {
federator := testrig.NewTestFederator(suite.db, tc, suite.storage, suite.mediaManager, fedWorker) federator := testrig.NewTestFederator(suite.db, tc, suite.storage, suite.mediaManager, fedWorker)
emailSender := testrig.NewEmailSender("../../../../web/template/", nil) emailSender := testrig.NewEmailSender("../../../../web/template/", nil)
processor := testrig.NewTestProcessor(suite.db, suite.storage, federator, emailSender, suite.mediaManager, clientWorker, fedWorker) processor := testrig.NewTestProcessor(suite.db, suite.storage, federator, emailSender, suite.mediaManager, clientWorker, fedWorker)
userModule := user.New(processor).(*user.Module) userModule := users.New(processor)
suite.NoError(processor.Start()) suite.NoError(processor.Start())
// setup request // setup request
@ -105,13 +105,13 @@ func (suite *InboxPostTestSuite) TestPostBlock() {
ctx.Request.Header.Set("Content-Type", "application/activity+json") ctx.Request.Header.Set("Content-Type", "application/activity+json")
// we need to pass the context through signature check first to set appropriate values on it // we need to pass the context through signature check first to set appropriate values on it
suite.securityModule.SignatureCheck(ctx) suite.signatureCheck(ctx)
// normally the router would populate these params from the path values, // normally the router would populate these params from the path values,
// but because we're calling the function directly, we need to set them manually. // but because we're calling the function directly, we need to set them manually.
ctx.Params = gin.Params{ ctx.Params = gin.Params{
gin.Param{ gin.Param{
Key: user.UsernameKey, Key: users.UsernameKey,
Value: blockedAccount.Username, Value: blockedAccount.Username,
}, },
} }
@ -196,7 +196,7 @@ func (suite *InboxPostTestSuite) TestPostUnblock() {
federator := testrig.NewTestFederator(suite.db, tc, suite.storage, suite.mediaManager, fedWorker) federator := testrig.NewTestFederator(suite.db, tc, suite.storage, suite.mediaManager, fedWorker)
emailSender := testrig.NewEmailSender("../../../../web/template/", nil) emailSender := testrig.NewEmailSender("../../../../web/template/", nil)
processor := testrig.NewTestProcessor(suite.db, suite.storage, federator, emailSender, suite.mediaManager, clientWorker, fedWorker) processor := testrig.NewTestProcessor(suite.db, suite.storage, federator, emailSender, suite.mediaManager, clientWorker, fedWorker)
userModule := user.New(processor).(*user.Module) userModule := users.New(processor)
suite.NoError(processor.Start()) suite.NoError(processor.Start())
// setup request // setup request
@ -209,13 +209,13 @@ func (suite *InboxPostTestSuite) TestPostUnblock() {
ctx.Request.Header.Set("Content-Type", "application/activity+json") ctx.Request.Header.Set("Content-Type", "application/activity+json")
// we need to pass the context through signature check first to set appropriate values on it // we need to pass the context through signature check first to set appropriate values on it
suite.securityModule.SignatureCheck(ctx) suite.signatureCheck(ctx)
// normally the router would populate these params from the path values, // normally the router would populate these params from the path values,
// but because we're calling the function directly, we need to set them manually. // but because we're calling the function directly, we need to set them manually.
ctx.Params = gin.Params{ ctx.Params = gin.Params{
gin.Param{ gin.Param{
Key: user.UsernameKey, Key: users.UsernameKey,
Value: blockedAccount.Username, Value: blockedAccount.Username,
}, },
} }
@ -298,7 +298,7 @@ func (suite *InboxPostTestSuite) TestPostUpdate() {
federator := testrig.NewTestFederator(suite.db, tc, suite.storage, suite.mediaManager, fedWorker) federator := testrig.NewTestFederator(suite.db, tc, suite.storage, suite.mediaManager, fedWorker)
emailSender := testrig.NewEmailSender("../../../../web/template/", nil) emailSender := testrig.NewEmailSender("../../../../web/template/", nil)
processor := testrig.NewTestProcessor(suite.db, suite.storage, federator, emailSender, suite.mediaManager, clientWorker, fedWorker) processor := testrig.NewTestProcessor(suite.db, suite.storage, federator, emailSender, suite.mediaManager, clientWorker, fedWorker)
userModule := user.New(processor).(*user.Module) userModule := users.New(processor)
suite.NoError(processor.Start()) suite.NoError(processor.Start())
// setup request // setup request
@ -311,13 +311,13 @@ func (suite *InboxPostTestSuite) TestPostUpdate() {
ctx.Request.Header.Set("Content-Type", "application/activity+json") ctx.Request.Header.Set("Content-Type", "application/activity+json")
// we need to pass the context through signature check first to set appropriate values on it // we need to pass the context through signature check first to set appropriate values on it
suite.securityModule.SignatureCheck(ctx) suite.signatureCheck(ctx)
// normally the router would populate these params from the path values, // normally the router would populate these params from the path values,
// but because we're calling the function directly, we need to set them manually. // but because we're calling the function directly, we need to set them manually.
ctx.Params = gin.Params{ ctx.Params = gin.Params{
gin.Param{ gin.Param{
Key: user.UsernameKey, Key: users.UsernameKey,
Value: receivingAccount.Username, Value: receivingAccount.Username,
}, },
} }
@ -431,7 +431,7 @@ func (suite *InboxPostTestSuite) TestPostDelete() {
emailSender := testrig.NewEmailSender("../../../../web/template/", nil) emailSender := testrig.NewEmailSender("../../../../web/template/", nil)
processor := testrig.NewTestProcessor(suite.db, suite.storage, federator, emailSender, suite.mediaManager, clientWorker, fedWorker) processor := testrig.NewTestProcessor(suite.db, suite.storage, federator, emailSender, suite.mediaManager, clientWorker, fedWorker)
suite.NoError(processor.Start()) suite.NoError(processor.Start())
userModule := user.New(processor).(*user.Module) userModule := users.New(processor)
// setup request // setup request
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
@ -443,13 +443,13 @@ func (suite *InboxPostTestSuite) TestPostDelete() {
ctx.Request.Header.Set("Content-Type", "application/activity+json") ctx.Request.Header.Set("Content-Type", "application/activity+json")
// we need to pass the context through signature check first to set appropriate values on it // we need to pass the context through signature check first to set appropriate values on it
suite.securityModule.SignatureCheck(ctx) suite.signatureCheck(ctx)
// normally the router would populate these params from the path values, // normally the router would populate these params from the path values,
// but because we're calling the function directly, we need to set them manually. // but because we're calling the function directly, we need to set them manually.
ctx.Params = gin.Params{ ctx.Params = gin.Params{
gin.Param{ gin.Param{
Key: user.UsernameKey, Key: users.UsernameKey,
Value: receivingAccount.Username, Value: receivingAccount.Username,
}, },
} }

View file

@ -16,7 +16,7 @@
along with this program. If not, see <http://www.gnu.org/licenses/>. along with this program. If not, see <http://www.gnu.org/licenses/>.
*/ */
package user package users
import ( import (
"encoding/json" "encoding/json"
@ -27,7 +27,7 @@ import (
"strings" "strings"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/superseriousbusiness/gotosocial/internal/api" apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util"
"github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/gtserror"
) )
@ -90,17 +90,17 @@ func (m *Module) OutboxGETHandler(c *gin.Context) {
requestedUsername := strings.ToLower(c.Param(UsernameKey)) requestedUsername := strings.ToLower(c.Param(UsernameKey))
if requestedUsername == "" { if requestedUsername == "" {
err := errors.New("no username specified in request") err := errors.New("no username specified in request")
api.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGet)
return return
} }
format, err := api.NegotiateAccept(c, api.HTMLOrActivityPubHeaders...) format, err := apiutil.NegotiateAccept(c, apiutil.HTMLOrActivityPubHeaders...)
if err != nil { if err != nil {
api.ErrorHandler(c, gtserror.NewErrorNotAcceptable(err, err.Error()), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorNotAcceptable(err, err.Error()), m.processor.InstanceGet)
return return
} }
if format == string(api.TextHTML) { if format == string(apiutil.TextHTML) {
// redirect to the user's profile // redirect to the user's profile
c.Redirect(http.StatusSeeOther, "/@"+requestedUsername) c.Redirect(http.StatusSeeOther, "/@"+requestedUsername)
return return
@ -111,7 +111,7 @@ func (m *Module) OutboxGETHandler(c *gin.Context) {
i, err := strconv.ParseBool(pageString) i, err := strconv.ParseBool(pageString)
if err != nil { if err != nil {
err := fmt.Errorf("error parsing %s: %s", PageKey, err) err := fmt.Errorf("error parsing %s: %s", PageKey, err)
api.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGet)
return return
} }
page = i page = i
@ -129,15 +129,15 @@ func (m *Module) OutboxGETHandler(c *gin.Context) {
maxID = maxIDString maxID = maxIDString
} }
resp, errWithCode := m.processor.GetFediOutbox(transferContext(c), requestedUsername, page, maxID, minID, c.Request.URL) resp, errWithCode := m.processor.GetFediOutbox(apiutil.TransferSignatureContext(c), requestedUsername, page, maxID, minID, c.Request.URL)
if errWithCode != nil { if errWithCode != nil {
api.ErrorHandler(c, errWithCode, m.processor.InstanceGet) apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGet)
return return
} }
b, err := json.Marshal(resp) b, err := json.Marshal(resp)
if err != nil { if err != nil {
api.ErrorHandler(c, gtserror.NewErrorInternalError(err), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorInternalError(err), m.processor.InstanceGet)
return return
} }

View file

@ -16,7 +16,7 @@
along with this program. If not, see <http://www.gnu.org/licenses/>. along with this program. If not, see <http://www.gnu.org/licenses/>.
*/ */
package user_test package users_test
import ( import (
"context" "context"
@ -30,7 +30,7 @@ import (
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/activity/streams" "github.com/superseriousbusiness/activity/streams"
"github.com/superseriousbusiness/activity/streams/vocab" "github.com/superseriousbusiness/activity/streams/vocab"
"github.com/superseriousbusiness/gotosocial/internal/api/s2s/user" "github.com/superseriousbusiness/gotosocial/internal/api/activitypub/users"
"github.com/superseriousbusiness/gotosocial/internal/concurrency" "github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/messages" "github.com/superseriousbusiness/gotosocial/internal/messages"
"github.com/superseriousbusiness/gotosocial/testrig" "github.com/superseriousbusiness/gotosocial/testrig"
@ -55,13 +55,13 @@ func (suite *OutboxGetTestSuite) TestGetOutbox() {
ctx.Request.Header.Set("Date", signedRequest.DateHeader) ctx.Request.Header.Set("Date", signedRequest.DateHeader)
// we need to pass the context through signature check first to set appropriate values on it // we need to pass the context through signature check first to set appropriate values on it
suite.securityModule.SignatureCheck(ctx) suite.signatureCheck(ctx)
// normally the router would populate these params from the path values, // normally the router would populate these params from the path values,
// but because we're calling the function directly, we need to set them manually. // but because we're calling the function directly, we need to set them manually.
ctx.Params = gin.Params{ ctx.Params = gin.Params{
gin.Param{ gin.Param{
Key: user.UsernameKey, Key: users.UsernameKey,
Value: targetAccount.Username, Value: targetAccount.Username,
}, },
} }
@ -102,7 +102,7 @@ func (suite *OutboxGetTestSuite) TestGetOutboxFirstPage() {
federator := testrig.NewTestFederator(suite.db, tc, suite.storage, suite.mediaManager, fedWorker) federator := testrig.NewTestFederator(suite.db, tc, suite.storage, suite.mediaManager, fedWorker)
emailSender := testrig.NewEmailSender("../../../../web/template/", nil) emailSender := testrig.NewEmailSender("../../../../web/template/", nil)
processor := testrig.NewTestProcessor(suite.db, suite.storage, federator, emailSender, suite.mediaManager, clientWorker, fedWorker) processor := testrig.NewTestProcessor(suite.db, suite.storage, federator, emailSender, suite.mediaManager, clientWorker, fedWorker)
userModule := user.New(processor).(*user.Module) userModule := users.New(processor)
suite.NoError(processor.Start()) suite.NoError(processor.Start())
// setup request // setup request
@ -114,13 +114,13 @@ func (suite *OutboxGetTestSuite) TestGetOutboxFirstPage() {
ctx.Request.Header.Set("Date", signedRequest.DateHeader) ctx.Request.Header.Set("Date", signedRequest.DateHeader)
// we need to pass the context through signature check first to set appropriate values on it // we need to pass the context through signature check first to set appropriate values on it
suite.securityModule.SignatureCheck(ctx) suite.signatureCheck(ctx)
// normally the router would populate these params from the path values, // normally the router would populate these params from the path values,
// but because we're calling the function directly, we need to set them manually. // but because we're calling the function directly, we need to set them manually.
ctx.Params = gin.Params{ ctx.Params = gin.Params{
gin.Param{ gin.Param{
Key: user.UsernameKey, Key: users.UsernameKey,
Value: targetAccount.Username, Value: targetAccount.Username,
}, },
} }
@ -161,7 +161,7 @@ func (suite *OutboxGetTestSuite) TestGetOutboxNextPage() {
federator := testrig.NewTestFederator(suite.db, tc, suite.storage, suite.mediaManager, fedWorker) federator := testrig.NewTestFederator(suite.db, tc, suite.storage, suite.mediaManager, fedWorker)
emailSender := testrig.NewEmailSender("../../../../web/template/", nil) emailSender := testrig.NewEmailSender("../../../../web/template/", nil)
processor := testrig.NewTestProcessor(suite.db, suite.storage, federator, emailSender, suite.mediaManager, clientWorker, fedWorker) processor := testrig.NewTestProcessor(suite.db, suite.storage, federator, emailSender, suite.mediaManager, clientWorker, fedWorker)
userModule := user.New(processor).(*user.Module) userModule := users.New(processor)
suite.NoError(processor.Start()) suite.NoError(processor.Start())
// setup request // setup request
@ -173,17 +173,17 @@ func (suite *OutboxGetTestSuite) TestGetOutboxNextPage() {
ctx.Request.Header.Set("Date", signedRequest.DateHeader) ctx.Request.Header.Set("Date", signedRequest.DateHeader)
// we need to pass the context through signature check first to set appropriate values on it // we need to pass the context through signature check first to set appropriate values on it
suite.securityModule.SignatureCheck(ctx) suite.signatureCheck(ctx)
// normally the router would populate these params from the path values, // normally the router would populate these params from the path values,
// but because we're calling the function directly, we need to set them manually. // but because we're calling the function directly, we need to set them manually.
ctx.Params = gin.Params{ ctx.Params = gin.Params{
gin.Param{ gin.Param{
Key: user.UsernameKey, Key: users.UsernameKey,
Value: targetAccount.Username, Value: targetAccount.Username,
}, },
gin.Param{ gin.Param{
Key: user.MaxIDKey, Key: users.MaxIDKey,
Value: "01F8MHAMCHF6Y650WCRSCP4WMY", Value: "01F8MHAMCHF6Y650WCRSCP4WMY",
}, },
} }

View file

@ -16,7 +16,7 @@
along with this program. If not, see <http://www.gnu.org/licenses/>. along with this program. If not, see <http://www.gnu.org/licenses/>.
*/ */
package user package users
import ( import (
"encoding/json" "encoding/json"
@ -25,7 +25,7 @@ import (
"strings" "strings"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/superseriousbusiness/gotosocial/internal/api" apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util"
"github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/gtserror"
) )
@ -39,31 +39,31 @@ func (m *Module) PublicKeyGETHandler(c *gin.Context) {
requestedUsername := strings.ToLower(c.Param(UsernameKey)) requestedUsername := strings.ToLower(c.Param(UsernameKey))
if requestedUsername == "" { if requestedUsername == "" {
err := errors.New("no username specified in request") err := errors.New("no username specified in request")
api.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGet)
return return
} }
format, err := api.NegotiateAccept(c, api.HTMLOrActivityPubHeaders...) format, err := apiutil.NegotiateAccept(c, apiutil.HTMLOrActivityPubHeaders...)
if err != nil { if err != nil {
api.ErrorHandler(c, gtserror.NewErrorNotAcceptable(err, err.Error()), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorNotAcceptable(err, err.Error()), m.processor.InstanceGet)
return return
} }
if format == string(api.TextHTML) { if format == string(apiutil.TextHTML) {
// redirect to the user's profile // redirect to the user's profile
c.Redirect(http.StatusSeeOther, "/@"+requestedUsername) c.Redirect(http.StatusSeeOther, "/@"+requestedUsername)
return return
} }
resp, errWithCode := m.processor.GetFediUser(transferContext(c), requestedUsername, c.Request.URL) resp, errWithCode := m.processor.GetFediUser(apiutil.TransferSignatureContext(c), requestedUsername, c.Request.URL)
if errWithCode != nil { if errWithCode != nil {
api.ErrorHandler(c, errWithCode, m.processor.InstanceGet) apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGet)
return return
} }
b, err := json.Marshal(resp) b, err := json.Marshal(resp)
if err != nil { if err != nil {
api.ErrorHandler(c, gtserror.NewErrorInternalError(err), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorInternalError(err), m.processor.InstanceGet)
return return
} }

View file

@ -16,7 +16,7 @@
along with this program. If not, see <http://www.gnu.org/licenses/>. along with this program. If not, see <http://www.gnu.org/licenses/>.
*/ */
package user package users
import ( import (
"encoding/json" "encoding/json"
@ -27,7 +27,7 @@ import (
"strings" "strings"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/superseriousbusiness/gotosocial/internal/api" apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util"
"github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/gtserror"
) )
@ -97,7 +97,7 @@ func (m *Module) StatusRepliesGETHandler(c *gin.Context) {
requestedUsername := strings.ToLower(c.Param(UsernameKey)) requestedUsername := strings.ToLower(c.Param(UsernameKey))
if requestedUsername == "" { if requestedUsername == "" {
err := errors.New("no username specified in request") err := errors.New("no username specified in request")
api.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGet)
return return
} }
@ -105,17 +105,17 @@ func (m *Module) StatusRepliesGETHandler(c *gin.Context) {
requestedStatusID := strings.ToUpper(c.Param(StatusIDKey)) requestedStatusID := strings.ToUpper(c.Param(StatusIDKey))
if requestedStatusID == "" { if requestedStatusID == "" {
err := errors.New("no status id specified in request") err := errors.New("no status id specified in request")
api.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGet)
return return
} }
format, err := api.NegotiateAccept(c, api.HTMLOrActivityPubHeaders...) format, err := apiutil.NegotiateAccept(c, apiutil.HTMLOrActivityPubHeaders...)
if err != nil { if err != nil {
api.ErrorHandler(c, gtserror.NewErrorNotAcceptable(err, err.Error()), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorNotAcceptable(err, err.Error()), m.processor.InstanceGet)
return return
} }
if format == string(api.TextHTML) { if format == string(apiutil.TextHTML) {
// redirect to the status // redirect to the status
c.Redirect(http.StatusSeeOther, "/@"+requestedUsername+"/statuses/"+requestedStatusID) c.Redirect(http.StatusSeeOther, "/@"+requestedUsername+"/statuses/"+requestedStatusID)
return return
@ -126,7 +126,7 @@ func (m *Module) StatusRepliesGETHandler(c *gin.Context) {
i, err := strconv.ParseBool(pageString) i, err := strconv.ParseBool(pageString)
if err != nil { if err != nil {
err := fmt.Errorf("error parsing %s: %s", PageKey, err) err := fmt.Errorf("error parsing %s: %s", PageKey, err)
api.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGet)
return return
} }
page = i page = i
@ -138,7 +138,7 @@ func (m *Module) StatusRepliesGETHandler(c *gin.Context) {
i, err := strconv.ParseBool(onlyOtherAccountsString) i, err := strconv.ParseBool(onlyOtherAccountsString)
if err != nil { if err != nil {
err := fmt.Errorf("error parsing %s: %s", OnlyOtherAccountsKey, err) err := fmt.Errorf("error parsing %s: %s", OnlyOtherAccountsKey, err)
api.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGet)
return return
} }
onlyOtherAccounts = i onlyOtherAccounts = i
@ -150,15 +150,15 @@ func (m *Module) StatusRepliesGETHandler(c *gin.Context) {
minID = minIDString minID = minIDString
} }
resp, errWithCode := m.processor.GetFediStatusReplies(transferContext(c), requestedUsername, requestedStatusID, page, onlyOtherAccounts, minID, c.Request.URL) resp, errWithCode := m.processor.GetFediStatusReplies(apiutil.TransferSignatureContext(c), requestedUsername, requestedStatusID, page, onlyOtherAccounts, minID, c.Request.URL)
if errWithCode != nil { if errWithCode != nil {
api.ErrorHandler(c, errWithCode, m.processor.InstanceGet) apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGet)
return return
} }
b, err := json.Marshal(resp) b, err := json.Marshal(resp)
if err != nil { if err != nil {
api.ErrorHandler(c, gtserror.NewErrorInternalError(err), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorInternalError(err), m.processor.InstanceGet)
return return
} }

View file

@ -16,7 +16,7 @@
along with this program. If not, see <http://www.gnu.org/licenses/>. along with this program. If not, see <http://www.gnu.org/licenses/>.
*/ */
package user_test package users_test
import ( import (
"context" "context"
@ -32,7 +32,7 @@ import (
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/activity/streams" "github.com/superseriousbusiness/activity/streams"
"github.com/superseriousbusiness/activity/streams/vocab" "github.com/superseriousbusiness/activity/streams/vocab"
"github.com/superseriousbusiness/gotosocial/internal/api/s2s/user" "github.com/superseriousbusiness/gotosocial/internal/api/activitypub/users"
"github.com/superseriousbusiness/gotosocial/internal/concurrency" "github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/messages" "github.com/superseriousbusiness/gotosocial/internal/messages"
"github.com/superseriousbusiness/gotosocial/testrig" "github.com/superseriousbusiness/gotosocial/testrig"
@ -58,17 +58,17 @@ func (suite *RepliesGetTestSuite) TestGetReplies() {
ctx.Request.Header.Set("Date", signedRequest.DateHeader) ctx.Request.Header.Set("Date", signedRequest.DateHeader)
// we need to pass the context through signature check first to set appropriate values on it // we need to pass the context through signature check first to set appropriate values on it
suite.securityModule.SignatureCheck(ctx) suite.signatureCheck(ctx)
// normally the router would populate these params from the path values, // normally the router would populate these params from the path values,
// but because we're calling the function directly, we need to set them manually. // but because we're calling the function directly, we need to set them manually.
ctx.Params = gin.Params{ ctx.Params = gin.Params{
gin.Param{ gin.Param{
Key: user.UsernameKey, Key: users.UsernameKey,
Value: targetAccount.Username, Value: targetAccount.Username,
}, },
gin.Param{ gin.Param{
Key: user.StatusIDKey, Key: users.StatusIDKey,
Value: targetStatus.ID, Value: targetStatus.ID,
}, },
} }
@ -111,7 +111,7 @@ func (suite *RepliesGetTestSuite) TestGetRepliesNext() {
federator := testrig.NewTestFederator(suite.db, tc, suite.storage, suite.mediaManager, fedWorker) federator := testrig.NewTestFederator(suite.db, tc, suite.storage, suite.mediaManager, fedWorker)
emailSender := testrig.NewEmailSender("../../../../web/template/", nil) emailSender := testrig.NewEmailSender("../../../../web/template/", nil)
processor := testrig.NewTestProcessor(suite.db, suite.storage, federator, emailSender, suite.mediaManager, clientWorker, fedWorker) processor := testrig.NewTestProcessor(suite.db, suite.storage, federator, emailSender, suite.mediaManager, clientWorker, fedWorker)
userModule := user.New(processor).(*user.Module) userModule := users.New(processor)
suite.NoError(processor.Start()) suite.NoError(processor.Start())
// setup request // setup request
@ -123,17 +123,17 @@ func (suite *RepliesGetTestSuite) TestGetRepliesNext() {
ctx.Request.Header.Set("Date", signedRequest.DateHeader) ctx.Request.Header.Set("Date", signedRequest.DateHeader)
// we need to pass the context through signature check first to set appropriate values on it // we need to pass the context through signature check first to set appropriate values on it
suite.securityModule.SignatureCheck(ctx) suite.signatureCheck(ctx)
// normally the router would populate these params from the path values, // normally the router would populate these params from the path values,
// but because we're calling the function directly, we need to set them manually. // but because we're calling the function directly, we need to set them manually.
ctx.Params = gin.Params{ ctx.Params = gin.Params{
gin.Param{ gin.Param{
Key: user.UsernameKey, Key: users.UsernameKey,
Value: targetAccount.Username, Value: targetAccount.Username,
}, },
gin.Param{ gin.Param{
Key: user.StatusIDKey, Key: users.StatusIDKey,
Value: targetStatus.ID, Value: targetStatus.ID,
}, },
} }
@ -179,7 +179,7 @@ func (suite *RepliesGetTestSuite) TestGetRepliesLast() {
federator := testrig.NewTestFederator(suite.db, tc, suite.storage, suite.mediaManager, fedWorker) federator := testrig.NewTestFederator(suite.db, tc, suite.storage, suite.mediaManager, fedWorker)
emailSender := testrig.NewEmailSender("../../../../web/template/", nil) emailSender := testrig.NewEmailSender("../../../../web/template/", nil)
processor := testrig.NewTestProcessor(suite.db, suite.storage, federator, emailSender, suite.mediaManager, clientWorker, fedWorker) processor := testrig.NewTestProcessor(suite.db, suite.storage, federator, emailSender, suite.mediaManager, clientWorker, fedWorker)
userModule := user.New(processor).(*user.Module) userModule := users.New(processor)
suite.NoError(processor.Start()) suite.NoError(processor.Start())
// setup request // setup request
@ -191,17 +191,17 @@ func (suite *RepliesGetTestSuite) TestGetRepliesLast() {
ctx.Request.Header.Set("Date", signedRequest.DateHeader) ctx.Request.Header.Set("Date", signedRequest.DateHeader)
// we need to pass the context through signature check first to set appropriate values on it // we need to pass the context through signature check first to set appropriate values on it
suite.securityModule.SignatureCheck(ctx) suite.signatureCheck(ctx)
// normally the router would populate these params from the path values, // normally the router would populate these params from the path values,
// but because we're calling the function directly, we need to set them manually. // but because we're calling the function directly, we need to set them manually.
ctx.Params = gin.Params{ ctx.Params = gin.Params{
gin.Param{ gin.Param{
Key: user.UsernameKey, Key: users.UsernameKey,
Value: targetAccount.Username, Value: targetAccount.Username,
}, },
gin.Param{ gin.Param{
Key: user.StatusIDKey, Key: users.StatusIDKey,
Value: targetStatus.ID, Value: targetStatus.ID,
}, },
} }

View file

@ -16,7 +16,7 @@
along with this program. If not, see <http://www.gnu.org/licenses/>. along with this program. If not, see <http://www.gnu.org/licenses/>.
*/ */
package user package users
import ( import (
"encoding/json" "encoding/json"
@ -25,7 +25,7 @@ import (
"strings" "strings"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/superseriousbusiness/gotosocial/internal/api" apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util"
"github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/gtserror"
) )
@ -35,7 +35,7 @@ func (m *Module) StatusGETHandler(c *gin.Context) {
requestedUsername := strings.ToLower(c.Param(UsernameKey)) requestedUsername := strings.ToLower(c.Param(UsernameKey))
if requestedUsername == "" { if requestedUsername == "" {
err := errors.New("no username specified in request") err := errors.New("no username specified in request")
api.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGet)
return return
} }
@ -43,31 +43,31 @@ func (m *Module) StatusGETHandler(c *gin.Context) {
requestedStatusID := strings.ToUpper(c.Param(StatusIDKey)) requestedStatusID := strings.ToUpper(c.Param(StatusIDKey))
if requestedStatusID == "" { if requestedStatusID == "" {
err := errors.New("no status id specified in request") err := errors.New("no status id specified in request")
api.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGet)
return return
} }
format, err := api.NegotiateAccept(c, api.HTMLOrActivityPubHeaders...) format, err := apiutil.NegotiateAccept(c, apiutil.HTMLOrActivityPubHeaders...)
if err != nil { if err != nil {
api.ErrorHandler(c, gtserror.NewErrorNotAcceptable(err, err.Error()), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorNotAcceptable(err, err.Error()), m.processor.InstanceGet)
return return
} }
if format == string(api.TextHTML) { if format == string(apiutil.TextHTML) {
// redirect to the status // redirect to the status
c.Redirect(http.StatusSeeOther, "/@"+requestedUsername+"/statuses/"+requestedStatusID) c.Redirect(http.StatusSeeOther, "/@"+requestedUsername+"/statuses/"+requestedStatusID)
return return
} }
resp, errWithCode := m.processor.GetFediStatus(transferContext(c), requestedUsername, requestedStatusID, c.Request.URL) resp, errWithCode := m.processor.GetFediStatus(apiutil.TransferSignatureContext(c), requestedUsername, requestedStatusID, c.Request.URL)
if errWithCode != nil { if errWithCode != nil {
api.ErrorHandler(c, errWithCode, m.processor.InstanceGet) apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGet)
return return
} }
b, err := json.Marshal(resp) b, err := json.Marshal(resp)
if err != nil { if err != nil {
api.ErrorHandler(c, gtserror.NewErrorInternalError(err), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorInternalError(err), m.processor.InstanceGet)
return return
} }

View file

@ -16,7 +16,7 @@
along with this program. If not, see <http://www.gnu.org/licenses/>. along with this program. If not, see <http://www.gnu.org/licenses/>.
*/ */
package user_test package users_test
import ( import (
"context" "context"
@ -31,7 +31,7 @@ import (
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/activity/streams" "github.com/superseriousbusiness/activity/streams"
"github.com/superseriousbusiness/activity/streams/vocab" "github.com/superseriousbusiness/activity/streams/vocab"
"github.com/superseriousbusiness/gotosocial/internal/api/s2s/user" "github.com/superseriousbusiness/gotosocial/internal/api/activitypub/users"
"github.com/superseriousbusiness/gotosocial/testrig" "github.com/superseriousbusiness/gotosocial/testrig"
) )
@ -55,17 +55,17 @@ func (suite *StatusGetTestSuite) TestGetStatus() {
ctx.Request.Header.Set("Date", signedRequest.DateHeader) ctx.Request.Header.Set("Date", signedRequest.DateHeader)
// we need to pass the context through signature check first to set appropriate values on it // we need to pass the context through signature check first to set appropriate values on it
suite.securityModule.SignatureCheck(ctx) suite.signatureCheck(ctx)
// normally the router would populate these params from the path values, // normally the router would populate these params from the path values,
// but because we're calling the function directly, we need to set them manually. // but because we're calling the function directly, we need to set them manually.
ctx.Params = gin.Params{ ctx.Params = gin.Params{
gin.Param{ gin.Param{
Key: user.UsernameKey, Key: users.UsernameKey,
Value: targetAccount.Username, Value: targetAccount.Username,
}, },
gin.Param{ gin.Param{
Key: user.StatusIDKey, Key: users.StatusIDKey,
Value: targetStatus.ID, Value: targetStatus.ID,
}, },
} }
@ -114,17 +114,17 @@ func (suite *StatusGetTestSuite) TestGetStatusLowercase() {
ctx.Request.Header.Set("Date", signedRequest.DateHeader) ctx.Request.Header.Set("Date", signedRequest.DateHeader)
// we need to pass the context through signature check first to set appropriate values on it // we need to pass the context through signature check first to set appropriate values on it
suite.securityModule.SignatureCheck(ctx) suite.signatureCheck(ctx)
// normally the router would populate these params from the path values, // normally the router would populate these params from the path values,
// but because we're calling the function directly, we need to set them manually. // but because we're calling the function directly, we need to set them manually.
ctx.Params = gin.Params{ ctx.Params = gin.Params{
gin.Param{ gin.Param{
Key: user.UsernameKey, Key: users.UsernameKey,
Value: strings.ToLower(targetAccount.Username), Value: strings.ToLower(targetAccount.Username),
}, },
gin.Param{ gin.Param{
Key: user.StatusIDKey, Key: users.StatusIDKey,
Value: strings.ToLower(targetStatus.ID), Value: strings.ToLower(targetStatus.ID),
}, },
} }

View file

@ -0,0 +1,80 @@
/*
GoToSocial
Copyright (C) 2021-2022 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package users
import (
"net/http"
"github.com/gin-gonic/gin"
"github.com/superseriousbusiness/gotosocial/internal/processing"
"github.com/superseriousbusiness/gotosocial/internal/uris"
)
const (
// UsernameKey is for account usernames.
UsernameKey = "username"
// StatusIDKey is for status IDs
StatusIDKey = "status"
// OnlyOtherAccountsKey is for filtering status responses.
OnlyOtherAccountsKey = "only_other_accounts"
// MinIDKey is for filtering status responses.
MinIDKey = "min_id"
// MaxIDKey is for filtering status responses.
MaxIDKey = "max_id"
// PageKey is for filtering status responses.
PageKey = "page"
// BasePath is the base path for serving AP 'users' requests, minus the 'users' prefix.
BasePath = "/:" + UsernameKey
// PublicKeyPath is a path to a user's public key, for serving bare minimum AP representations.
PublicKeyPath = BasePath + "/" + uris.PublicKeyPath
// InboxPath is for serving POST requests to a user's inbox with the given username key.
InboxPath = BasePath + "/" + uris.InboxPath
// OutboxPath is for serving GET requests to a user's outbox with the given username key.
OutboxPath = BasePath + "/" + uris.OutboxPath
// FollowersPath is for serving GET request's to a user's followers list, with the given username key.
FollowersPath = BasePath + "/" + uris.FollowersPath
// FollowingPath is for serving GET request's to a user's following list, with the given username key.
FollowingPath = BasePath + "/" + uris.FollowingPath
// StatusPath is for serving GET requests to a particular status by a user, with the given username key and status ID
StatusPath = BasePath + "/" + uris.StatusesPath + "/:" + StatusIDKey
// StatusRepliesPath is for serving the replies collection of a status.
StatusRepliesPath = StatusPath + "/replies"
)
type Module struct {
processor processing.Processor
}
func New(processor processing.Processor) *Module {
return &Module{
processor: processor,
}
}
func (m *Module) Route(attachHandler func(method string, path string, f ...gin.HandlerFunc) gin.IRoutes) {
attachHandler(http.MethodGet, BasePath, m.UsersGETHandler)
attachHandler(http.MethodPost, InboxPath, m.InboxPOSTHandler)
attachHandler(http.MethodGet, FollowersPath, m.FollowersGETHandler)
attachHandler(http.MethodGet, FollowingPath, m.FollowingGETHandler)
attachHandler(http.MethodGet, StatusPath, m.StatusGETHandler)
attachHandler(http.MethodGet, PublicKeyPath, m.PublicKeyGETHandler)
attachHandler(http.MethodGet, StatusRepliesPath, m.StatusRepliesGETHandler)
attachHandler(http.MethodGet, OutboxPath, m.OutboxGETHandler)
}

View file

@ -16,12 +16,12 @@
along with this program. If not, see <http://www.gnu.org/licenses/>. along with this program. If not, see <http://www.gnu.org/licenses/>.
*/ */
package user_test package users_test
import ( import (
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/api/s2s/user" "github.com/superseriousbusiness/gotosocial/internal/api/activitypub/users"
"github.com/superseriousbusiness/gotosocial/internal/api/security"
"github.com/superseriousbusiness/gotosocial/internal/concurrency" "github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/email" "github.com/superseriousbusiness/gotosocial/internal/email"
@ -29,7 +29,7 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/media" "github.com/superseriousbusiness/gotosocial/internal/media"
"github.com/superseriousbusiness/gotosocial/internal/messages" "github.com/superseriousbusiness/gotosocial/internal/messages"
"github.com/superseriousbusiness/gotosocial/internal/oauth" "github.com/superseriousbusiness/gotosocial/internal/middleware"
"github.com/superseriousbusiness/gotosocial/internal/processing" "github.com/superseriousbusiness/gotosocial/internal/processing"
"github.com/superseriousbusiness/gotosocial/internal/storage" "github.com/superseriousbusiness/gotosocial/internal/storage"
"github.com/superseriousbusiness/gotosocial/internal/typeutils" "github.com/superseriousbusiness/gotosocial/internal/typeutils"
@ -39,15 +39,13 @@ import (
type UserStandardTestSuite struct { type UserStandardTestSuite struct {
// standard suite interfaces // standard suite interfaces
suite.Suite suite.Suite
db db.DB db db.DB
tc typeutils.TypeConverter tc typeutils.TypeConverter
mediaManager media.Manager mediaManager media.Manager
federator federation.Federator federator federation.Federator
emailSender email.Sender emailSender email.Sender
processor processing.Processor processor processing.Processor
storage *storage.Driver storage *storage.Driver
oauthServer oauth.Server
securityModule *security.Module
// standard suite models // standard suite models
testTokens map[string]*gtsmodel.Token testTokens map[string]*gtsmodel.Token
@ -60,7 +58,9 @@ type UserStandardTestSuite struct {
testBlocks map[string]*gtsmodel.Block testBlocks map[string]*gtsmodel.Block
// module being tested // module being tested
userModule *user.Module userModule *users.Module
signatureCheck gin.HandlerFunc
} }
func (suite *UserStandardTestSuite) SetupSuite() { func (suite *UserStandardTestSuite) SetupSuite() {
@ -88,12 +88,12 @@ func (suite *UserStandardTestSuite) SetupTest() {
suite.federator = testrig.NewTestFederator(suite.db, testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil, "../../../../testrig/media"), suite.db, fedWorker), suite.storage, suite.mediaManager, fedWorker) suite.federator = testrig.NewTestFederator(suite.db, testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil, "../../../../testrig/media"), suite.db, fedWorker), suite.storage, suite.mediaManager, fedWorker)
suite.emailSender = testrig.NewEmailSender("../../../../web/template/", nil) suite.emailSender = testrig.NewEmailSender("../../../../web/template/", nil)
suite.processor = testrig.NewTestProcessor(suite.db, suite.storage, suite.federator, suite.emailSender, suite.mediaManager, clientWorker, fedWorker) suite.processor = testrig.NewTestProcessor(suite.db, suite.storage, suite.federator, suite.emailSender, suite.mediaManager, clientWorker, fedWorker)
suite.userModule = user.New(suite.processor).(*user.Module) suite.userModule = users.New(suite.processor)
suite.oauthServer = testrig.NewTestOauthServer(suite.db)
suite.securityModule = security.New(suite.db, suite.oauthServer).(*security.Module)
testrig.StandardDBSetup(suite.db, suite.testAccounts) testrig.StandardDBSetup(suite.db, suite.testAccounts)
testrig.StandardStorageSetup(suite.storage, "../../../../testrig/media") testrig.StandardStorageSetup(suite.storage, "../../../../testrig/media")
suite.signatureCheck = middleware.SignatureCheck(suite.db.IsURIBlocked)
suite.NoError(suite.processor.Start()) suite.NoError(suite.processor.Start())
} }

View file

@ -16,7 +16,7 @@
along with this program. If not, see <http://www.gnu.org/licenses/>. along with this program. If not, see <http://www.gnu.org/licenses/>.
*/ */
package user package users
import ( import (
"encoding/json" "encoding/json"
@ -25,7 +25,7 @@ import (
"strings" "strings"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/superseriousbusiness/gotosocial/internal/api" apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util"
"github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/gtserror"
) )
@ -43,31 +43,31 @@ func (m *Module) UsersGETHandler(c *gin.Context) {
requestedUsername := strings.ToLower(c.Param(UsernameKey)) requestedUsername := strings.ToLower(c.Param(UsernameKey))
if requestedUsername == "" { if requestedUsername == "" {
err := errors.New("no username specified in request") err := errors.New("no username specified in request")
api.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGet)
return return
} }
format, err := api.NegotiateAccept(c, api.HTMLOrActivityPubHeaders...) format, err := apiutil.NegotiateAccept(c, apiutil.HTMLOrActivityPubHeaders...)
if err != nil { if err != nil {
api.ErrorHandler(c, gtserror.NewErrorNotAcceptable(err, err.Error()), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorNotAcceptable(err, err.Error()), m.processor.InstanceGet)
return return
} }
if format == string(api.TextHTML) { if format == string(apiutil.TextHTML) {
// redirect to the user's profile // redirect to the user's profile
c.Redirect(http.StatusSeeOther, "/@"+requestedUsername) c.Redirect(http.StatusSeeOther, "/@"+requestedUsername)
return return
} }
resp, errWithCode := m.processor.GetFediUser(transferContext(c), requestedUsername, c.Request.URL) resp, errWithCode := m.processor.GetFediUser(apiutil.TransferSignatureContext(c), requestedUsername, c.Request.URL)
if errWithCode != nil { if errWithCode != nil {
api.ErrorHandler(c, errWithCode, m.processor.InstanceGet) apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGet)
return return
} }
b, err := json.Marshal(resp) b, err := json.Marshal(resp)
if err != nil { if err != nil {
api.ErrorHandler(c, gtserror.NewErrorInternalError(err), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorInternalError(err), m.processor.InstanceGet)
return return
} }

View file

@ -16,7 +16,7 @@
along with this program. If not, see <http://www.gnu.org/licenses/>. along with this program. If not, see <http://www.gnu.org/licenses/>.
*/ */
package user_test package users_test
import ( import (
"context" "context"
@ -30,8 +30,8 @@ import (
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/activity/streams" "github.com/superseriousbusiness/activity/streams"
"github.com/superseriousbusiness/activity/streams/vocab" "github.com/superseriousbusiness/activity/streams/vocab"
"github.com/superseriousbusiness/gotosocial/internal/api/activitypub/users"
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
"github.com/superseriousbusiness/gotosocial/internal/api/s2s/user"
"github.com/superseriousbusiness/gotosocial/internal/oauth" "github.com/superseriousbusiness/gotosocial/internal/oauth"
"github.com/superseriousbusiness/gotosocial/testrig" "github.com/superseriousbusiness/gotosocial/testrig"
) )
@ -55,13 +55,13 @@ func (suite *UserGetTestSuite) TestGetUser() {
ctx.Request.Header.Set("Date", signedRequest.DateHeader) ctx.Request.Header.Set("Date", signedRequest.DateHeader)
// we need to pass the context through signature check first to set appropriate values on it // we need to pass the context through signature check first to set appropriate values on it
suite.securityModule.SignatureCheck(ctx) suite.signatureCheck(ctx)
// normally the router would populate these params from the path values, // normally the router would populate these params from the path values,
// but because we're calling the function directly, we need to set them manually. // but because we're calling the function directly, we need to set them manually.
ctx.Params = gin.Params{ ctx.Params = gin.Params{
gin.Param{ gin.Param{
Key: user.UsernameKey, Key: users.UsernameKey,
Value: targetAccount.Username, Value: targetAccount.Username,
}, },
} }
@ -98,7 +98,7 @@ func (suite *UserGetTestSuite) TestGetUser() {
// TestGetUserPublicKeyDeleted checks whether the public key of a deleted account can still be dereferenced. // TestGetUserPublicKeyDeleted checks whether the public key of a deleted account can still be dereferenced.
// This is needed by remote instances for authenticating delete requests and stuff like that. // This is needed by remote instances for authenticating delete requests and stuff like that.
func (suite *UserGetTestSuite) TestGetUserPublicKeyDeleted() { func (suite *UserGetTestSuite) TestGetUserPublicKeyDeleted() {
userModule := user.New(suite.processor).(*user.Module) userModule := users.New(suite.processor)
targetAccount := suite.testAccounts["local_account_1"] targetAccount := suite.testAccounts["local_account_1"]
// first delete the account, as though zork had deleted himself // first delete the account, as though zork had deleted himself
@ -133,13 +133,13 @@ func (suite *UserGetTestSuite) TestGetUserPublicKeyDeleted() {
ctx.Request.Header.Set("Date", signedRequest.DateHeader) ctx.Request.Header.Set("Date", signedRequest.DateHeader)
// we need to pass the context through signature check first to set appropriate values on it // we need to pass the context through signature check first to set appropriate values on it
suite.securityModule.SignatureCheck(ctx) suite.signatureCheck(ctx)
// normally the router would populate these params from the path values, // normally the router would populate these params from the path values,
// but because we're calling the function directly, we need to set them manually. // but because we're calling the function directly, we need to set them manually.
ctx.Params = gin.Params{ ctx.Params = gin.Params{
gin.Param{ gin.Param{
Key: user.UsernameKey, Key: users.UsernameKey,
Value: targetAccount.Username, Value: targetAccount.Username,
}, },
} }

View file

@ -1,37 +0,0 @@
/*
GoToSocial
Copyright (C) 2021-2022 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package api
import (
"github.com/superseriousbusiness/gotosocial/internal/router"
)
// ClientModule represents a chunk of code (usually contained in a single package) that adds a set
// of functionalities and/or side effects to a router, by mapping routes and/or middlewares onto it--in other words, a REST API ;)
// A ClientAPIMpdule with routes corresponds roughly to one main path of the gotosocial REST api, for example /api/v1/accounts/ or /oauth/
type ClientModule interface {
Route(s router.Router) error
}
// FederationModule represents a chunk of code (usually contained in a single package) that adds a set
// of functionalities and/or side effects to a router, by mapping routes and/or middlewares onto it--in other words, a REST API ;)
// Unlike ClientAPIModule, federation API module is not intended to be interacted with by clients directly -- it is primarily a server-to-server interface.
type FederationModule interface {
Route(s router.Router) error
}

64
internal/api/auth.go Normal file
View file

@ -0,0 +1,64 @@
/*
GoToSocial
Copyright (C) 2021-2022 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package api
import (
"github.com/superseriousbusiness/gotosocial/internal/api/auth"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/middleware"
"github.com/superseriousbusiness/gotosocial/internal/oidc"
"github.com/superseriousbusiness/gotosocial/internal/processing"
"github.com/superseriousbusiness/gotosocial/internal/router"
)
type Auth struct {
routerSession *gtsmodel.RouterSession
sessionName string
auth *auth.Module
}
// Route attaches 'auth' and 'oauth' groups to the given router.
func (a *Auth) Route(r router.Router) {
// create groupings for the 'auth' and 'oauth' prefixes
authGroup := r.AttachGroup("auth")
oauthGroup := r.AttachGroup("oauth")
// instantiate + attach shared, non-global middlewares to both of these groups
var (
rateLimitMiddleware = middleware.RateLimit() // nolint:contextcheck
gzipMiddleware = middleware.Gzip()
cacheControlMiddleware = middleware.CacheControl("private", "max-age=120")
sessionMiddleware = middleware.Session(a.sessionName, a.routerSession.Auth, a.routerSession.Crypt)
)
authGroup.Use(rateLimitMiddleware, gzipMiddleware, cacheControlMiddleware, sessionMiddleware)
oauthGroup.Use(rateLimitMiddleware, gzipMiddleware, cacheControlMiddleware, sessionMiddleware)
a.auth.RouteAuth(authGroup.Handle)
a.auth.RouteOauth(oauthGroup.Handle)
}
func NewAuth(db db.DB, p processing.Processor, idp oidc.IDP, routerSession *gtsmodel.RouterSession, sessionName string) *Auth {
return &Auth{
routerSession: routerSession,
sessionName: sessionName,
auth: auth.New(db, p, idp),
}
}

117
internal/api/auth/auth.go Normal file
View file

@ -0,0 +1,117 @@
/*
GoToSocial
Copyright (C) 2021-2022 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package auth
import (
"net/http"
"github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/oidc"
"github.com/superseriousbusiness/gotosocial/internal/processing"
)
const (
/*
paths prefixed with 'auth'
*/
// AuthSignInPath is the API path for users to sign in through
AuthSignInPath = "/sign_in"
// AuthCheckYourEmailPath users land here after registering a new account, instructs them to confirm their email
AuthCheckYourEmailPath = "/check_your_email"
// AuthWaitForApprovalPath users land here after confirming their email
// but before an admin approves their account (if such is required)
AuthWaitForApprovalPath = "/wait_for_approval"
// AuthAccountDisabledPath users land here when their account is suspended by an admin
AuthAccountDisabledPath = "/account_disabled"
// AuthCallbackPath is the API path for receiving callback tokens from external OIDC providers
AuthCallbackPath = "/callback"
/*
paths prefixed with 'oauth'
*/
// OauthTokenPath is the API path to use for granting token requests to users with valid credentials
OauthTokenPath = "/token" // #nosec G101 else we get a hardcoded credentials warning
// OauthAuthorizePath is the API path for authorization requests (eg., authorize this app to act on my behalf as a user)
OauthAuthorizePath = "/authorize"
// OauthFinalizePath is the API path for completing user registration with additional user details
OauthFinalizePath = "/finalize"
// OauthOobTokenPath is the path for serving an html representation of an oob token page.
OauthOobTokenPath = "/oob" // #nosec G101 else we get a hardcoded credentials warning
/*
params / session keys
*/
callbackStateParam = "state"
callbackCodeParam = "code"
sessionUserID = "userid"
sessionClientID = "client_id"
sessionRedirectURI = "redirect_uri"
sessionForceLogin = "force_login"
sessionResponseType = "response_type"
sessionScope = "scope"
sessionInternalState = "internal_state"
sessionClientState = "client_state"
sessionClaims = "claims"
sessionAppID = "app_id"
)
type Module struct {
db db.DB
processor processing.Processor
idp oidc.IDP
}
// New returns an Auth module which provides both 'oauth' and 'auth' endpoints.
//
// It is safe to pass a nil idp if oidc is disabled.
func New(db db.DB, processor processing.Processor, idp oidc.IDP) *Module {
return &Module{
db: db,
processor: processor,
idp: idp,
}
}
// RouteAuth routes all paths that should have an 'auth' prefix
func (m *Module) RouteAuth(attachHandler func(method string, path string, f ...gin.HandlerFunc) gin.IRoutes) {
attachHandler(http.MethodGet, AuthSignInPath, m.SignInGETHandler)
attachHandler(http.MethodPost, AuthSignInPath, m.SignInPOSTHandler)
attachHandler(http.MethodGet, AuthCallbackPath, m.CallbackGETHandler)
}
// RouteOauth routes all paths that should have an 'oauth' prefix
func (m *Module) RouteOauth(attachHandler func(method string, path string, f ...gin.HandlerFunc) gin.IRoutes) {
attachHandler(http.MethodPost, OauthTokenPath, m.TokenPOSTHandler)
attachHandler(http.MethodGet, OauthAuthorizePath, m.AuthorizeGETHandler)
attachHandler(http.MethodPost, OauthAuthorizePath, m.AuthorizePOSTHandler)
attachHandler(http.MethodPost, OauthFinalizePath, m.FinalizePOSTHandler)
attachHandler(http.MethodGet, OauthOobTokenPath, m.OobHandler)
}
func (m *Module) clearSession(s sessions.Session) {
s.Clear()
if err := s.Save(); err != nil {
panic(err)
}
}

View file

@ -20,7 +20,6 @@ package auth_test
import ( import (
"bytes" "bytes"
"context"
"fmt" "fmt"
"net/http/httptest" "net/http/httptest"
@ -28,7 +27,7 @@ import (
"github.com/gin-contrib/sessions/memstore" "github.com/gin-contrib/sessions/memstore"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/api/client/auth" "github.com/superseriousbusiness/gotosocial/internal/api/auth"
"github.com/superseriousbusiness/gotosocial/internal/concurrency" "github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/db"
@ -37,10 +36,9 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/media" "github.com/superseriousbusiness/gotosocial/internal/media"
"github.com/superseriousbusiness/gotosocial/internal/messages" "github.com/superseriousbusiness/gotosocial/internal/messages"
"github.com/superseriousbusiness/gotosocial/internal/oauth" "github.com/superseriousbusiness/gotosocial/internal/middleware"
"github.com/superseriousbusiness/gotosocial/internal/oidc" "github.com/superseriousbusiness/gotosocial/internal/oidc"
"github.com/superseriousbusiness/gotosocial/internal/processing" "github.com/superseriousbusiness/gotosocial/internal/processing"
"github.com/superseriousbusiness/gotosocial/internal/router"
"github.com/superseriousbusiness/gotosocial/internal/storage" "github.com/superseriousbusiness/gotosocial/internal/storage"
"github.com/superseriousbusiness/gotosocial/testrig" "github.com/superseriousbusiness/gotosocial/testrig"
) )
@ -54,7 +52,6 @@ type AuthStandardTestSuite struct {
processor processing.Processor processor processing.Processor
emailSender email.Sender emailSender email.Sender
idp oidc.IDP idp oidc.IDP
oauthServer oauth.Server
// standard suite models // standard suite models
testTokens map[string]*gtsmodel.Token testTokens map[string]*gtsmodel.Token
@ -90,17 +87,10 @@ func (suite *AuthStandardTestSuite) SetupTest() {
suite.db = testrig.NewTestDB() suite.db = testrig.NewTestDB()
suite.storage = testrig.NewInMemoryStorage() suite.storage = testrig.NewInMemoryStorage()
suite.mediaManager = testrig.NewTestMediaManager(suite.db, suite.storage) suite.mediaManager = testrig.NewTestMediaManager(suite.db, suite.storage)
suite.federator = testrig.NewTestFederator(suite.db, testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil, "../../../../testrig/media"), suite.db, fedWorker), suite.storage, suite.mediaManager, fedWorker) suite.federator = testrig.NewTestFederator(suite.db, testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil, "../../../testrig/media"), suite.db, fedWorker), suite.storage, suite.mediaManager, fedWorker)
suite.emailSender = testrig.NewEmailSender("../../../../web/template/", nil) suite.emailSender = testrig.NewEmailSender("../../../web/template/", nil)
suite.processor = testrig.NewTestProcessor(suite.db, suite.storage, suite.federator, suite.emailSender, suite.mediaManager, clientWorker, fedWorker) suite.processor = testrig.NewTestProcessor(suite.db, suite.storage, suite.federator, suite.emailSender, suite.mediaManager, clientWorker, fedWorker)
suite.authModule = auth.New(suite.db, suite.processor, suite.idp)
suite.oauthServer = testrig.NewTestOauthServer(suite.db)
var err error
suite.idp, err = oidc.NewIDP(context.Background())
if err != nil {
panic(err)
}
suite.authModule = auth.New(suite.db, suite.idp, suite.processor).(*auth.Module)
testrig.StandardDBSetup(suite.db, suite.testAccounts) testrig.StandardDBSetup(suite.db, suite.testAccounts)
} }
@ -114,7 +104,7 @@ func (suite *AuthStandardTestSuite) newContext(requestMethod string, requestPath
ctx, engine := testrig.CreateGinTestContext(recorder, nil) ctx, engine := testrig.CreateGinTestContext(recorder, nil)
// load templates into the engine // load templates into the engine
testrig.ConfigureTemplatesWithGin(engine, "../../../../web/template") testrig.ConfigureTemplatesWithGin(engine, "../../../web/template")
// create the request // create the request
protocol := config.GetProtocol() protocol := config.GetProtocol()
@ -131,7 +121,7 @@ func (suite *AuthStandardTestSuite) newContext(requestMethod string, requestPath
// trigger the session middleware on the context // trigger the session middleware on the context
store := memstore.NewStore(make([]byte, 32), make([]byte, 32)) store := memstore.NewStore(make([]byte, 32), make([]byte, 32))
store.Options(router.SessionOptions()) store.Options(middleware.SessionOptions())
sessionMiddleware := sessions.Sessions("gotosocial-localhost", store) sessionMiddleware := sessions.Sessions("gotosocial-localhost", store)
sessionMiddleware(ctx) sessionMiddleware(ctx)

View file

@ -27,8 +27,8 @@ import (
"github.com/gin-contrib/sessions" "github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/superseriousbusiness/gotosocial/internal/api" apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
"github.com/superseriousbusiness/gotosocial/internal/api/model" apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util"
"github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/gtserror"
@ -42,8 +42,8 @@ import (
func (m *Module) AuthorizeGETHandler(c *gin.Context) { func (m *Module) AuthorizeGETHandler(c *gin.Context) {
s := sessions.Default(c) s := sessions.Default(c)
if _, err := api.NegotiateAccept(c, api.HTMLAcceptHeaders...); err != nil { if _, err := apiutil.NegotiateAccept(c, apiutil.HTMLAcceptHeaders...); err != nil {
api.ErrorHandler(c, gtserror.NewErrorNotAcceptable(err, err.Error()), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorNotAcceptable(err, err.Error()), m.processor.InstanceGet)
return return
} }
@ -51,20 +51,20 @@ func (m *Module) AuthorizeGETHandler(c *gin.Context) {
// If it's not set, then we don't know yet who the user is, so we need to redirect them to the sign in page. // If it's not set, then we don't know yet who the user is, so we need to redirect them to the sign in page.
userID, ok := s.Get(sessionUserID).(string) userID, ok := s.Get(sessionUserID).(string)
if !ok || userID == "" { if !ok || userID == "" {
form := &model.OAuthAuthorize{} form := &apimodel.OAuthAuthorize{}
if err := c.ShouldBind(form); err != nil { if err := c.ShouldBind(form); err != nil {
m.clearSession(s) m.clearSession(s)
api.ErrorHandler(c, gtserror.NewErrorBadRequest(err, oauth.HelpfulAdvice), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, oauth.HelpfulAdvice), m.processor.InstanceGet)
return return
} }
if errWithCode := saveAuthFormToSession(s, form); errWithCode != nil { if errWithCode := saveAuthFormToSession(s, form); errWithCode != nil {
m.clearSession(s) m.clearSession(s)
api.ErrorHandler(c, errWithCode, m.processor.InstanceGet) apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGet)
return return
} }
c.Redirect(http.StatusSeeOther, AuthSignInPath) c.Redirect(http.StatusSeeOther, "/auth"+AuthSignInPath)
return return
} }
@ -73,7 +73,7 @@ func (m *Module) AuthorizeGETHandler(c *gin.Context) {
if !ok || clientID == "" { if !ok || clientID == "" {
m.clearSession(s) m.clearSession(s)
err := fmt.Errorf("key %s was not found in session", sessionClientID) err := fmt.Errorf("key %s was not found in session", sessionClientID)
api.ErrorHandler(c, gtserror.NewErrorBadRequest(err, oauth.HelpfulAdvice), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, oauth.HelpfulAdvice), m.processor.InstanceGet)
return return
} }
@ -87,7 +87,7 @@ func (m *Module) AuthorizeGETHandler(c *gin.Context) {
} else { } else {
errWithCode = gtserror.NewErrorInternalError(err, safe, oauth.HelpfulAdvice) errWithCode = gtserror.NewErrorInternalError(err, safe, oauth.HelpfulAdvice)
} }
api.ErrorHandler(c, errWithCode, m.processor.InstanceGet) apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGet)
return return
} }
@ -101,7 +101,7 @@ func (m *Module) AuthorizeGETHandler(c *gin.Context) {
} else { } else {
errWithCode = gtserror.NewErrorInternalError(err, safe, oauth.HelpfulAdvice) errWithCode = gtserror.NewErrorInternalError(err, safe, oauth.HelpfulAdvice)
} }
api.ErrorHandler(c, errWithCode, m.processor.InstanceGet) apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGet)
return return
} }
@ -115,7 +115,7 @@ func (m *Module) AuthorizeGETHandler(c *gin.Context) {
} else { } else {
errWithCode = gtserror.NewErrorInternalError(err, safe, oauth.HelpfulAdvice) errWithCode = gtserror.NewErrorInternalError(err, safe, oauth.HelpfulAdvice)
} }
api.ErrorHandler(c, errWithCode, m.processor.InstanceGet) apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGet)
return return
} }
@ -128,7 +128,7 @@ func (m *Module) AuthorizeGETHandler(c *gin.Context) {
if !ok || redirect == "" { if !ok || redirect == "" {
m.clearSession(s) m.clearSession(s)
err := fmt.Errorf("key %s was not found in session", sessionRedirectURI) err := fmt.Errorf("key %s was not found in session", sessionRedirectURI)
api.ErrorHandler(c, gtserror.NewErrorBadRequest(err, oauth.HelpfulAdvice), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, oauth.HelpfulAdvice), m.processor.InstanceGet)
return return
} }
@ -136,13 +136,13 @@ func (m *Module) AuthorizeGETHandler(c *gin.Context) {
if !ok || scope == "" { if !ok || scope == "" {
m.clearSession(s) m.clearSession(s)
err := fmt.Errorf("key %s was not found in session", sessionScope) err := fmt.Errorf("key %s was not found in session", sessionScope)
api.ErrorHandler(c, gtserror.NewErrorBadRequest(err, oauth.HelpfulAdvice), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, oauth.HelpfulAdvice), m.processor.InstanceGet)
return return
} }
instance, errWithCode := m.processor.InstanceGet(c.Request.Context(), config.GetHost()) instance, errWithCode := m.processor.InstanceGet(c.Request.Context(), config.GetHost())
if errWithCode != nil { if errWithCode != nil {
api.ErrorHandler(c, errWithCode, m.processor.InstanceGet) apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGet)
return return
} }
@ -206,7 +206,7 @@ func (m *Module) AuthorizePOSTHandler(c *gin.Context) {
if len(errs) != 0 { if len(errs) != 0 {
errs = append(errs, oauth.HelpfulAdvice) errs = append(errs, oauth.HelpfulAdvice)
api.ErrorHandler(c, gtserror.NewErrorBadRequest(errors.New("one or more missing keys on session during AuthorizePOSTHandler"), errs...), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(errors.New("one or more missing keys on session during AuthorizePOSTHandler"), errs...), m.processor.InstanceGet)
return return
} }
@ -220,7 +220,7 @@ func (m *Module) AuthorizePOSTHandler(c *gin.Context) {
} else { } else {
errWithCode = gtserror.NewErrorInternalError(err, safe, oauth.HelpfulAdvice) errWithCode = gtserror.NewErrorInternalError(err, safe, oauth.HelpfulAdvice)
} }
api.ErrorHandler(c, errWithCode, m.processor.InstanceGet) apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGet)
return return
} }
@ -234,7 +234,7 @@ func (m *Module) AuthorizePOSTHandler(c *gin.Context) {
} else { } else {
errWithCode = gtserror.NewErrorInternalError(err, safe, oauth.HelpfulAdvice) errWithCode = gtserror.NewErrorInternalError(err, safe, oauth.HelpfulAdvice)
} }
api.ErrorHandler(c, errWithCode, m.processor.InstanceGet) apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGet)
return return
} }
@ -263,13 +263,13 @@ func (m *Module) AuthorizePOSTHandler(c *gin.Context) {
} }
if errWithCode := m.processor.OAuthHandleAuthorizeRequest(c.Writer, c.Request); errWithCode != nil { if errWithCode := m.processor.OAuthHandleAuthorizeRequest(c.Writer, c.Request); errWithCode != nil {
api.ErrorHandler(c, errWithCode, m.processor.InstanceGet) apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGet)
} }
} }
// saveAuthFormToSession checks the given OAuthAuthorize form, // saveAuthFormToSession checks the given OAuthAuthorize form,
// and stores the values in the form into the session. // and stores the values in the form into the session.
func saveAuthFormToSession(s sessions.Session, form *model.OAuthAuthorize) gtserror.WithCode { func saveAuthFormToSession(s sessions.Session, form *apimodel.OAuthAuthorize) gtserror.WithCode {
if form == nil { if form == nil {
err := errors.New("OAuthAuthorize form was nil") err := errors.New("OAuthAuthorize form was nil")
return gtserror.NewErrorBadRequest(err, err.Error(), oauth.HelpfulAdvice) return gtserror.NewErrorBadRequest(err, err.Error(), oauth.HelpfulAdvice)
@ -314,19 +314,19 @@ func saveAuthFormToSession(s sessions.Session, form *model.OAuthAuthorize) gtser
func ensureUserIsAuthorizedOrRedirect(ctx *gin.Context, user *gtsmodel.User, account *gtsmodel.Account) (redirected bool) { func ensureUserIsAuthorizedOrRedirect(ctx *gin.Context, user *gtsmodel.User, account *gtsmodel.Account) (redirected bool) {
if user.ConfirmedAt.IsZero() { if user.ConfirmedAt.IsZero() {
ctx.Redirect(http.StatusSeeOther, CheckYourEmailPath) ctx.Redirect(http.StatusSeeOther, "/auth"+AuthCheckYourEmailPath)
redirected = true redirected = true
return return
} }
if !*user.Approved { if !*user.Approved {
ctx.Redirect(http.StatusSeeOther, WaitForApprovalPath) ctx.Redirect(http.StatusSeeOther, "/auth"+AuthWaitForApprovalPath)
redirected = true redirected = true
return return
} }
if *user.Disabled || !account.SuspendedAt.IsZero() { if *user.Disabled || !account.SuspendedAt.IsZero() {
ctx.Redirect(http.StatusSeeOther, AccountDisabledPath) ctx.Redirect(http.StatusSeeOther, "/auth"+AuthAccountDisabledPath)
redirected = true redirected = true
return return
} }

View file

@ -9,7 +9,7 @@ import (
"github.com/gin-contrib/sessions" "github.com/gin-contrib/sessions"
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/api/client/auth" "github.com/superseriousbusiness/gotosocial/internal/api/auth"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/testrig" "github.com/superseriousbusiness/gotosocial/testrig"
) )
@ -34,7 +34,7 @@ func (suite *AuthAuthorizeTestSuite) TestAccountAuthorizeHandler() {
return []string{"confirmed_at"} return []string{"confirmed_at"}
}, },
expectedStatusCode: http.StatusSeeOther, expectedStatusCode: http.StatusSeeOther,
expectedLocationHeader: auth.CheckYourEmailPath, expectedLocationHeader: "/auth" + auth.AuthCheckYourEmailPath,
}, },
{ {
description: "user has their email confirmed but is not approved", description: "user has their email confirmed but is not approved",
@ -44,7 +44,7 @@ func (suite *AuthAuthorizeTestSuite) TestAccountAuthorizeHandler() {
return []string{"confirmed_at", "email"} return []string{"confirmed_at", "email"}
}, },
expectedStatusCode: http.StatusSeeOther, expectedStatusCode: http.StatusSeeOther,
expectedLocationHeader: auth.WaitForApprovalPath, expectedLocationHeader: "/auth" + auth.AuthWaitForApprovalPath,
}, },
{ {
description: "user has their email confirmed and is approved, but User entity has been disabled", description: "user has their email confirmed and is approved, but User entity has been disabled",
@ -56,7 +56,7 @@ func (suite *AuthAuthorizeTestSuite) TestAccountAuthorizeHandler() {
return []string{"confirmed_at", "email", "approved", "disabled"} return []string{"confirmed_at", "email", "approved", "disabled"}
}, },
expectedStatusCode: http.StatusSeeOther, expectedStatusCode: http.StatusSeeOther,
expectedLocationHeader: auth.AccountDisabledPath, expectedLocationHeader: "/auth" + auth.AuthAccountDisabledPath,
}, },
{ {
description: "user has their email confirmed and is approved, but Account entity has been suspended", description: "user has their email confirmed and is approved, but Account entity has been suspended",
@ -69,7 +69,7 @@ func (suite *AuthAuthorizeTestSuite) TestAccountAuthorizeHandler() {
return []string{"confirmed_at", "email", "approved", "disabled"} return []string{"confirmed_at", "email", "approved", "disabled"}
}, },
expectedStatusCode: http.StatusSeeOther, expectedStatusCode: http.StatusSeeOther,
expectedLocationHeader: auth.AccountDisabledPath, expectedLocationHeader: "/auth" + auth.AuthAccountDisabledPath,
}, },
} }

View file

@ -29,7 +29,7 @@ import (
"github.com/gin-contrib/sessions" "github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/superseriousbusiness/gotosocial/internal/api" apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util"
"github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/gtserror"
@ -47,6 +47,12 @@ type extraInfo struct {
// CallbackGETHandler parses a token from an external auth provider. // CallbackGETHandler parses a token from an external auth provider.
func (m *Module) CallbackGETHandler(c *gin.Context) { func (m *Module) CallbackGETHandler(c *gin.Context) {
if !config.GetOIDCEnabled() {
err := errors.New("oidc is not enabled for this server")
apiutil.ErrorHandler(c, gtserror.NewErrorNotFound(err, err.Error()), m.processor.InstanceGet)
return
}
s := sessions.Default(c) s := sessions.Default(c)
// check the query vs session state parameter to mitigate csrf // check the query vs session state parameter to mitigate csrf
@ -56,7 +62,7 @@ func (m *Module) CallbackGETHandler(c *gin.Context) {
if returnedInternalState == "" { if returnedInternalState == "" {
m.clearSession(s) m.clearSession(s)
err := fmt.Errorf("%s parameter not found on callback query", callbackStateParam) err := fmt.Errorf("%s parameter not found on callback query", callbackStateParam)
api.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGet)
return return
} }
@ -65,14 +71,14 @@ func (m *Module) CallbackGETHandler(c *gin.Context) {
if !ok { if !ok {
m.clearSession(s) m.clearSession(s)
err := fmt.Errorf("key %s was not found in session", sessionInternalState) err := fmt.Errorf("key %s was not found in session", sessionInternalState)
api.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGet)
return return
} }
if returnedInternalState != savedInternalState { if returnedInternalState != savedInternalState {
m.clearSession(s) m.clearSession(s)
err := errors.New("mismatch between callback state and saved state") err := errors.New("mismatch between callback state and saved state")
api.ErrorHandler(c, gtserror.NewErrorUnauthorized(err, err.Error()), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorUnauthorized(err, err.Error()), m.processor.InstanceGet)
return return
} }
@ -81,14 +87,14 @@ func (m *Module) CallbackGETHandler(c *gin.Context) {
if code == "" { if code == "" {
m.clearSession(s) m.clearSession(s)
err := fmt.Errorf("%s parameter not found on callback query", callbackCodeParam) err := fmt.Errorf("%s parameter not found on callback query", callbackCodeParam)
api.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGet)
return return
} }
claims, errWithCode := m.idp.HandleCallback(c.Request.Context(), code) claims, errWithCode := m.idp.HandleCallback(c.Request.Context(), code)
if errWithCode != nil { if errWithCode != nil {
m.clearSession(s) m.clearSession(s)
api.ErrorHandler(c, errWithCode, m.processor.InstanceGet) apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGet)
return return
} }
@ -98,7 +104,7 @@ func (m *Module) CallbackGETHandler(c *gin.Context) {
if !ok || clientID == "" { if !ok || clientID == "" {
m.clearSession(s) m.clearSession(s)
err := fmt.Errorf("key %s was not found in session", sessionClientID) err := fmt.Errorf("key %s was not found in session", sessionClientID)
api.ErrorHandler(c, gtserror.NewErrorBadRequest(err, oauth.HelpfulAdvice), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, oauth.HelpfulAdvice), m.processor.InstanceGet)
return return
} }
@ -112,21 +118,21 @@ func (m *Module) CallbackGETHandler(c *gin.Context) {
} else { } else {
errWithCode = gtserror.NewErrorInternalError(err, safe, oauth.HelpfulAdvice) errWithCode = gtserror.NewErrorInternalError(err, safe, oauth.HelpfulAdvice)
} }
api.ErrorHandler(c, errWithCode, m.processor.InstanceGet) apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGet)
return return
} }
user, errWithCode := m.fetchUserForClaims(c.Request.Context(), claims, net.IP(c.ClientIP()), app.ID) user, errWithCode := m.fetchUserForClaims(c.Request.Context(), claims, net.IP(c.ClientIP()), app.ID)
if errWithCode != nil { if errWithCode != nil {
m.clearSession(s) m.clearSession(s)
api.ErrorHandler(c, errWithCode, m.processor.InstanceGet) apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGet)
return return
} }
if user == nil { if user == nil {
// no user exists yet - let's ask them for their preferred username // no user exists yet - let's ask them for their preferred username
instance, errWithCode := m.processor.InstanceGet(c.Request.Context(), config.GetHost()) instance, errWithCode := m.processor.InstanceGet(c.Request.Context(), config.GetHost())
if errWithCode != nil { if errWithCode != nil {
api.ErrorHandler(c, errWithCode, m.processor.InstanceGet) apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGet)
return return
} }
@ -135,7 +141,7 @@ func (m *Module) CallbackGETHandler(c *gin.Context) {
s.Set(sessionAppID, app.ID) s.Set(sessionAppID, app.ID)
if err := s.Save(); err != nil { if err := s.Save(); err != nil {
m.clearSession(s) m.clearSession(s)
api.ErrorHandler(c, gtserror.NewErrorInternalError(err), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorInternalError(err), m.processor.InstanceGet)
return return
} }
c.HTML(http.StatusOK, "finalize.tmpl", gin.H{ c.HTML(http.StatusOK, "finalize.tmpl", gin.H{
@ -148,10 +154,10 @@ func (m *Module) CallbackGETHandler(c *gin.Context) {
s.Set(sessionUserID, user.ID) s.Set(sessionUserID, user.ID)
if err := s.Save(); err != nil { if err := s.Save(); err != nil {
m.clearSession(s) m.clearSession(s)
api.ErrorHandler(c, gtserror.NewErrorInternalError(err), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorInternalError(err), m.processor.InstanceGet)
return return
} }
c.Redirect(http.StatusFound, OauthAuthorizePath) c.Redirect(http.StatusFound, "/oauth"+OauthAuthorizePath)
} }
// FinalizePOSTHandler registers the user after additional data has been provided // FinalizePOSTHandler registers the user after additional data has been provided
@ -161,7 +167,7 @@ func (m *Module) FinalizePOSTHandler(c *gin.Context) {
form := &extraInfo{} form := &extraInfo{}
if err := c.ShouldBind(form); err != nil { if err := c.ShouldBind(form); err != nil {
m.clearSession(s) m.clearSession(s)
api.ErrorHandler(c, gtserror.NewErrorBadRequest(err, oauth.HelpfulAdvice), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, oauth.HelpfulAdvice), m.processor.InstanceGet)
return return
} }
@ -169,7 +175,7 @@ func (m *Module) FinalizePOSTHandler(c *gin.Context) {
validationError := func(err error) { validationError := func(err error) {
instance, errWithCode := m.processor.InstanceGet(c.Request.Context(), config.GetHost()) instance, errWithCode := m.processor.InstanceGet(c.Request.Context(), config.GetHost())
if errWithCode != nil { if errWithCode != nil {
api.ErrorHandler(c, errWithCode, m.processor.InstanceGet) apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGet)
return return
} }
c.HTML(http.StatusOK, "finalize.tmpl", gin.H{ c.HTML(http.StatusOK, "finalize.tmpl", gin.H{
@ -189,7 +195,7 @@ func (m *Module) FinalizePOSTHandler(c *gin.Context) {
// see if the username is still available // see if the username is still available
usernameAvailable, err := m.db.IsUsernameAvailable(c.Request.Context(), form.Username) usernameAvailable, err := m.db.IsUsernameAvailable(c.Request.Context(), form.Username)
if err != nil { if err != nil {
api.ErrorHandler(c, gtserror.NewErrorBadRequest(err, oauth.HelpfulAdvice), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, oauth.HelpfulAdvice), m.processor.InstanceGet)
return return
} }
if !usernameAvailable { if !usernameAvailable {
@ -201,7 +207,7 @@ func (m *Module) FinalizePOSTHandler(c *gin.Context) {
appID, ok := s.Get(sessionAppID).(string) appID, ok := s.Get(sessionAppID).(string)
if !ok { if !ok {
err := fmt.Errorf("key %s was not found in session", sessionAppID) err := fmt.Errorf("key %s was not found in session", sessionAppID)
api.ErrorHandler(c, gtserror.NewErrorBadRequest(err, oauth.HelpfulAdvice), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, oauth.HelpfulAdvice), m.processor.InstanceGet)
return return
} }
@ -209,7 +215,7 @@ func (m *Module) FinalizePOSTHandler(c *gin.Context) {
claims, ok := s.Get(sessionClaims).(*oidc.Claims) claims, ok := s.Get(sessionClaims).(*oidc.Claims)
if !ok { if !ok {
err := fmt.Errorf("key %s was not found in session", sessionClaims) err := fmt.Errorf("key %s was not found in session", sessionClaims)
api.ErrorHandler(c, gtserror.NewErrorBadRequest(err, oauth.HelpfulAdvice), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, oauth.HelpfulAdvice), m.processor.InstanceGet)
return return
} }
@ -217,7 +223,7 @@ func (m *Module) FinalizePOSTHandler(c *gin.Context) {
user, errWithCode := m.createUserFromOIDC(c.Request.Context(), claims, form, net.IP(c.ClientIP()), appID) user, errWithCode := m.createUserFromOIDC(c.Request.Context(), claims, form, net.IP(c.ClientIP()), appID)
if errWithCode != nil { if errWithCode != nil {
m.clearSession(s) m.clearSession(s)
api.ErrorHandler(c, errWithCode, m.processor.InstanceGet) apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGet)
return return
} }
s.Delete(sessionClaims) s.Delete(sessionClaims)
@ -225,10 +231,10 @@ func (m *Module) FinalizePOSTHandler(c *gin.Context) {
s.Set(sessionUserID, user.ID) s.Set(sessionUserID, user.ID)
if err := s.Save(); err != nil { if err := s.Save(); err != nil {
m.clearSession(s) m.clearSession(s)
api.ErrorHandler(c, gtserror.NewErrorInternalError(err), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorInternalError(err), m.processor.InstanceGet)
return return
} }
c.Redirect(http.StatusFound, OauthAuthorizePath) c.Redirect(http.StatusFound, "/oauth"+OauthAuthorizePath)
} }
func (m *Module) fetchUserForClaims(ctx context.Context, claims *oidc.Claims, ip net.IP, appID string) (*gtsmodel.User, gtserror.WithCode) { func (m *Module) fetchUserForClaims(ctx context.Context, claims *oidc.Claims, ip net.IP, appID string) (*gtsmodel.User, gtserror.WithCode) {

View file

@ -26,8 +26,8 @@ import (
"github.com/gin-contrib/sessions" "github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/superseriousbusiness/gotosocial/internal/api" apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
"github.com/superseriousbusiness/gotosocial/internal/api/model" apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util"
"github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/gtserror"
@ -38,16 +38,18 @@ func (m *Module) OobHandler(c *gin.Context) {
host := config.GetHost() host := config.GetHost()
instance, errWithCode := m.processor.InstanceGet(c.Request.Context(), host) instance, errWithCode := m.processor.InstanceGet(c.Request.Context(), host)
if errWithCode != nil { if errWithCode != nil {
api.ErrorHandler(c, errWithCode, m.processor.InstanceGet) apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGet)
return return
} }
instanceGet := func(ctx context.Context, domain string) (*model.Instance, gtserror.WithCode) { return instance, nil } instanceGet := func(ctx context.Context, domain string) (*apimodel.Instance, gtserror.WithCode) {
return instance, nil
}
oobToken := c.Query("code") oobToken := c.Query("code")
if oobToken == "" { if oobToken == "" {
err := errors.New("no 'code' query value provided in callback redirect") err := errors.New("no 'code' query value provided in callback redirect")
api.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error(), oauth.HelpfulAdvice), instanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error(), oauth.HelpfulAdvice), instanceGet)
return return
} }
@ -67,7 +69,7 @@ func (m *Module) OobHandler(c *gin.Context) {
if len(errs) != 0 { if len(errs) != 0 {
errs = append(errs, oauth.HelpfulAdvice) errs = append(errs, oauth.HelpfulAdvice)
api.ErrorHandler(c, gtserror.NewErrorBadRequest(errors.New("one or more missing keys on session during OobHandler"), errs...), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(errors.New("one or more missing keys on session during OobHandler"), errs...), m.processor.InstanceGet)
return return
} }
@ -81,7 +83,7 @@ func (m *Module) OobHandler(c *gin.Context) {
} else { } else {
errWithCode = gtserror.NewErrorInternalError(err, safe, oauth.HelpfulAdvice) errWithCode = gtserror.NewErrorInternalError(err, safe, oauth.HelpfulAdvice)
} }
api.ErrorHandler(c, errWithCode, instanceGet) apiutil.ErrorHandler(c, errWithCode, instanceGet)
return return
} }
@ -95,7 +97,7 @@ func (m *Module) OobHandler(c *gin.Context) {
} else { } else {
errWithCode = gtserror.NewErrorInternalError(err, safe, oauth.HelpfulAdvice) errWithCode = gtserror.NewErrorInternalError(err, safe, oauth.HelpfulAdvice)
} }
api.ErrorHandler(c, errWithCode, instanceGet) apiutil.ErrorHandler(c, errWithCode, instanceGet)
return return
} }

View file

@ -26,7 +26,7 @@ import (
"github.com/gin-contrib/sessions" "github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/superseriousbusiness/gotosocial/internal/api" apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util"
"github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/oauth" "github.com/superseriousbusiness/gotosocial/internal/oauth"
@ -44,15 +44,15 @@ type login struct {
// The form will then POST to the sign in page, which will be handled by SignInPOSTHandler. // The form will then POST to the sign in page, which will be handled by SignInPOSTHandler.
// If an idp provider is set, then the user will be redirected to that to do their sign in. // If an idp provider is set, then the user will be redirected to that to do their sign in.
func (m *Module) SignInGETHandler(c *gin.Context) { func (m *Module) SignInGETHandler(c *gin.Context) {
if _, err := api.NegotiateAccept(c, api.HTMLAcceptHeaders...); err != nil { if _, err := apiutil.NegotiateAccept(c, apiutil.HTMLAcceptHeaders...); err != nil {
api.ErrorHandler(c, gtserror.NewErrorNotAcceptable(err, err.Error()), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorNotAcceptable(err, err.Error()), m.processor.InstanceGet)
return return
} }
if m.idp == nil { if !config.GetOIDCEnabled() {
instance, errWithCode := m.processor.InstanceGet(c.Request.Context(), config.GetHost()) instance, errWithCode := m.processor.InstanceGet(c.Request.Context(), config.GetHost())
if errWithCode != nil { if errWithCode != nil {
api.ErrorHandler(c, errWithCode, m.processor.InstanceGet) apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGet)
return return
} }
@ -71,7 +71,7 @@ func (m *Module) SignInGETHandler(c *gin.Context) {
if !ok { if !ok {
m.clearSession(s) m.clearSession(s)
err := fmt.Errorf("key %s was not found in session", sessionInternalState) err := fmt.Errorf("key %s was not found in session", sessionInternalState)
api.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGet)
return return
} }
@ -87,7 +87,7 @@ func (m *Module) SignInPOSTHandler(c *gin.Context) {
form := &login{} form := &login{}
if err := c.ShouldBind(form); err != nil { if err := c.ShouldBind(form); err != nil {
m.clearSession(s) m.clearSession(s)
api.ErrorHandler(c, gtserror.NewErrorBadRequest(err, oauth.HelpfulAdvice), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, oauth.HelpfulAdvice), m.processor.InstanceGet)
return return
} }
@ -95,17 +95,17 @@ func (m *Module) SignInPOSTHandler(c *gin.Context) {
if errWithCode != nil { if errWithCode != nil {
// don't clear session here, so the user can just press back and try again // don't clear session here, so the user can just press back and try again
// if they accidentally gave the wrong password or something // if they accidentally gave the wrong password or something
api.ErrorHandler(c, errWithCode, m.processor.InstanceGet) apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGet)
return return
} }
s.Set(sessionUserID, userid) s.Set(sessionUserID, userid)
if err := s.Save(); err != nil { if err := s.Save(); err != nil {
err := fmt.Errorf("error saving user id onto session: %s", err) err := fmt.Errorf("error saving user id onto session: %s", err)
api.ErrorHandler(c, gtserror.NewErrorInternalError(err, oauth.HelpfulAdvice), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorInternalError(err, oauth.HelpfulAdvice), m.processor.InstanceGet)
} }
c.Redirect(http.StatusFound, OauthAuthorizePath) c.Redirect(http.StatusFound, "/oauth"+OauthAuthorizePath)
} }
// ValidatePassword takes an email address and a password. // ValidatePassword takes an email address and a password.

View file

@ -22,7 +22,7 @@ import (
"net/http" "net/http"
"net/url" "net/url"
"github.com/superseriousbusiness/gotosocial/internal/api" apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util"
"github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/oauth" "github.com/superseriousbusiness/gotosocial/internal/oauth"
@ -41,8 +41,8 @@ type tokenRequestForm struct {
// TokenPOSTHandler should be served as a POST at https://example.org/oauth/token // TokenPOSTHandler should be served as a POST at https://example.org/oauth/token
// The idea here is to serve an oauth access token to a user, which can be used for authorizing against non-public APIs. // The idea here is to serve an oauth access token to a user, which can be used for authorizing against non-public APIs.
func (m *Module) TokenPOSTHandler(c *gin.Context) { func (m *Module) TokenPOSTHandler(c *gin.Context) {
if _, err := api.NegotiateAccept(c, api.JSONAcceptHeaders...); err != nil { if _, err := apiutil.NegotiateAccept(c, apiutil.JSONAcceptHeaders...); err != nil {
api.ErrorHandler(c, gtserror.NewErrorNotAcceptable(err, err.Error()), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorNotAcceptable(err, err.Error()), m.processor.InstanceGet)
return return
} }
@ -50,7 +50,7 @@ func (m *Module) TokenPOSTHandler(c *gin.Context) {
form := &tokenRequestForm{} form := &tokenRequestForm{}
if err := c.ShouldBind(form); err != nil { if err := c.ShouldBind(form); err != nil {
api.OAuthErrorHandler(c, gtserror.NewErrorBadRequest(oauth.InvalidRequest(), err.Error())) apiutil.OAuthErrorHandler(c, gtserror.NewErrorBadRequest(oauth.InvalidRequest(), err.Error()))
return return
} }
@ -99,13 +99,13 @@ func (m *Module) TokenPOSTHandler(c *gin.Context) {
} }
if len(help) != 0 { if len(help) != 0 {
api.OAuthErrorHandler(c, gtserror.NewErrorBadRequest(oauth.InvalidRequest(), help...)) apiutil.OAuthErrorHandler(c, gtserror.NewErrorBadRequest(oauth.InvalidRequest(), help...))
return return
} }
token, errWithCode := m.processor.OAuthHandleTokenRequest(c.Request) token, errWithCode := m.processor.OAuthHandleTokenRequest(c.Request)
if errWithCode != nil { if errWithCode != nil {
api.OAuthErrorHandler(c, errWithCode) apiutil.OAuthErrorHandler(c, errWithCode)
return return
} }

129
internal/api/client.go Normal file
View file

@ -0,0 +1,129 @@
/*
GoToSocial
Copyright (C) 2021-2022 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package api
import (
"github.com/superseriousbusiness/gotosocial/internal/api/client/accounts"
"github.com/superseriousbusiness/gotosocial/internal/api/client/admin"
"github.com/superseriousbusiness/gotosocial/internal/api/client/apps"
"github.com/superseriousbusiness/gotosocial/internal/api/client/blocks"
"github.com/superseriousbusiness/gotosocial/internal/api/client/bookmarks"
"github.com/superseriousbusiness/gotosocial/internal/api/client/customemojis"
"github.com/superseriousbusiness/gotosocial/internal/api/client/favourites"
filter "github.com/superseriousbusiness/gotosocial/internal/api/client/filters"
"github.com/superseriousbusiness/gotosocial/internal/api/client/followrequests"
"github.com/superseriousbusiness/gotosocial/internal/api/client/instance"
"github.com/superseriousbusiness/gotosocial/internal/api/client/lists"
"github.com/superseriousbusiness/gotosocial/internal/api/client/media"
"github.com/superseriousbusiness/gotosocial/internal/api/client/notifications"
"github.com/superseriousbusiness/gotosocial/internal/api/client/search"
"github.com/superseriousbusiness/gotosocial/internal/api/client/statuses"
"github.com/superseriousbusiness/gotosocial/internal/api/client/streaming"
"github.com/superseriousbusiness/gotosocial/internal/api/client/timelines"
"github.com/superseriousbusiness/gotosocial/internal/api/client/user"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/middleware"
"github.com/superseriousbusiness/gotosocial/internal/processing"
"github.com/superseriousbusiness/gotosocial/internal/router"
)
type Client struct {
processor processing.Processor
db db.DB
accounts *accounts.Module // api/v1/accounts
admin *admin.Module // api/v1/admin
apps *apps.Module // api/v1/apps
blocks *blocks.Module // api/v1/blocks
bookmarks *bookmarks.Module // api/v1/bookmarks
customEmojis *customemojis.Module // api/v1/custom_emojis
favourites *favourites.Module // api/v1/favourites
filters *filter.Module // api/v1/filters
followRequests *followrequests.Module // api/v1/follow_requests
instance *instance.Module // api/v1/instance
lists *lists.Module // api/v1/lists
media *media.Module // api/v1/media, api/v2/media
notifications *notifications.Module // api/v1/notifications
search *search.Module // api/v1/search, api/v2/search
statuses *statuses.Module // api/v1/statuses
streaming *streaming.Module // api/v1/streaming
timelines *timelines.Module // api/v1/timelines
user *user.Module // api/v1/user
}
func (c *Client) Route(r router.Router) {
// create a new group on the top level client 'api' prefix
apiGroup := r.AttachGroup("api")
// attach non-global middlewares appropriate to the client api
apiGroup.Use(
middleware.TokenCheck(c.db, c.processor.OAuthValidateBearerToken),
middleware.RateLimit(),
middleware.Gzip(),
middleware.CacheControl("no-store"), // never cache api responses
)
// for each client api module, pass it the Handle function
// so that the module can attach its routes to this group
h := apiGroup.Handle
c.accounts.Route(h)
c.admin.Route(h)
c.apps.Route(h)
c.blocks.Route(h)
c.bookmarks.Route(h)
c.customEmojis.Route(h)
c.favourites.Route(h)
c.filters.Route(h)
c.followRequests.Route(h)
c.instance.Route(h)
c.lists.Route(h)
c.media.Route(h)
c.notifications.Route(h)
c.search.Route(h)
c.statuses.Route(h)
c.streaming.Route(h)
c.timelines.Route(h)
c.user.Route(h)
}
func NewClient(db db.DB, p processing.Processor) *Client {
return &Client{
processor: p,
db: db,
accounts: accounts.New(p),
admin: admin.New(p),
apps: apps.New(p),
blocks: blocks.New(p),
bookmarks: bookmarks.New(p),
customEmojis: customemojis.New(p),
favourites: favourites.New(p),
filters: filter.New(p),
followRequests: followrequests.New(p),
instance: instance.New(p),
lists: lists.New(p),
media: media.New(p),
notifications: notifications.New(p),
search: search.New(p),
statuses: statuses.New(p),
streaming: streaming.New(p),
timelines: timelines.New(p),
user: user.New(p),
}
}

View file

@ -16,7 +16,7 @@
along with this program. If not, see <http://www.gnu.org/licenses/>. along with this program. If not, see <http://www.gnu.org/licenses/>.
*/ */
package account_test package accounts_test
import ( import (
"bytes" "bytes"
@ -26,7 +26,7 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/api/client/account" "github.com/superseriousbusiness/gotosocial/internal/api/client/accounts"
"github.com/superseriousbusiness/gotosocial/internal/concurrency" "github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/db"
@ -62,7 +62,7 @@ type AccountStandardTestSuite struct {
testStatuses map[string]*gtsmodel.Status testStatuses map[string]*gtsmodel.Status
// module being tested // module being tested
accountModule *account.Module accountsModule *accounts.Module
} }
func (suite *AccountStandardTestSuite) SetupSuite() { func (suite *AccountStandardTestSuite) SetupSuite() {
@ -89,7 +89,7 @@ func (suite *AccountStandardTestSuite) SetupTest() {
suite.sentEmails = make(map[string]string) suite.sentEmails = make(map[string]string)
suite.emailSender = testrig.NewEmailSender("../../../../web/template/", suite.sentEmails) suite.emailSender = testrig.NewEmailSender("../../../../web/template/", suite.sentEmails)
suite.processor = testrig.NewTestProcessor(suite.db, suite.storage, suite.federator, suite.emailSender, suite.mediaManager, clientWorker, fedWorker) suite.processor = testrig.NewTestProcessor(suite.db, suite.storage, suite.federator, suite.emailSender, suite.mediaManager, clientWorker, fedWorker)
suite.accountModule = account.New(suite.processor).(*account.Module) suite.accountsModule = accounts.New(suite.processor)
testrig.StandardDBSetup(suite.db, nil) testrig.StandardDBSetup(suite.db, nil)
testrig.StandardStorageSetup(suite.storage, "../../../../testrig/media") testrig.StandardStorageSetup(suite.storage, "../../../../testrig/media")

View file

@ -16,7 +16,7 @@
along with this program. If not, see <http://www.gnu.org/licenses/>. along with this program. If not, see <http://www.gnu.org/licenses/>.
*/ */
package account package accounts
import ( import (
"errors" "errors"
@ -24,8 +24,8 @@ import (
"net/http" "net/http"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/superseriousbusiness/gotosocial/internal/api" apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
"github.com/superseriousbusiness/gotosocial/internal/api/model" apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util"
"github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/oauth" "github.com/superseriousbusiness/gotosocial/internal/oauth"
@ -73,23 +73,23 @@ import (
func (m *Module) AccountCreatePOSTHandler(c *gin.Context) { func (m *Module) AccountCreatePOSTHandler(c *gin.Context) {
authed, err := oauth.Authed(c, true, true, false, false) authed, err := oauth.Authed(c, true, true, false, false)
if err != nil { if err != nil {
api.ErrorHandler(c, gtserror.NewErrorUnauthorized(err, err.Error()), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorUnauthorized(err, err.Error()), m.processor.InstanceGet)
return return
} }
if _, err := api.NegotiateAccept(c, api.JSONAcceptHeaders...); err != nil { if _, err := apiutil.NegotiateAccept(c, apiutil.JSONAcceptHeaders...); err != nil {
api.ErrorHandler(c, gtserror.NewErrorNotAcceptable(err, err.Error()), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorNotAcceptable(err, err.Error()), m.processor.InstanceGet)
return return
} }
form := &model.AccountCreateRequest{} form := &apimodel.AccountCreateRequest{}
if err := c.ShouldBind(form); err != nil { if err := c.ShouldBind(form); err != nil {
api.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGet)
return return
} }
if err := validateCreateAccount(form); err != nil { if err := validateCreateAccount(form); err != nil {
api.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGet)
return return
} }
@ -97,14 +97,14 @@ func (m *Module) AccountCreatePOSTHandler(c *gin.Context) {
signUpIP := net.ParseIP(clientIP) signUpIP := net.ParseIP(clientIP)
if signUpIP == nil { if signUpIP == nil {
err := errors.New("ip address could not be parsed from request") err := errors.New("ip address could not be parsed from request")
api.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGet)
return return
} }
form.IP = signUpIP form.IP = signUpIP
ti, errWithCode := m.processor.AccountCreate(c.Request.Context(), authed, form) ti, errWithCode := m.processor.AccountCreate(c.Request.Context(), authed, form)
if errWithCode != nil { if errWithCode != nil {
api.ErrorHandler(c, errWithCode, m.processor.InstanceGet) apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGet)
return return
} }
@ -113,7 +113,7 @@ func (m *Module) AccountCreatePOSTHandler(c *gin.Context) {
// validateCreateAccount checks through all the necessary prerequisites for creating a new account, // validateCreateAccount checks through all the necessary prerequisites for creating a new account,
// according to the provided account create request. If the account isn't eligible, an error will be returned. // according to the provided account create request. If the account isn't eligible, an error will be returned.
func validateCreateAccount(form *model.AccountCreateRequest) error { func validateCreateAccount(form *apimodel.AccountCreateRequest) error {
if form == nil { if form == nil {
return errors.New("form was nil") return errors.New("form was nil")
} }

View file

@ -16,4 +16,4 @@
// along with this program. If not, see <http://www.gnu.org/licenses/>. // along with this program. If not, see <http://www.gnu.org/licenses/>.
// */ // */
package account_test package accounts_test

View file

@ -16,15 +16,15 @@
along with this program. If not, see <http://www.gnu.org/licenses/>. along with this program. If not, see <http://www.gnu.org/licenses/>.
*/ */
package account package accounts
import ( import (
"errors" "errors"
"net/http" "net/http"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/superseriousbusiness/gotosocial/internal/api" apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
"github.com/superseriousbusiness/gotosocial/internal/api/model" apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util"
"github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/oauth" "github.com/superseriousbusiness/gotosocial/internal/oauth"
) )
@ -68,26 +68,26 @@ import (
func (m *Module) AccountDeletePOSTHandler(c *gin.Context) { func (m *Module) AccountDeletePOSTHandler(c *gin.Context) {
authed, err := oauth.Authed(c, true, true, true, true) authed, err := oauth.Authed(c, true, true, true, true)
if err != nil { if err != nil {
api.ErrorHandler(c, gtserror.NewErrorUnauthorized(err, err.Error()), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorUnauthorized(err, err.Error()), m.processor.InstanceGet)
return return
} }
form := &model.AccountDeleteRequest{} form := &apimodel.AccountDeleteRequest{}
if err := c.ShouldBind(&form); err != nil { if err := c.ShouldBind(&form); err != nil {
api.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGet)
return return
} }
if form.Password == "" { if form.Password == "" {
err = errors.New("no password provided in account delete request") err = errors.New("no password provided in account delete request")
api.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGet)
return return
} }
form.DeleteOriginID = authed.Account.ID form.DeleteOriginID = authed.Account.ID
if errWithCode := m.processor.AccountDeleteLocal(c.Request.Context(), authed, form); errWithCode != nil { if errWithCode := m.processor.AccountDeleteLocal(c.Request.Context(), authed, form); errWithCode != nil {
api.ErrorHandler(c, errWithCode, m.processor.InstanceGet) apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGet)
return return
} }

View file

@ -16,7 +16,7 @@
along with this program. If not, see <http://www.gnu.org/licenses/>. along with this program. If not, see <http://www.gnu.org/licenses/>.
*/ */
package account_test package accounts_test
import ( import (
"net/http" "net/http"
@ -24,7 +24,7 @@ import (
"testing" "testing"
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/api/client/account" "github.com/superseriousbusiness/gotosocial/internal/api/client/accounts"
"github.com/superseriousbusiness/gotosocial/testrig" "github.com/superseriousbusiness/gotosocial/testrig"
) )
@ -45,10 +45,10 @@ func (suite *AccountDeleteTestSuite) TestAccountDeletePOSTHandler() {
} }
bodyBytes := requestBody.Bytes() bodyBytes := requestBody.Bytes()
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
ctx := suite.newContext(recorder, http.MethodPost, bodyBytes, account.DeleteAccountPath, w.FormDataContentType()) ctx := suite.newContext(recorder, http.MethodPost, bodyBytes, accounts.DeleteAccountPath, w.FormDataContentType())
// call the handler // call the handler
suite.accountModule.AccountDeletePOSTHandler(ctx) suite.accountsModule.AccountDeletePOSTHandler(ctx)
// 1. we should have Accepted because our request was valid // 1. we should have Accepted because our request was valid
suite.Equal(http.StatusAccepted, recorder.Code) suite.Equal(http.StatusAccepted, recorder.Code)
@ -67,10 +67,10 @@ func (suite *AccountDeleteTestSuite) TestAccountDeletePOSTHandlerWrongPassword()
} }
bodyBytes := requestBody.Bytes() bodyBytes := requestBody.Bytes()
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
ctx := suite.newContext(recorder, http.MethodPost, bodyBytes, account.DeleteAccountPath, w.FormDataContentType()) ctx := suite.newContext(recorder, http.MethodPost, bodyBytes, accounts.DeleteAccountPath, w.FormDataContentType())
// call the handler // call the handler
suite.accountModule.AccountDeletePOSTHandler(ctx) suite.accountsModule.AccountDeletePOSTHandler(ctx)
// 1. we should have Forbidden because we supplied the wrong password // 1. we should have Forbidden because we supplied the wrong password
suite.Equal(http.StatusForbidden, recorder.Code) suite.Equal(http.StatusForbidden, recorder.Code)
@ -87,10 +87,10 @@ func (suite *AccountDeleteTestSuite) TestAccountDeletePOSTHandlerNoPassword() {
} }
bodyBytes := requestBody.Bytes() bodyBytes := requestBody.Bytes()
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
ctx := suite.newContext(recorder, http.MethodPost, bodyBytes, account.DeleteAccountPath, w.FormDataContentType()) ctx := suite.newContext(recorder, http.MethodPost, bodyBytes, accounts.DeleteAccountPath, w.FormDataContentType())
// call the handler // call the handler
suite.accountModule.AccountDeletePOSTHandler(ctx) suite.accountsModule.AccountDeletePOSTHandler(ctx)
// 1. we should have StatusBadRequest because our request was invalid // 1. we should have StatusBadRequest because our request was invalid
suite.Equal(http.StatusBadRequest, recorder.Code) suite.Equal(http.StatusBadRequest, recorder.Code)

View file

@ -16,14 +16,14 @@
along with this program. If not, see <http://www.gnu.org/licenses/>. along with this program. If not, see <http://www.gnu.org/licenses/>.
*/ */
package account package accounts
import ( import (
"errors" "errors"
"net/http" "net/http"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/superseriousbusiness/gotosocial/internal/api" apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util"
"github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/oauth" "github.com/superseriousbusiness/gotosocial/internal/oauth"
) )
@ -69,25 +69,25 @@ import (
func (m *Module) AccountGETHandler(c *gin.Context) { func (m *Module) AccountGETHandler(c *gin.Context) {
authed, err := oauth.Authed(c, true, true, true, true) authed, err := oauth.Authed(c, true, true, true, true)
if err != nil { if err != nil {
api.ErrorHandler(c, gtserror.NewErrorUnauthorized(err, err.Error()), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorUnauthorized(err, err.Error()), m.processor.InstanceGet)
return return
} }
if _, err := api.NegotiateAccept(c, api.JSONAcceptHeaders...); err != nil { if _, err := apiutil.NegotiateAccept(c, apiutil.JSONAcceptHeaders...); err != nil {
api.ErrorHandler(c, gtserror.NewErrorNotAcceptable(err, err.Error()), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorNotAcceptable(err, err.Error()), m.processor.InstanceGet)
return return
} }
targetAcctID := c.Param(IDKey) targetAcctID := c.Param(IDKey)
if targetAcctID == "" { if targetAcctID == "" {
err := errors.New("no account id specified") err := errors.New("no account id specified")
api.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGet)
return return
} }
acctInfo, errWithCode := m.processor.AccountGet(c.Request.Context(), authed, targetAcctID) acctInfo, errWithCode := m.processor.AccountGet(c.Request.Context(), authed, targetAcctID)
if errWithCode != nil { if errWithCode != nil {
api.ErrorHandler(c, errWithCode, m.processor.InstanceGet) apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGet)
return return
} }

View file

@ -16,17 +16,13 @@
along with this program. If not, see <http://www.gnu.org/licenses/>. along with this program. If not, see <http://www.gnu.org/licenses/>.
*/ */
package account package accounts
import ( import (
"net/http" "net/http"
"strings"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/superseriousbusiness/gotosocial/internal/api"
"github.com/superseriousbusiness/gotosocial/internal/processing" "github.com/superseriousbusiness/gotosocial/internal/processing"
"github.com/superseriousbusiness/gotosocial/internal/router"
) )
const ( const (
@ -49,8 +45,8 @@ const (
// IDKey is the key to use for retrieving account ID in requests // IDKey is the key to use for retrieving account ID in requests
IDKey = "id" IDKey = "id"
// BasePath is the base API path for this module // BasePath is the base API path for this module, excluding the 'api' prefix
BasePath = "/api/v1/accounts" BasePath = "/v1/accounts"
// BasePathWithID is the base path for this module with the ID key // BasePathWithID is the base path for this module with the ID key
BasePathWithID = BasePath + "/:" + IDKey BasePathWithID = BasePath + "/:" + IDKey
// VerifyPath is for verifying account credentials // VerifyPath is for verifying account credentials
@ -77,65 +73,47 @@ const (
DeleteAccountPath = BasePath + "/delete" DeleteAccountPath = BasePath + "/delete"
) )
// Module implements the ClientAPIModule interface for account-related actions
type Module struct { type Module struct {
processor processing.Processor processor processing.Processor
} }
// New returns a new account module func New(processor processing.Processor) *Module {
func New(processor processing.Processor) api.ClientModule {
return &Module{ return &Module{
processor: processor, processor: processor,
} }
} }
// Route attaches all routes from this module to the given router func (m *Module) Route(attachHandler func(method string, path string, f ...gin.HandlerFunc) gin.IRoutes) {
func (m *Module) Route(r router.Router) error {
// create account // create account
r.AttachHandler(http.MethodPost, BasePath, m.AccountCreatePOSTHandler) attachHandler(http.MethodPost, BasePath, m.AccountCreatePOSTHandler)
// delete account
r.AttachHandler(http.MethodPost, DeleteAccountPath, m.AccountDeletePOSTHandler)
// get account // get account
r.AttachHandler(http.MethodGet, BasePathWithID, m.muxHandler) attachHandler(http.MethodGet, BasePathWithID, m.AccountGETHandler)
// delete account
attachHandler(http.MethodPost, DeleteAccountPath, m.AccountDeletePOSTHandler)
// verify account
attachHandler(http.MethodGet, VerifyPath, m.AccountVerifyGETHandler)
// modify account // modify account
r.AttachHandler(http.MethodPatch, BasePathWithID, m.muxHandler) attachHandler(http.MethodPatch, UpdateCredentialsPath, m.AccountUpdateCredentialsPATCHHandler)
// get account's statuses // get account's statuses
r.AttachHandler(http.MethodGet, GetStatusesPath, m.AccountStatusesGETHandler) attachHandler(http.MethodGet, GetStatusesPath, m.AccountStatusesGETHandler)
// get following or followers // get following or followers
r.AttachHandler(http.MethodGet, GetFollowersPath, m.AccountFollowersGETHandler) attachHandler(http.MethodGet, GetFollowersPath, m.AccountFollowersGETHandler)
r.AttachHandler(http.MethodGet, GetFollowingPath, m.AccountFollowingGETHandler) attachHandler(http.MethodGet, GetFollowingPath, m.AccountFollowingGETHandler)
// get relationship with account // get relationship with account
r.AttachHandler(http.MethodGet, GetRelationshipsPath, m.AccountRelationshipsGETHandler) attachHandler(http.MethodGet, GetRelationshipsPath, m.AccountRelationshipsGETHandler)
// follow or unfollow account // follow or unfollow account
r.AttachHandler(http.MethodPost, FollowPath, m.AccountFollowPOSTHandler) attachHandler(http.MethodPost, FollowPath, m.AccountFollowPOSTHandler)
r.AttachHandler(http.MethodPost, UnfollowPath, m.AccountUnfollowPOSTHandler) attachHandler(http.MethodPost, UnfollowPath, m.AccountUnfollowPOSTHandler)
// block or unblock account // block or unblock account
r.AttachHandler(http.MethodPost, BlockPath, m.AccountBlockPOSTHandler) attachHandler(http.MethodPost, BlockPath, m.AccountBlockPOSTHandler)
r.AttachHandler(http.MethodPost, UnblockPath, m.AccountUnblockPOSTHandler) attachHandler(http.MethodPost, UnblockPath, m.AccountUnblockPOSTHandler)
return nil
}
func (m *Module) muxHandler(c *gin.Context) {
ru := c.Request.RequestURI
switch c.Request.Method {
case http.MethodGet:
if strings.HasPrefix(ru, VerifyPath) {
m.AccountVerifyGETHandler(c)
} else {
m.AccountGETHandler(c)
}
case http.MethodPatch:
if strings.HasPrefix(ru, UpdateCredentialsPath) {
m.AccountUpdateCredentialsPATCHHandler(c)
}
}
} }

View file

@ -16,7 +16,7 @@
along with this program. If not, see <http://www.gnu.org/licenses/>. along with this program. If not, see <http://www.gnu.org/licenses/>.
*/ */
package account package accounts
import ( import (
"errors" "errors"
@ -25,8 +25,8 @@ import (
"strconv" "strconv"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/superseriousbusiness/gotosocial/internal/api" apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
"github.com/superseriousbusiness/gotosocial/internal/api/model" apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util"
"github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/oauth" "github.com/superseriousbusiness/gotosocial/internal/oauth"
) )
@ -138,33 +138,33 @@ import (
func (m *Module) AccountUpdateCredentialsPATCHHandler(c *gin.Context) { func (m *Module) AccountUpdateCredentialsPATCHHandler(c *gin.Context) {
authed, err := oauth.Authed(c, true, true, true, true) authed, err := oauth.Authed(c, true, true, true, true)
if err != nil { if err != nil {
api.ErrorHandler(c, gtserror.NewErrorUnauthorized(err, err.Error()), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorUnauthorized(err, err.Error()), m.processor.InstanceGet)
return return
} }
if _, err := api.NegotiateAccept(c, api.JSONAcceptHeaders...); err != nil { if _, err := apiutil.NegotiateAccept(c, apiutil.JSONAcceptHeaders...); err != nil {
api.ErrorHandler(c, gtserror.NewErrorNotAcceptable(err, err.Error()), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorNotAcceptable(err, err.Error()), m.processor.InstanceGet)
return return
} }
form, err := parseUpdateAccountForm(c) form, err := parseUpdateAccountForm(c)
if err != nil { if err != nil {
api.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGet)
return return
} }
acctSensitive, errWithCode := m.processor.AccountUpdate(c.Request.Context(), authed, form) acctSensitive, errWithCode := m.processor.AccountUpdate(c.Request.Context(), authed, form)
if errWithCode != nil { if errWithCode != nil {
api.ErrorHandler(c, errWithCode, m.processor.InstanceGet) apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGet)
return return
} }
c.JSON(http.StatusOK, acctSensitive) c.JSON(http.StatusOK, acctSensitive)
} }
func parseUpdateAccountForm(c *gin.Context) (*model.UpdateCredentialsRequest, error) { func parseUpdateAccountForm(c *gin.Context) (*apimodel.UpdateCredentialsRequest, error) {
form := &model.UpdateCredentialsRequest{ form := &apimodel.UpdateCredentialsRequest{
Source: &model.UpdateSource{}, Source: &apimodel.UpdateSource{},
} }
if err := c.ShouldBind(&form); err != nil { if err := c.ShouldBind(&form); err != nil {

View file

@ -16,7 +16,7 @@
along with this program. If not, see <http://www.gnu.org/licenses/>. along with this program. If not, see <http://www.gnu.org/licenses/>.
*/ */
package account_test package accounts_test
import ( import (
"context" "context"
@ -27,7 +27,7 @@ import (
"testing" "testing"
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/api/client/account" "github.com/superseriousbusiness/gotosocial/internal/api/client/accounts"
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
"github.com/superseriousbusiness/gotosocial/testrig" "github.com/superseriousbusiness/gotosocial/testrig"
) )
@ -50,10 +50,10 @@ func (suite *AccountUpdateTestSuite) TestAccountUpdateCredentialsPATCHHandler()
} }
bodyBytes := requestBody.Bytes() bodyBytes := requestBody.Bytes()
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
ctx := suite.newContext(recorder, http.MethodPatch, bodyBytes, account.UpdateCredentialsPath, w.FormDataContentType()) ctx := suite.newContext(recorder, http.MethodPatch, bodyBytes, accounts.UpdateCredentialsPath, w.FormDataContentType())
// call the handler // call the handler
suite.accountModule.AccountUpdateCredentialsPATCHHandler(ctx) suite.accountsModule.AccountUpdateCredentialsPATCHHandler(ctx)
// 1. we should have OK because our request was valid // 1. we should have OK because our request was valid
suite.Equal(http.StatusOK, recorder.Code) suite.Equal(http.StatusOK, recorder.Code)
@ -89,10 +89,10 @@ func (suite *AccountUpdateTestSuite) TestAccountUpdateCredentialsPATCHHandlerUnl
} }
bodyBytes1 := requestBody1.Bytes() bodyBytes1 := requestBody1.Bytes()
recorder1 := httptest.NewRecorder() recorder1 := httptest.NewRecorder()
ctx1 := suite.newContext(recorder1, http.MethodPatch, bodyBytes1, account.UpdateCredentialsPath, w1.FormDataContentType()) ctx1 := suite.newContext(recorder1, http.MethodPatch, bodyBytes1, accounts.UpdateCredentialsPath, w1.FormDataContentType())
// call the handler // call the handler
suite.accountModule.AccountUpdateCredentialsPATCHHandler(ctx1) suite.accountsModule.AccountUpdateCredentialsPATCHHandler(ctx1)
// 1. we should have OK because our request was valid // 1. we should have OK because our request was valid
suite.Equal(http.StatusOK, recorder1.Code) suite.Equal(http.StatusOK, recorder1.Code)
@ -125,10 +125,10 @@ func (suite *AccountUpdateTestSuite) TestAccountUpdateCredentialsPATCHHandlerUnl
} }
bodyBytes2 := requestBody2.Bytes() bodyBytes2 := requestBody2.Bytes()
recorder2 := httptest.NewRecorder() recorder2 := httptest.NewRecorder()
ctx2 := suite.newContext(recorder2, http.MethodPatch, bodyBytes2, account.UpdateCredentialsPath, w2.FormDataContentType()) ctx2 := suite.newContext(recorder2, http.MethodPatch, bodyBytes2, accounts.UpdateCredentialsPath, w2.FormDataContentType())
// call the handler // call the handler
suite.accountModule.AccountUpdateCredentialsPATCHHandler(ctx2) suite.accountsModule.AccountUpdateCredentialsPATCHHandler(ctx2)
// 1. we should have OK because our request was valid // 1. we should have OK because our request was valid
suite.Equal(http.StatusOK, recorder1.Code) suite.Equal(http.StatusOK, recorder1.Code)
@ -170,10 +170,10 @@ func (suite *AccountUpdateTestSuite) TestAccountUpdateCredentialsPATCHHandlerGet
} }
bodyBytes := requestBody.Bytes() bodyBytes := requestBody.Bytes()
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
ctx := suite.newContext(recorder, http.MethodPatch, bodyBytes, account.UpdateCredentialsPath, w.FormDataContentType()) ctx := suite.newContext(recorder, http.MethodPatch, bodyBytes, accounts.UpdateCredentialsPath, w.FormDataContentType())
// call the handler // call the handler
suite.accountModule.AccountUpdateCredentialsPATCHHandler(ctx) suite.accountsModule.AccountUpdateCredentialsPATCHHandler(ctx)
// 1. we should have OK because our request was valid // 1. we should have OK because our request was valid
suite.Equal(http.StatusOK, recorder.Code) suite.Equal(http.StatusOK, recorder.Code)
@ -212,10 +212,10 @@ func (suite *AccountUpdateTestSuite) TestAccountUpdateCredentialsPATCHHandlerTwo
} }
bodyBytes := requestBody.Bytes() bodyBytes := requestBody.Bytes()
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
ctx := suite.newContext(recorder, http.MethodPatch, bodyBytes, account.UpdateCredentialsPath, w.FormDataContentType()) ctx := suite.newContext(recorder, http.MethodPatch, bodyBytes, accounts.UpdateCredentialsPath, w.FormDataContentType())
// call the handler // call the handler
suite.accountModule.AccountUpdateCredentialsPATCHHandler(ctx) suite.accountsModule.AccountUpdateCredentialsPATCHHandler(ctx)
// 1. we should have OK because our request was valid // 1. we should have OK because our request was valid
suite.Equal(http.StatusOK, recorder.Code) suite.Equal(http.StatusOK, recorder.Code)
@ -266,10 +266,10 @@ func (suite *AccountUpdateTestSuite) TestAccountUpdateCredentialsPATCHHandlerWit
} }
bodyBytes := requestBody.Bytes() bodyBytes := requestBody.Bytes()
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
ctx := suite.newContext(recorder, http.MethodPatch, bodyBytes, account.UpdateCredentialsPath, w.FormDataContentType()) ctx := suite.newContext(recorder, http.MethodPatch, bodyBytes, accounts.UpdateCredentialsPath, w.FormDataContentType())
// call the handler // call the handler
suite.accountModule.AccountUpdateCredentialsPATCHHandler(ctx) suite.accountsModule.AccountUpdateCredentialsPATCHHandler(ctx)
// 1. we should have OK because our request was valid // 1. we should have OK because our request was valid
suite.Equal(http.StatusOK, recorder.Code) suite.Equal(http.StatusOK, recorder.Code)
@ -308,10 +308,10 @@ func (suite *AccountUpdateTestSuite) TestAccountUpdateCredentialsPATCHHandlerEmp
// set up the request // set up the request
bodyBytes := []byte{} bodyBytes := []byte{}
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
ctx := suite.newContext(recorder, http.MethodPatch, bodyBytes, account.UpdateCredentialsPath, "") ctx := suite.newContext(recorder, http.MethodPatch, bodyBytes, accounts.UpdateCredentialsPath, "")
// call the handler // call the handler
suite.accountModule.AccountUpdateCredentialsPATCHHandler(ctx) suite.accountsModule.AccountUpdateCredentialsPATCHHandler(ctx)
// 1. we should have OK because our request was valid // 1. we should have OK because our request was valid
suite.Equal(http.StatusBadRequest, recorder.Code) suite.Equal(http.StatusBadRequest, recorder.Code)
@ -343,10 +343,10 @@ func (suite *AccountUpdateTestSuite) TestAccountUpdateCredentialsPATCHHandlerUpd
} }
bodyBytes := requestBody.Bytes() bodyBytes := requestBody.Bytes()
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
ctx := suite.newContext(recorder, http.MethodPatch, bodyBytes, account.UpdateCredentialsPath, w.FormDataContentType()) ctx := suite.newContext(recorder, http.MethodPatch, bodyBytes, accounts.UpdateCredentialsPath, w.FormDataContentType())
// call the handler // call the handler
suite.accountModule.AccountUpdateCredentialsPATCHHandler(ctx) suite.accountsModule.AccountUpdateCredentialsPATCHHandler(ctx)
// 1. we should have OK because our request was valid // 1. we should have OK because our request was valid
suite.Equal(http.StatusOK, recorder.Code) suite.Equal(http.StatusOK, recorder.Code)
@ -385,10 +385,10 @@ func (suite *AccountUpdateTestSuite) TestAccountUpdateCredentialsPATCHHandlerUpd
} }
bodyBytes := requestBody.Bytes() bodyBytes := requestBody.Bytes()
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
ctx := suite.newContext(recorder, http.MethodPatch, bodyBytes, account.UpdateCredentialsPath, w.FormDataContentType()) ctx := suite.newContext(recorder, http.MethodPatch, bodyBytes, accounts.UpdateCredentialsPath, w.FormDataContentType())
// call the handler // call the handler
suite.accountModule.AccountUpdateCredentialsPATCHHandler(ctx) suite.accountsModule.AccountUpdateCredentialsPATCHHandler(ctx)
// 1. we should have OK because our request was valid // 1. we should have OK because our request was valid
suite.Equal(http.StatusOK, recorder.Code) suite.Equal(http.StatusOK, recorder.Code)
@ -430,10 +430,10 @@ func (suite *AccountUpdateTestSuite) TestAccountUpdateCredentialsPATCHHandlerUpd
} }
bodyBytes := requestBody.Bytes() bodyBytes := requestBody.Bytes()
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
ctx := suite.newContext(recorder, http.MethodPatch, bodyBytes, account.UpdateCredentialsPath, w.FormDataContentType()) ctx := suite.newContext(recorder, http.MethodPatch, bodyBytes, accounts.UpdateCredentialsPath, w.FormDataContentType())
// call the handler // call the handler
suite.accountModule.AccountUpdateCredentialsPATCHHandler(ctx) suite.accountsModule.AccountUpdateCredentialsPATCHHandler(ctx)
suite.Equal(http.StatusBadRequest, recorder.Code) suite.Equal(http.StatusBadRequest, recorder.Code)

View file

@ -16,13 +16,13 @@
along with this program. If not, see <http://www.gnu.org/licenses/>. along with this program. If not, see <http://www.gnu.org/licenses/>.
*/ */
package account package accounts
import ( import (
"net/http" "net/http"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/superseriousbusiness/gotosocial/internal/api" apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util"
"github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/oauth" "github.com/superseriousbusiness/gotosocial/internal/oauth"
) )
@ -59,18 +59,18 @@ import (
func (m *Module) AccountVerifyGETHandler(c *gin.Context) { func (m *Module) AccountVerifyGETHandler(c *gin.Context) {
authed, err := oauth.Authed(c, true, true, true, true) authed, err := oauth.Authed(c, true, true, true, true)
if err != nil { if err != nil {
api.ErrorHandler(c, gtserror.NewErrorUnauthorized(err, err.Error()), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorUnauthorized(err, err.Error()), m.processor.InstanceGet)
return return
} }
if _, err := api.NegotiateAccept(c, api.JSONAcceptHeaders...); err != nil { if _, err := apiutil.NegotiateAccept(c, apiutil.JSONAcceptHeaders...); err != nil {
api.ErrorHandler(c, gtserror.NewErrorNotAcceptable(err, err.Error()), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorNotAcceptable(err, err.Error()), m.processor.InstanceGet)
return return
} }
acctSensitive, errWithCode := m.processor.AccountGet(c.Request.Context(), authed, authed.Account.ID) acctSensitive, errWithCode := m.processor.AccountGet(c.Request.Context(), authed, authed.Account.ID)
if errWithCode != nil { if errWithCode != nil {
api.ErrorHandler(c, errWithCode, m.processor.InstanceGet) apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGet)
return return
} }

View file

@ -16,7 +16,7 @@
along with this program. If not, see <http://www.gnu.org/licenses/>. along with this program. If not, see <http://www.gnu.org/licenses/>.
*/ */
package account_test package accounts_test
import ( import (
"encoding/json" "encoding/json"
@ -28,7 +28,7 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/api/client/account" "github.com/superseriousbusiness/gotosocial/internal/api/client/accounts"
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
) )
@ -42,10 +42,10 @@ func (suite *AccountVerifyTestSuite) TestAccountVerifyGet() {
// set up the request // set up the request
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
ctx := suite.newContext(recorder, http.MethodGet, nil, account.VerifyPath, "") ctx := suite.newContext(recorder, http.MethodGet, nil, accounts.VerifyPath, "")
// call the handler // call the handler
suite.accountModule.AccountVerifyGETHandler(ctx) suite.accountsModule.AccountVerifyGETHandler(ctx)
// 1. we should have OK because our request was valid // 1. we should have OK because our request was valid
suite.Equal(http.StatusOK, recorder.Code) suite.Equal(http.StatusOK, recorder.Code)

View file

@ -16,14 +16,14 @@
along with this program. If not, see <http://www.gnu.org/licenses/>. along with this program. If not, see <http://www.gnu.org/licenses/>.
*/ */
package account package accounts
import ( import (
"errors" "errors"
"net/http" "net/http"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/superseriousbusiness/gotosocial/internal/api" apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util"
"github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/oauth" "github.com/superseriousbusiness/gotosocial/internal/oauth"
) )
@ -69,25 +69,25 @@ import (
func (m *Module) AccountBlockPOSTHandler(c *gin.Context) { func (m *Module) AccountBlockPOSTHandler(c *gin.Context) {
authed, err := oauth.Authed(c, true, true, true, true) authed, err := oauth.Authed(c, true, true, true, true)
if err != nil { if err != nil {
api.ErrorHandler(c, gtserror.NewErrorUnauthorized(err, err.Error()), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorUnauthorized(err, err.Error()), m.processor.InstanceGet)
return return
} }
if _, err := api.NegotiateAccept(c, api.JSONAcceptHeaders...); err != nil { if _, err := apiutil.NegotiateAccept(c, apiutil.JSONAcceptHeaders...); err != nil {
api.ErrorHandler(c, gtserror.NewErrorNotAcceptable(err, err.Error()), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorNotAcceptable(err, err.Error()), m.processor.InstanceGet)
return return
} }
targetAcctID := c.Param(IDKey) targetAcctID := c.Param(IDKey)
if targetAcctID == "" { if targetAcctID == "" {
err := errors.New("no account id specified") err := errors.New("no account id specified")
api.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGet)
return return
} }
relationship, errWithCode := m.processor.AccountBlockCreate(c.Request.Context(), authed, targetAcctID) relationship, errWithCode := m.processor.AccountBlockCreate(c.Request.Context(), authed, targetAcctID)
if errWithCode != nil { if errWithCode != nil {
api.ErrorHandler(c, errWithCode, m.processor.InstanceGet) apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGet)
return return
} }

View file

@ -16,7 +16,7 @@
along with this program. If not, see <http://www.gnu.org/licenses/>. along with this program. If not, see <http://www.gnu.org/licenses/>.
*/ */
package account_test package accounts_test
import ( import (
"fmt" "fmt"
@ -29,7 +29,7 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/api/client/account" "github.com/superseriousbusiness/gotosocial/internal/api/client/accounts"
"github.com/superseriousbusiness/gotosocial/internal/oauth" "github.com/superseriousbusiness/gotosocial/internal/oauth"
"github.com/superseriousbusiness/gotosocial/testrig" "github.com/superseriousbusiness/gotosocial/testrig"
) )
@ -46,16 +46,16 @@ func (suite *BlockTestSuite) TestBlockSelf() {
ctx.Set(oauth.SessionAuthorizedToken, oauth.DBTokenToToken(suite.testTokens["local_account_1"])) ctx.Set(oauth.SessionAuthorizedToken, oauth.DBTokenToToken(suite.testTokens["local_account_1"]))
ctx.Set(oauth.SessionAuthorizedApplication, suite.testApplications["application_1"]) ctx.Set(oauth.SessionAuthorizedApplication, suite.testApplications["application_1"])
ctx.Set(oauth.SessionAuthorizedUser, suite.testUsers["local_account_1"]) ctx.Set(oauth.SessionAuthorizedUser, suite.testUsers["local_account_1"])
ctx.Request = httptest.NewRequest(http.MethodPost, fmt.Sprintf("http://localhost:8080%s", strings.Replace(account.BlockPath, ":id", testAcct.ID, 1)), nil) ctx.Request = httptest.NewRequest(http.MethodPost, fmt.Sprintf("http://localhost:8080%s", strings.Replace(accounts.BlockPath, ":id", testAcct.ID, 1)), nil)
ctx.Params = gin.Params{ ctx.Params = gin.Params{
gin.Param{ gin.Param{
Key: account.IDKey, Key: accounts.IDKey,
Value: testAcct.ID, Value: testAcct.ID,
}, },
} }
suite.accountModule.AccountBlockPOSTHandler(ctx) suite.accountsModule.AccountBlockPOSTHandler(ctx)
// 1. status should be Not Acceptable due to attempted self-block // 1. status should be Not Acceptable due to attempted self-block
suite.Equal(http.StatusNotAcceptable, recorder.Code) suite.Equal(http.StatusNotAcceptable, recorder.Code)

View file

@ -16,15 +16,15 @@
along with this program. If not, see <http://www.gnu.org/licenses/>. along with this program. If not, see <http://www.gnu.org/licenses/>.
*/ */
package account package accounts
import ( import (
"errors" "errors"
"net/http" "net/http"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/superseriousbusiness/gotosocial/internal/api" apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
"github.com/superseriousbusiness/gotosocial/internal/api/model" apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util"
"github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/oauth" "github.com/superseriousbusiness/gotosocial/internal/oauth"
) )
@ -91,32 +91,32 @@ import (
func (m *Module) AccountFollowPOSTHandler(c *gin.Context) { func (m *Module) AccountFollowPOSTHandler(c *gin.Context) {
authed, err := oauth.Authed(c, true, true, true, true) authed, err := oauth.Authed(c, true, true, true, true)
if err != nil { if err != nil {
api.ErrorHandler(c, gtserror.NewErrorUnauthorized(err, err.Error()), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorUnauthorized(err, err.Error()), m.processor.InstanceGet)
return return
} }
if _, err := api.NegotiateAccept(c, api.JSONAcceptHeaders...); err != nil { if _, err := apiutil.NegotiateAccept(c, apiutil.JSONAcceptHeaders...); err != nil {
api.ErrorHandler(c, gtserror.NewErrorNotAcceptable(err, err.Error()), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorNotAcceptable(err, err.Error()), m.processor.InstanceGet)
return return
} }
targetAcctID := c.Param(IDKey) targetAcctID := c.Param(IDKey)
if targetAcctID == "" { if targetAcctID == "" {
err := errors.New("no account id specified") err := errors.New("no account id specified")
api.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGet)
return return
} }
form := &model.AccountFollowRequest{} form := &apimodel.AccountFollowRequest{}
if err := c.ShouldBind(form); err != nil { if err := c.ShouldBind(form); err != nil {
api.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGet)
return return
} }
form.ID = targetAcctID form.ID = targetAcctID
relationship, errWithCode := m.processor.AccountFollowCreate(c.Request.Context(), authed, form) relationship, errWithCode := m.processor.AccountFollowCreate(c.Request.Context(), authed, form)
if errWithCode != nil { if errWithCode != nil {
api.ErrorHandler(c, errWithCode, m.processor.InstanceGet) apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGet)
return return
} }

View file

@ -16,7 +16,7 @@
along with this program. If not, see <http://www.gnu.org/licenses/>. along with this program. If not, see <http://www.gnu.org/licenses/>.
*/ */
package account_test package accounts_test
import ( import (
"fmt" "fmt"
@ -29,7 +29,7 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/api/client/account" "github.com/superseriousbusiness/gotosocial/internal/api/client/accounts"
"github.com/superseriousbusiness/gotosocial/internal/oauth" "github.com/superseriousbusiness/gotosocial/internal/oauth"
"github.com/superseriousbusiness/gotosocial/testrig" "github.com/superseriousbusiness/gotosocial/testrig"
) )
@ -46,17 +46,17 @@ func (suite *FollowTestSuite) TestFollowSelf() {
ctx.Set(oauth.SessionAuthorizedToken, oauth.DBTokenToToken(suite.testTokens["local_account_1"])) ctx.Set(oauth.SessionAuthorizedToken, oauth.DBTokenToToken(suite.testTokens["local_account_1"]))
ctx.Set(oauth.SessionAuthorizedApplication, suite.testApplications["application_1"]) ctx.Set(oauth.SessionAuthorizedApplication, suite.testApplications["application_1"])
ctx.Set(oauth.SessionAuthorizedUser, suite.testUsers["local_account_1"]) ctx.Set(oauth.SessionAuthorizedUser, suite.testUsers["local_account_1"])
ctx.Request = httptest.NewRequest(http.MethodPost, fmt.Sprintf("http://localhost:8080%s", strings.Replace(account.FollowPath, ":id", testAcct.ID, 1)), nil) ctx.Request = httptest.NewRequest(http.MethodPost, fmt.Sprintf("http://localhost:8080%s", strings.Replace(accounts.FollowPath, ":id", testAcct.ID, 1)), nil)
ctx.Params = gin.Params{ ctx.Params = gin.Params{
gin.Param{ gin.Param{
Key: account.IDKey, Key: accounts.IDKey,
Value: testAcct.ID, Value: testAcct.ID,
}, },
} }
// call the handler // call the handler
suite.accountModule.AccountFollowPOSTHandler(ctx) suite.accountsModule.AccountFollowPOSTHandler(ctx)
// 1. status should be Not Acceptable due to self-follow attempt // 1. status should be Not Acceptable due to self-follow attempt
suite.Equal(http.StatusNotAcceptable, recorder.Code) suite.Equal(http.StatusNotAcceptable, recorder.Code)

View file

@ -16,14 +16,14 @@
along with this program. If not, see <http://www.gnu.org/licenses/>. along with this program. If not, see <http://www.gnu.org/licenses/>.
*/ */
package account package accounts
import ( import (
"errors" "errors"
"net/http" "net/http"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/superseriousbusiness/gotosocial/internal/api" apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util"
"github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/oauth" "github.com/superseriousbusiness/gotosocial/internal/oauth"
) )
@ -72,25 +72,25 @@ import (
func (m *Module) AccountFollowersGETHandler(c *gin.Context) { func (m *Module) AccountFollowersGETHandler(c *gin.Context) {
authed, err := oauth.Authed(c, true, true, true, true) authed, err := oauth.Authed(c, true, true, true, true)
if err != nil { if err != nil {
api.ErrorHandler(c, gtserror.NewErrorUnauthorized(err, err.Error()), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorUnauthorized(err, err.Error()), m.processor.InstanceGet)
return return
} }
if _, err := api.NegotiateAccept(c, api.JSONAcceptHeaders...); err != nil { if _, err := apiutil.NegotiateAccept(c, apiutil.JSONAcceptHeaders...); err != nil {
api.ErrorHandler(c, gtserror.NewErrorNotAcceptable(err, err.Error()), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorNotAcceptable(err, err.Error()), m.processor.InstanceGet)
return return
} }
targetAcctID := c.Param(IDKey) targetAcctID := c.Param(IDKey)
if targetAcctID == "" { if targetAcctID == "" {
err := errors.New("no account id specified") err := errors.New("no account id specified")
api.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGet)
return return
} }
followers, errWithCode := m.processor.AccountFollowersGet(c.Request.Context(), authed, targetAcctID) followers, errWithCode := m.processor.AccountFollowersGet(c.Request.Context(), authed, targetAcctID)
if errWithCode != nil { if errWithCode != nil {
api.ErrorHandler(c, errWithCode, m.processor.InstanceGet) apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGet)
return return
} }

View file

@ -16,14 +16,14 @@
along with this program. If not, see <http://www.gnu.org/licenses/>. along with this program. If not, see <http://www.gnu.org/licenses/>.
*/ */
package account package accounts
import ( import (
"errors" "errors"
"net/http" "net/http"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/superseriousbusiness/gotosocial/internal/api" apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util"
"github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/oauth" "github.com/superseriousbusiness/gotosocial/internal/oauth"
) )
@ -72,25 +72,25 @@ import (
func (m *Module) AccountFollowingGETHandler(c *gin.Context) { func (m *Module) AccountFollowingGETHandler(c *gin.Context) {
authed, err := oauth.Authed(c, true, true, true, true) authed, err := oauth.Authed(c, true, true, true, true)
if err != nil { if err != nil {
api.ErrorHandler(c, gtserror.NewErrorUnauthorized(err, err.Error()), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorUnauthorized(err, err.Error()), m.processor.InstanceGet)
return return
} }
if _, err := api.NegotiateAccept(c, api.JSONAcceptHeaders...); err != nil { if _, err := apiutil.NegotiateAccept(c, apiutil.JSONAcceptHeaders...); err != nil {
api.ErrorHandler(c, gtserror.NewErrorNotAcceptable(err, err.Error()), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorNotAcceptable(err, err.Error()), m.processor.InstanceGet)
return return
} }
targetAcctID := c.Param(IDKey) targetAcctID := c.Param(IDKey)
if targetAcctID == "" { if targetAcctID == "" {
err := errors.New("no account id specified") err := errors.New("no account id specified")
api.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGet)
return return
} }
following, errWithCode := m.processor.AccountFollowingGet(c.Request.Context(), authed, targetAcctID) following, errWithCode := m.processor.AccountFollowingGet(c.Request.Context(), authed, targetAcctID)
if errWithCode != nil { if errWithCode != nil {
api.ErrorHandler(c, errWithCode, m.processor.InstanceGet) apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGet)
return return
} }

View file

@ -1,12 +1,12 @@
package account package accounts
import ( import (
"errors" "errors"
"net/http" "net/http"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/superseriousbusiness/gotosocial/internal/api" apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
"github.com/superseriousbusiness/gotosocial/internal/api/model" apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util"
"github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/oauth" "github.com/superseriousbusiness/gotosocial/internal/oauth"
) )
@ -57,12 +57,12 @@ import (
func (m *Module) AccountRelationshipsGETHandler(c *gin.Context) { func (m *Module) AccountRelationshipsGETHandler(c *gin.Context) {
authed, err := oauth.Authed(c, true, true, true, true) authed, err := oauth.Authed(c, true, true, true, true)
if err != nil { if err != nil {
api.ErrorHandler(c, gtserror.NewErrorUnauthorized(err, err.Error()), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorUnauthorized(err, err.Error()), m.processor.InstanceGet)
return return
} }
if _, err := api.NegotiateAccept(c, api.JSONAcceptHeaders...); err != nil { if _, err := apiutil.NegotiateAccept(c, apiutil.JSONAcceptHeaders...); err != nil {
api.ErrorHandler(c, gtserror.NewErrorNotAcceptable(err, err.Error()), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorNotAcceptable(err, err.Error()), m.processor.InstanceGet)
return return
} }
@ -72,18 +72,18 @@ func (m *Module) AccountRelationshipsGETHandler(c *gin.Context) {
id := c.Query("id") id := c.Query("id")
if id == "" { if id == "" {
err = errors.New("no account id(s) specified in query") err = errors.New("no account id(s) specified in query")
api.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGet)
return return
} }
targetAccountIDs = append(targetAccountIDs, id) targetAccountIDs = append(targetAccountIDs, id)
} }
relationships := []model.Relationship{} relationships := []apimodel.Relationship{}
for _, targetAccountID := range targetAccountIDs { for _, targetAccountID := range targetAccountIDs {
r, errWithCode := m.processor.AccountRelationshipGet(c.Request.Context(), authed, targetAccountID) r, errWithCode := m.processor.AccountRelationshipGet(c.Request.Context(), authed, targetAccountID)
if errWithCode != nil { if errWithCode != nil {
api.ErrorHandler(c, errWithCode, m.processor.InstanceGet) apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGet)
return return
} }
relationships = append(relationships, *r) relationships = append(relationships, *r)

View file

@ -16,7 +16,7 @@
along with this program. If not, see <http://www.gnu.org/licenses/>. along with this program. If not, see <http://www.gnu.org/licenses/>.
*/ */
package account package accounts
import ( import (
"errors" "errors"
@ -25,7 +25,7 @@ import (
"strconv" "strconv"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/superseriousbusiness/gotosocial/internal/api" apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util"
"github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/oauth" "github.com/superseriousbusiness/gotosocial/internal/oauth"
) )
@ -133,19 +133,19 @@ import (
func (m *Module) AccountStatusesGETHandler(c *gin.Context) { func (m *Module) AccountStatusesGETHandler(c *gin.Context) {
authed, err := oauth.Authed(c, false, false, false, false) authed, err := oauth.Authed(c, false, false, false, false)
if err != nil { if err != nil {
api.ErrorHandler(c, gtserror.NewErrorUnauthorized(err, err.Error()), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorUnauthorized(err, err.Error()), m.processor.InstanceGet)
return return
} }
if _, err := api.NegotiateAccept(c, api.JSONAcceptHeaders...); err != nil { if _, err := apiutil.NegotiateAccept(c, apiutil.JSONAcceptHeaders...); err != nil {
api.ErrorHandler(c, gtserror.NewErrorNotAcceptable(err, err.Error()), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorNotAcceptable(err, err.Error()), m.processor.InstanceGet)
return return
} }
targetAcctID := c.Param(IDKey) targetAcctID := c.Param(IDKey)
if targetAcctID == "" { if targetAcctID == "" {
err := errors.New("no account id specified") err := errors.New("no account id specified")
api.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGet)
return return
} }
@ -155,7 +155,7 @@ func (m *Module) AccountStatusesGETHandler(c *gin.Context) {
i, err := strconv.ParseInt(limitString, 10, 32) i, err := strconv.ParseInt(limitString, 10, 32)
if err != nil { if err != nil {
err := fmt.Errorf("error parsing %s: %s", LimitKey, err) err := fmt.Errorf("error parsing %s: %s", LimitKey, err)
api.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGet)
return return
} }
limit = int(i) limit = int(i)
@ -167,7 +167,7 @@ func (m *Module) AccountStatusesGETHandler(c *gin.Context) {
i, err := strconv.ParseBool(excludeRepliesString) i, err := strconv.ParseBool(excludeRepliesString)
if err != nil { if err != nil {
err := fmt.Errorf("error parsing %s: %s", ExcludeRepliesKey, err) err := fmt.Errorf("error parsing %s: %s", ExcludeRepliesKey, err)
api.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGet)
return return
} }
excludeReplies = i excludeReplies = i
@ -179,7 +179,7 @@ func (m *Module) AccountStatusesGETHandler(c *gin.Context) {
i, err := strconv.ParseBool(excludeReblogsString) i, err := strconv.ParseBool(excludeReblogsString)
if err != nil { if err != nil {
err := fmt.Errorf("error parsing %s: %s", ExcludeReblogsKey, err) err := fmt.Errorf("error parsing %s: %s", ExcludeReblogsKey, err)
api.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGet)
return return
} }
excludeReblogs = i excludeReblogs = i
@ -203,7 +203,7 @@ func (m *Module) AccountStatusesGETHandler(c *gin.Context) {
i, err := strconv.ParseBool(pinnedString) i, err := strconv.ParseBool(pinnedString)
if err != nil { if err != nil {
err := fmt.Errorf("error parsing %s: %s", PinnedKey, err) err := fmt.Errorf("error parsing %s: %s", PinnedKey, err)
api.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGet)
return return
} }
pinnedOnly = i pinnedOnly = i
@ -215,7 +215,7 @@ func (m *Module) AccountStatusesGETHandler(c *gin.Context) {
i, err := strconv.ParseBool(mediaOnlyString) i, err := strconv.ParseBool(mediaOnlyString)
if err != nil { if err != nil {
err := fmt.Errorf("error parsing %s: %s", OnlyMediaKey, err) err := fmt.Errorf("error parsing %s: %s", OnlyMediaKey, err)
api.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGet)
return return
} }
mediaOnly = i mediaOnly = i
@ -227,7 +227,7 @@ func (m *Module) AccountStatusesGETHandler(c *gin.Context) {
i, err := strconv.ParseBool(publicOnlyString) i, err := strconv.ParseBool(publicOnlyString)
if err != nil { if err != nil {
err := fmt.Errorf("error parsing %s: %s", OnlyPublicKey, err) err := fmt.Errorf("error parsing %s: %s", OnlyPublicKey, err)
api.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGet)
return return
} }
publicOnly = i publicOnly = i
@ -235,7 +235,7 @@ func (m *Module) AccountStatusesGETHandler(c *gin.Context) {
resp, errWithCode := m.processor.AccountStatusesGet(c.Request.Context(), authed, targetAcctID, limit, excludeReplies, excludeReblogs, maxID, minID, pinnedOnly, mediaOnly, publicOnly) resp, errWithCode := m.processor.AccountStatusesGet(c.Request.Context(), authed, targetAcctID, limit, excludeReplies, excludeReblogs, maxID, minID, pinnedOnly, mediaOnly, publicOnly)
if errWithCode != nil { if errWithCode != nil {
api.ErrorHandler(c, errWithCode, m.processor.InstanceGet) apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGet)
return return
} }

View file

@ -16,7 +16,7 @@
along with this program. If not, see <http://www.gnu.org/licenses/>. along with this program. If not, see <http://www.gnu.org/licenses/>.
*/ */
package account_test package accounts_test
import ( import (
"encoding/json" "encoding/json"
@ -29,7 +29,7 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/api/client/account" "github.com/superseriousbusiness/gotosocial/internal/api/client/accounts"
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
) )
@ -45,13 +45,13 @@ func (suite *AccountStatusesTestSuite) TestGetStatusesPublicOnly() {
ctx := suite.newContext(recorder, http.MethodGet, nil, fmt.Sprintf("/api/v1/accounts/%s/statuses?limit=20&only_media=false&only_public=true", targetAccount.ID), "") ctx := suite.newContext(recorder, http.MethodGet, nil, fmt.Sprintf("/api/v1/accounts/%s/statuses?limit=20&only_media=false&only_public=true", targetAccount.ID), "")
ctx.Params = gin.Params{ ctx.Params = gin.Params{
gin.Param{ gin.Param{
Key: account.IDKey, Key: accounts.IDKey,
Value: targetAccount.ID, Value: targetAccount.ID,
}, },
} }
// call the handler // call the handler
suite.accountModule.AccountStatusesGETHandler(ctx) suite.accountsModule.AccountStatusesGETHandler(ctx)
// 1. we should have OK because our request was valid // 1. we should have OK because our request was valid
suite.Equal(http.StatusOK, recorder.Code) suite.Equal(http.StatusOK, recorder.Code)
@ -85,13 +85,13 @@ func (suite *AccountStatusesTestSuite) TestGetStatusesPublicOnlyMediaOnly() {
ctx := suite.newContext(recorder, http.MethodGet, nil, fmt.Sprintf("/api/v1/accounts/%s/statuses?limit=20&only_media=true&only_public=true", targetAccount.ID), "") ctx := suite.newContext(recorder, http.MethodGet, nil, fmt.Sprintf("/api/v1/accounts/%s/statuses?limit=20&only_media=true&only_public=true", targetAccount.ID), "")
ctx.Params = gin.Params{ ctx.Params = gin.Params{
gin.Param{ gin.Param{
Key: account.IDKey, Key: accounts.IDKey,
Value: targetAccount.ID, Value: targetAccount.ID,
}, },
} }
// call the handler // call the handler
suite.accountModule.AccountStatusesGETHandler(ctx) suite.accountsModule.AccountStatusesGETHandler(ctx)
// 1. we should have OK because our request was valid // 1. we should have OK because our request was valid
suite.Equal(http.StatusOK, recorder.Code) suite.Equal(http.StatusOK, recorder.Code)

View file

@ -16,14 +16,14 @@
along with this program. If not, see <http://www.gnu.org/licenses/>. along with this program. If not, see <http://www.gnu.org/licenses/>.
*/ */
package account package accounts
import ( import (
"errors" "errors"
"net/http" "net/http"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/superseriousbusiness/gotosocial/internal/api" apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util"
"github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/oauth" "github.com/superseriousbusiness/gotosocial/internal/oauth"
) )
@ -70,25 +70,25 @@ import (
func (m *Module) AccountUnblockPOSTHandler(c *gin.Context) { func (m *Module) AccountUnblockPOSTHandler(c *gin.Context) {
authed, err := oauth.Authed(c, true, true, true, true) authed, err := oauth.Authed(c, true, true, true, true)
if err != nil { if err != nil {
api.ErrorHandler(c, gtserror.NewErrorUnauthorized(err, err.Error()), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorUnauthorized(err, err.Error()), m.processor.InstanceGet)
return return
} }
if _, err := api.NegotiateAccept(c, api.JSONAcceptHeaders...); err != nil { if _, err := apiutil.NegotiateAccept(c, apiutil.JSONAcceptHeaders...); err != nil {
api.ErrorHandler(c, gtserror.NewErrorNotAcceptable(err, err.Error()), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorNotAcceptable(err, err.Error()), m.processor.InstanceGet)
return return
} }
targetAcctID := c.Param(IDKey) targetAcctID := c.Param(IDKey)
if targetAcctID == "" { if targetAcctID == "" {
err := errors.New("no account id specified") err := errors.New("no account id specified")
api.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGet)
return return
} }
relationship, errWithCode := m.processor.AccountBlockRemove(c.Request.Context(), authed, targetAcctID) relationship, errWithCode := m.processor.AccountBlockRemove(c.Request.Context(), authed, targetAcctID)
if errWithCode != nil { if errWithCode != nil {
api.ErrorHandler(c, errWithCode, m.processor.InstanceGet) apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGet)
return return
} }

View file

@ -16,14 +16,14 @@
along with this program. If not, see <http://www.gnu.org/licenses/>. along with this program. If not, see <http://www.gnu.org/licenses/>.
*/ */
package account package accounts
import ( import (
"errors" "errors"
"net/http" "net/http"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/superseriousbusiness/gotosocial/internal/api" apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util"
"github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/oauth" "github.com/superseriousbusiness/gotosocial/internal/oauth"
) )
@ -70,25 +70,25 @@ import (
func (m *Module) AccountUnfollowPOSTHandler(c *gin.Context) { func (m *Module) AccountUnfollowPOSTHandler(c *gin.Context) {
authed, err := oauth.Authed(c, true, true, true, true) authed, err := oauth.Authed(c, true, true, true, true)
if err != nil { if err != nil {
api.ErrorHandler(c, gtserror.NewErrorUnauthorized(err, err.Error()), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorUnauthorized(err, err.Error()), m.processor.InstanceGet)
return return
} }
if _, err := api.NegotiateAccept(c, api.JSONAcceptHeaders...); err != nil { if _, err := apiutil.NegotiateAccept(c, apiutil.JSONAcceptHeaders...); err != nil {
api.ErrorHandler(c, gtserror.NewErrorNotAcceptable(err, err.Error()), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorNotAcceptable(err, err.Error()), m.processor.InstanceGet)
return return
} }
targetAcctID := c.Param(IDKey) targetAcctID := c.Param(IDKey)
if targetAcctID == "" { if targetAcctID == "" {
err := errors.New("no account id specified") err := errors.New("no account id specified")
api.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGet)
return return
} }
relationship, errWithCode := m.processor.AccountFollowRemove(c.Request.Context(), authed, targetAcctID) relationship, errWithCode := m.processor.AccountFollowRemove(c.Request.Context(), authed, targetAcctID)
if errWithCode != nil { if errWithCode != nil {
api.ErrorHandler(c, errWithCode, m.processor.InstanceGet) apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGet)
return return
} }

View file

@ -24,8 +24,8 @@ import (
"net/http" "net/http"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/superseriousbusiness/gotosocial/internal/api" apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
"github.com/superseriousbusiness/gotosocial/internal/api/model" apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util"
"github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/oauth" "github.com/superseriousbusiness/gotosocial/internal/oauth"
) )
@ -85,38 +85,38 @@ import (
func (m *Module) AccountActionPOSTHandler(c *gin.Context) { func (m *Module) AccountActionPOSTHandler(c *gin.Context) {
authed, err := oauth.Authed(c, true, true, true, true) authed, err := oauth.Authed(c, true, true, true, true)
if err != nil { if err != nil {
api.ErrorHandler(c, gtserror.NewErrorUnauthorized(err, err.Error()), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorUnauthorized(err, err.Error()), m.processor.InstanceGet)
return return
} }
if !*authed.User.Admin { if !*authed.User.Admin {
err := fmt.Errorf("user %s not an admin", authed.User.ID) err := fmt.Errorf("user %s not an admin", authed.User.ID)
api.ErrorHandler(c, gtserror.NewErrorForbidden(err, err.Error()), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorForbidden(err, err.Error()), m.processor.InstanceGet)
return return
} }
form := &model.AdminAccountActionRequest{} form := &apimodel.AdminAccountActionRequest{}
if err := c.ShouldBind(form); err != nil { if err := c.ShouldBind(form); err != nil {
api.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGet)
return return
} }
if form.Type == "" { if form.Type == "" {
err := errors.New("no type specified") err := errors.New("no type specified")
api.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGet)
return return
} }
targetAcctID := c.Param(IDKey) targetAcctID := c.Param(IDKey)
if targetAcctID == "" { if targetAcctID == "" {
err := errors.New("no account id specified") err := errors.New("no account id specified")
api.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGet)
return return
} }
form.TargetAccountID = targetAcctID form.TargetAccountID = targetAcctID
if errWithCode := m.processor.AdminAccountAction(c.Request.Context(), authed, form); errWithCode != nil { if errWithCode := m.processor.AdminAccountAction(c.Request.Context(), authed, form); errWithCode != nil {
api.ErrorHandler(c, errWithCode, m.processor.InstanceGet) apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGet)
return return
} }

View file

@ -21,14 +21,13 @@ package admin
import ( import (
"net/http" "net/http"
"github.com/superseriousbusiness/gotosocial/internal/api" "github.com/gin-gonic/gin"
"github.com/superseriousbusiness/gotosocial/internal/processing" "github.com/superseriousbusiness/gotosocial/internal/processing"
"github.com/superseriousbusiness/gotosocial/internal/router"
) )
const ( const (
// BasePath is the base API path for this module. // BasePath is the base API path for this module, excluding the api prefix
BasePath = "/api/v1/admin" BasePath = "/v1/admin"
// EmojiPath is used for posting/deleting custom emojis. // EmojiPath is used for posting/deleting custom emojis.
EmojiPath = BasePath + "/custom_emojis" EmojiPath = BasePath + "/custom_emojis"
// EmojiPathWithID is used for interacting with a single emoji. // EmojiPathWithID is used for interacting with a single emoji.
@ -68,32 +67,28 @@ const (
DomainQueryKey = "domain" DomainQueryKey = "domain"
) )
// Module implements the ClientAPIModule interface for admin-related actions (reports, emojis, etc)
type Module struct { type Module struct {
processor processing.Processor processor processing.Processor
} }
// New returns a new admin module func New(processor processing.Processor) *Module {
func New(processor processing.Processor) api.ClientModule {
return &Module{ return &Module{
processor: processor, processor: processor,
} }
} }
// Route attaches all routes from this module to the given router func (m *Module) Route(attachHandler func(method string, path string, f ...gin.HandlerFunc) gin.IRoutes) {
func (m *Module) Route(r router.Router) error { attachHandler(http.MethodPost, EmojiPath, m.EmojiCreatePOSTHandler)
r.AttachHandler(http.MethodPost, EmojiPath, m.EmojiCreatePOSTHandler) attachHandler(http.MethodGet, EmojiPath, m.EmojisGETHandler)
r.AttachHandler(http.MethodGet, EmojiPath, m.EmojisGETHandler) attachHandler(http.MethodDelete, EmojiPathWithID, m.EmojiDELETEHandler)
r.AttachHandler(http.MethodDelete, EmojiPathWithID, m.EmojiDELETEHandler) attachHandler(http.MethodGet, EmojiPathWithID, m.EmojiGETHandler)
r.AttachHandler(http.MethodGet, EmojiPathWithID, m.EmojiGETHandler) attachHandler(http.MethodPatch, EmojiPathWithID, m.EmojiPATCHHandler)
r.AttachHandler(http.MethodPatch, EmojiPathWithID, m.EmojiPATCHHandler) attachHandler(http.MethodPost, DomainBlocksPath, m.DomainBlocksPOSTHandler)
r.AttachHandler(http.MethodPost, DomainBlocksPath, m.DomainBlocksPOSTHandler) attachHandler(http.MethodGet, DomainBlocksPath, m.DomainBlocksGETHandler)
r.AttachHandler(http.MethodGet, DomainBlocksPath, m.DomainBlocksGETHandler) attachHandler(http.MethodGet, DomainBlocksPathWithID, m.DomainBlockGETHandler)
r.AttachHandler(http.MethodGet, DomainBlocksPathWithID, m.DomainBlockGETHandler) attachHandler(http.MethodDelete, DomainBlocksPathWithID, m.DomainBlockDELETEHandler)
r.AttachHandler(http.MethodDelete, DomainBlocksPathWithID, m.DomainBlockDELETEHandler) attachHandler(http.MethodPost, AccountsActionPath, m.AccountActionPOSTHandler)
r.AttachHandler(http.MethodPost, AccountsActionPath, m.AccountActionPOSTHandler) attachHandler(http.MethodPost, MediaCleanupPath, m.MediaCleanupPOSTHandler)
r.AttachHandler(http.MethodPost, MediaCleanupPath, m.MediaCleanupPOSTHandler) attachHandler(http.MethodPost, MediaRefetchPath, m.MediaRefetchPOSTHandler)
r.AttachHandler(http.MethodPost, MediaRefetchPath, m.MediaRefetchPOSTHandler) attachHandler(http.MethodGet, EmojiCategoriesPath, m.EmojiCategoriesGETHandler)
r.AttachHandler(http.MethodGet, EmojiCategoriesPath, m.EmojiCategoriesGETHandler)
return nil
} }

View file

@ -93,7 +93,7 @@ func (suite *AdminStandardTestSuite) SetupTest() {
suite.sentEmails = make(map[string]string) suite.sentEmails = make(map[string]string)
suite.emailSender = testrig.NewEmailSender("../../../../web/template/", suite.sentEmails) suite.emailSender = testrig.NewEmailSender("../../../../web/template/", suite.sentEmails)
suite.processor = testrig.NewTestProcessor(suite.db, suite.storage, suite.federator, suite.emailSender, suite.mediaManager, clientWorker, fedWorker) suite.processor = testrig.NewTestProcessor(suite.db, suite.storage, suite.federator, suite.emailSender, suite.mediaManager, clientWorker, fedWorker)
suite.adminModule = admin.New(suite.processor).(*admin.Module) suite.adminModule = admin.New(suite.processor)
testrig.StandardDBSetup(suite.db, nil) testrig.StandardDBSetup(suite.db, nil)
testrig.StandardStorageSetup(suite.storage, "../../../../testrig/media") testrig.StandardStorageSetup(suite.storage, "../../../../testrig/media")
} }

View file

@ -25,8 +25,8 @@ import (
"strconv" "strconv"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/superseriousbusiness/gotosocial/internal/api" apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
"github.com/superseriousbusiness/gotosocial/internal/api/model" apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util"
"github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/oauth" "github.com/superseriousbusiness/gotosocial/internal/oauth"
) )
@ -126,18 +126,18 @@ import (
func (m *Module) DomainBlocksPOSTHandler(c *gin.Context) { func (m *Module) DomainBlocksPOSTHandler(c *gin.Context) {
authed, err := oauth.Authed(c, true, true, true, true) authed, err := oauth.Authed(c, true, true, true, true)
if err != nil { if err != nil {
api.ErrorHandler(c, gtserror.NewErrorUnauthorized(err, err.Error()), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorUnauthorized(err, err.Error()), m.processor.InstanceGet)
return return
} }
if !*authed.User.Admin { if !*authed.User.Admin {
err := fmt.Errorf("user %s not an admin", authed.User.ID) err := fmt.Errorf("user %s not an admin", authed.User.ID)
api.ErrorHandler(c, gtserror.NewErrorForbidden(err, err.Error()), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorForbidden(err, err.Error()), m.processor.InstanceGet)
return return
} }
if _, err := api.NegotiateAccept(c, api.JSONAcceptHeaders...); err != nil { if _, err := apiutil.NegotiateAccept(c, apiutil.JSONAcceptHeaders...); err != nil {
api.ErrorHandler(c, gtserror.NewErrorNotAcceptable(err, err.Error()), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorNotAcceptable(err, err.Error()), m.processor.InstanceGet)
return return
} }
@ -147,21 +147,21 @@ func (m *Module) DomainBlocksPOSTHandler(c *gin.Context) {
i, err := strconv.ParseBool(importString) i, err := strconv.ParseBool(importString)
if err != nil { if err != nil {
err := fmt.Errorf("error parsing %s: %s", ImportQueryKey, err) err := fmt.Errorf("error parsing %s: %s", ImportQueryKey, err)
api.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGet)
return return
} }
imp = i imp = i
} }
form := &model.DomainBlockCreateRequest{} form := &apimodel.DomainBlockCreateRequest{}
if err := c.ShouldBind(form); err != nil { if err := c.ShouldBind(form); err != nil {
api.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGet)
return return
} }
if err := validateCreateDomainBlock(form, imp); err != nil { if err := validateCreateDomainBlock(form, imp); err != nil {
err := fmt.Errorf("error validating form: %s", err) err := fmt.Errorf("error validating form: %s", err)
api.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGet)
return return
} }
@ -169,7 +169,7 @@ func (m *Module) DomainBlocksPOSTHandler(c *gin.Context) {
// we're importing multiple blocks // we're importing multiple blocks
domainBlocks, errWithCode := m.processor.AdminDomainBlocksImport(c.Request.Context(), authed, form) domainBlocks, errWithCode := m.processor.AdminDomainBlocksImport(c.Request.Context(), authed, form)
if errWithCode != nil { if errWithCode != nil {
api.ErrorHandler(c, errWithCode, m.processor.InstanceGet) apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGet)
return return
} }
c.JSON(http.StatusOK, domainBlocks) c.JSON(http.StatusOK, domainBlocks)
@ -179,13 +179,13 @@ func (m *Module) DomainBlocksPOSTHandler(c *gin.Context) {
// we're just creating one block // we're just creating one block
domainBlock, errWithCode := m.processor.AdminDomainBlockCreate(c.Request.Context(), authed, form) domainBlock, errWithCode := m.processor.AdminDomainBlockCreate(c.Request.Context(), authed, form)
if errWithCode != nil { if errWithCode != nil {
api.ErrorHandler(c, errWithCode, m.processor.InstanceGet) apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGet)
return return
} }
c.JSON(http.StatusOK, domainBlock) c.JSON(http.StatusOK, domainBlock)
} }
func validateCreateDomainBlock(form *model.DomainBlockCreateRequest, imp bool) error { func validateCreateDomainBlock(form *apimodel.DomainBlockCreateRequest, imp bool) error {
if imp { if imp {
if form.Domains.Size == 0 { if form.Domains.Size == 0 {
return errors.New("import was specified but list of domains is empty") return errors.New("import was specified but list of domains is empty")

View file

@ -24,7 +24,7 @@ import (
"net/http" "net/http"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/superseriousbusiness/gotosocial/internal/api" apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util"
"github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/oauth" "github.com/superseriousbusiness/gotosocial/internal/oauth"
) )
@ -72,31 +72,31 @@ import (
func (m *Module) DomainBlockDELETEHandler(c *gin.Context) { func (m *Module) DomainBlockDELETEHandler(c *gin.Context) {
authed, err := oauth.Authed(c, true, true, true, true) authed, err := oauth.Authed(c, true, true, true, true)
if err != nil { if err != nil {
api.ErrorHandler(c, gtserror.NewErrorUnauthorized(err, err.Error()), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorUnauthorized(err, err.Error()), m.processor.InstanceGet)
return return
} }
if !*authed.User.Admin { if !*authed.User.Admin {
err := fmt.Errorf("user %s not an admin", authed.User.ID) err := fmt.Errorf("user %s not an admin", authed.User.ID)
api.ErrorHandler(c, gtserror.NewErrorForbidden(err, err.Error()), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorForbidden(err, err.Error()), m.processor.InstanceGet)
return return
} }
if _, err := api.NegotiateAccept(c, api.JSONAcceptHeaders...); err != nil { if _, err := apiutil.NegotiateAccept(c, apiutil.JSONAcceptHeaders...); err != nil {
api.ErrorHandler(c, gtserror.NewErrorNotAcceptable(err, err.Error()), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorNotAcceptable(err, err.Error()), m.processor.InstanceGet)
return return
} }
domainBlockID := c.Param(IDKey) domainBlockID := c.Param(IDKey)
if domainBlockID == "" { if domainBlockID == "" {
err := errors.New("no domain block id specified") err := errors.New("no domain block id specified")
api.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGet)
return return
} }
domainBlock, errWithCode := m.processor.AdminDomainBlockDelete(c.Request.Context(), authed, domainBlockID) domainBlock, errWithCode := m.processor.AdminDomainBlockDelete(c.Request.Context(), authed, domainBlockID)
if errWithCode != nil { if errWithCode != nil {
api.ErrorHandler(c, errWithCode, m.processor.InstanceGet) apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGet)
return return
} }

View file

@ -25,7 +25,7 @@ import (
"strconv" "strconv"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/superseriousbusiness/gotosocial/internal/api" apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util"
"github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/oauth" "github.com/superseriousbusiness/gotosocial/internal/oauth"
) )
@ -73,25 +73,25 @@ import (
func (m *Module) DomainBlockGETHandler(c *gin.Context) { func (m *Module) DomainBlockGETHandler(c *gin.Context) {
authed, err := oauth.Authed(c, true, true, true, true) authed, err := oauth.Authed(c, true, true, true, true)
if err != nil { if err != nil {
api.ErrorHandler(c, gtserror.NewErrorUnauthorized(err, err.Error()), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorUnauthorized(err, err.Error()), m.processor.InstanceGet)
return return
} }
if !*authed.User.Admin { if !*authed.User.Admin {
err := fmt.Errorf("user %s not an admin", authed.User.ID) err := fmt.Errorf("user %s not an admin", authed.User.ID)
api.ErrorHandler(c, gtserror.NewErrorForbidden(err, err.Error()), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorForbidden(err, err.Error()), m.processor.InstanceGet)
return return
} }
if _, err := api.NegotiateAccept(c, api.JSONAcceptHeaders...); err != nil { if _, err := apiutil.NegotiateAccept(c, apiutil.JSONAcceptHeaders...); err != nil {
api.ErrorHandler(c, gtserror.NewErrorNotAcceptable(err, err.Error()), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorNotAcceptable(err, err.Error()), m.processor.InstanceGet)
return return
} }
domainBlockID := c.Param(IDKey) domainBlockID := c.Param(IDKey)
if domainBlockID == "" { if domainBlockID == "" {
err := errors.New("no domain block id specified") err := errors.New("no domain block id specified")
api.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGet)
return return
} }
@ -101,7 +101,7 @@ func (m *Module) DomainBlockGETHandler(c *gin.Context) {
i, err := strconv.ParseBool(exportString) i, err := strconv.ParseBool(exportString)
if err != nil { if err != nil {
err := fmt.Errorf("error parsing %s: %s", ExportQueryKey, err) err := fmt.Errorf("error parsing %s: %s", ExportQueryKey, err)
api.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGet)
return return
} }
export = i export = i
@ -109,7 +109,7 @@ func (m *Module) DomainBlockGETHandler(c *gin.Context) {
domainBlock, errWithCode := m.processor.AdminDomainBlockGet(c.Request.Context(), authed, domainBlockID, export) domainBlock, errWithCode := m.processor.AdminDomainBlockGet(c.Request.Context(), authed, domainBlockID, export)
if errWithCode != nil { if errWithCode != nil {
api.ErrorHandler(c, errWithCode, m.processor.InstanceGet) apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGet)
return return
} }

View file

@ -24,7 +24,7 @@ import (
"strconv" "strconv"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/superseriousbusiness/gotosocial/internal/api" apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util"
"github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/oauth" "github.com/superseriousbusiness/gotosocial/internal/oauth"
) )
@ -78,18 +78,18 @@ import (
func (m *Module) DomainBlocksGETHandler(c *gin.Context) { func (m *Module) DomainBlocksGETHandler(c *gin.Context) {
authed, err := oauth.Authed(c, true, true, true, true) authed, err := oauth.Authed(c, true, true, true, true)
if err != nil { if err != nil {
api.ErrorHandler(c, gtserror.NewErrorUnauthorized(err, err.Error()), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorUnauthorized(err, err.Error()), m.processor.InstanceGet)
return return
} }
if !*authed.User.Admin { if !*authed.User.Admin {
err := fmt.Errorf("user %s not an admin", authed.User.ID) err := fmt.Errorf("user %s not an admin", authed.User.ID)
api.ErrorHandler(c, gtserror.NewErrorForbidden(err, err.Error()), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorForbidden(err, err.Error()), m.processor.InstanceGet)
return return
} }
if _, err := api.NegotiateAccept(c, api.JSONAcceptHeaders...); err != nil { if _, err := apiutil.NegotiateAccept(c, apiutil.JSONAcceptHeaders...); err != nil {
api.ErrorHandler(c, gtserror.NewErrorNotAcceptable(err, err.Error()), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorNotAcceptable(err, err.Error()), m.processor.InstanceGet)
return return
} }
@ -99,7 +99,7 @@ func (m *Module) DomainBlocksGETHandler(c *gin.Context) {
i, err := strconv.ParseBool(exportString) i, err := strconv.ParseBool(exportString)
if err != nil { if err != nil {
err := fmt.Errorf("error parsing %s: %s", ExportQueryKey, err) err := fmt.Errorf("error parsing %s: %s", ExportQueryKey, err)
api.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGet)
return return
} }
export = i export = i
@ -107,7 +107,7 @@ func (m *Module) DomainBlocksGETHandler(c *gin.Context) {
domainBlocks, errWithCode := m.processor.AdminDomainBlocksGet(c.Request.Context(), authed, export) domainBlocks, errWithCode := m.processor.AdminDomainBlocksGet(c.Request.Context(), authed, export)
if errWithCode != nil { if errWithCode != nil {
api.ErrorHandler(c, errWithCode, m.processor.InstanceGet) apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGet)
return return
} }

View file

@ -23,7 +23,7 @@ import (
"net/http" "net/http"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/superseriousbusiness/gotosocial/internal/api" apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util"
"github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/oauth" "github.com/superseriousbusiness/gotosocial/internal/oauth"
) )
@ -69,24 +69,24 @@ import (
func (m *Module) EmojiCategoriesGETHandler(c *gin.Context) { func (m *Module) EmojiCategoriesGETHandler(c *gin.Context) {
authed, err := oauth.Authed(c, true, true, true, true) authed, err := oauth.Authed(c, true, true, true, true)
if err != nil { if err != nil {
api.ErrorHandler(c, gtserror.NewErrorUnauthorized(err, err.Error()), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorUnauthorized(err, err.Error()), m.processor.InstanceGet)
return return
} }
if !*authed.User.Admin { if !*authed.User.Admin {
err := fmt.Errorf("user %s not an admin", authed.User.ID) err := fmt.Errorf("user %s not an admin", authed.User.ID)
api.ErrorHandler(c, gtserror.NewErrorForbidden(err, err.Error()), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorForbidden(err, err.Error()), m.processor.InstanceGet)
return return
} }
if _, err := api.NegotiateAccept(c, api.JSONAcceptHeaders...); err != nil { if _, err := apiutil.NegotiateAccept(c, apiutil.JSONAcceptHeaders...); err != nil {
api.ErrorHandler(c, gtserror.NewErrorNotAcceptable(err, err.Error()), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorNotAcceptable(err, err.Error()), m.processor.InstanceGet)
return return
} }
categories, errWithCode := m.processor.AdminEmojiCategoriesGet(c.Request.Context()) categories, errWithCode := m.processor.AdminEmojiCategoriesGet(c.Request.Context())
if errWithCode != nil { if errWithCode != nil {
api.ErrorHandler(c, errWithCode, m.processor.InstanceGet) apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGet)
return return
} }

View file

@ -24,8 +24,8 @@ import (
"net/http" "net/http"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/superseriousbusiness/gotosocial/internal/api" apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
"github.com/superseriousbusiness/gotosocial/internal/api/model" apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util"
"github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/oauth" "github.com/superseriousbusiness/gotosocial/internal/oauth"
@ -100,42 +100,42 @@ import (
func (m *Module) EmojiCreatePOSTHandler(c *gin.Context) { func (m *Module) EmojiCreatePOSTHandler(c *gin.Context) {
authed, err := oauth.Authed(c, true, true, true, true) authed, err := oauth.Authed(c, true, true, true, true)
if err != nil { if err != nil {
api.ErrorHandler(c, gtserror.NewErrorUnauthorized(err, err.Error()), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorUnauthorized(err, err.Error()), m.processor.InstanceGet)
return return
} }
if !*authed.User.Admin { if !*authed.User.Admin {
err := fmt.Errorf("user %s not an admin", authed.User.ID) err := fmt.Errorf("user %s not an admin", authed.User.ID)
api.ErrorHandler(c, gtserror.NewErrorForbidden(err, err.Error()), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorForbidden(err, err.Error()), m.processor.InstanceGet)
return return
} }
if _, err := api.NegotiateAccept(c, api.JSONAcceptHeaders...); err != nil { if _, err := apiutil.NegotiateAccept(c, apiutil.JSONAcceptHeaders...); err != nil {
api.ErrorHandler(c, gtserror.NewErrorNotAcceptable(err, err.Error()), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorNotAcceptable(err, err.Error()), m.processor.InstanceGet)
return return
} }
form := &model.EmojiCreateRequest{} form := &apimodel.EmojiCreateRequest{}
if err := c.ShouldBind(form); err != nil { if err := c.ShouldBind(form); err != nil {
api.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGet)
return return
} }
if err := validateCreateEmoji(form); err != nil { if err := validateCreateEmoji(form); err != nil {
api.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGet)
return return
} }
apiEmoji, errWithCode := m.processor.AdminEmojiCreate(c.Request.Context(), authed, form) apiEmoji, errWithCode := m.processor.AdminEmojiCreate(c.Request.Context(), authed, form)
if errWithCode != nil { if errWithCode != nil {
api.ErrorHandler(c, errWithCode, m.processor.InstanceGet) apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGet)
return return
} }
c.JSON(http.StatusOK, apiEmoji) c.JSON(http.StatusOK, apiEmoji)
} }
func validateCreateEmoji(form *model.EmojiCreateRequest) error { func validateCreateEmoji(form *apimodel.EmojiCreateRequest) error {
if form.Image == nil || form.Image.Size == 0 { if form.Image == nil || form.Image.Size == 0 {
return errors.New("no emoji given") return errors.New("no emoji given")
} }

View file

@ -24,7 +24,7 @@ import (
"net/http" "net/http"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/superseriousbusiness/gotosocial/internal/api" apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util"
"github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/oauth" "github.com/superseriousbusiness/gotosocial/internal/oauth"
) )
@ -78,31 +78,31 @@ import (
func (m *Module) EmojiDELETEHandler(c *gin.Context) { func (m *Module) EmojiDELETEHandler(c *gin.Context) {
authed, err := oauth.Authed(c, true, true, true, true) authed, err := oauth.Authed(c, true, true, true, true)
if err != nil { if err != nil {
api.ErrorHandler(c, gtserror.NewErrorUnauthorized(err, err.Error()), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorUnauthorized(err, err.Error()), m.processor.InstanceGet)
return return
} }
if !*authed.User.Admin { if !*authed.User.Admin {
err := fmt.Errorf("user %s not an admin", authed.User.ID) err := fmt.Errorf("user %s not an admin", authed.User.ID)
api.ErrorHandler(c, gtserror.NewErrorForbidden(err, err.Error()), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorForbidden(err, err.Error()), m.processor.InstanceGet)
return return
} }
if _, err := api.NegotiateAccept(c, api.JSONAcceptHeaders...); err != nil { if _, err := apiutil.NegotiateAccept(c, apiutil.JSONAcceptHeaders...); err != nil {
api.ErrorHandler(c, gtserror.NewErrorNotAcceptable(err, err.Error()), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorNotAcceptable(err, err.Error()), m.processor.InstanceGet)
return return
} }
emojiID := c.Param(IDKey) emojiID := c.Param(IDKey)
if emojiID == "" { if emojiID == "" {
err := errors.New("no emoji id specified") err := errors.New("no emoji id specified")
api.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGet)
return return
} }
emoji, errWithCode := m.processor.AdminEmojiDelete(c.Request.Context(), authed, emojiID) emoji, errWithCode := m.processor.AdminEmojiDelete(c.Request.Context(), authed, emojiID)
if errWithCode != nil { if errWithCode != nil {
api.ErrorHandler(c, errWithCode, m.processor.InstanceGet) apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGet)
return return
} }

View file

@ -24,7 +24,7 @@ import (
"net/http" "net/http"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/superseriousbusiness/gotosocial/internal/api" apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util"
"github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/oauth" "github.com/superseriousbusiness/gotosocial/internal/oauth"
) )
@ -68,31 +68,31 @@ import (
func (m *Module) EmojiGETHandler(c *gin.Context) { func (m *Module) EmojiGETHandler(c *gin.Context) {
authed, err := oauth.Authed(c, true, true, true, true) authed, err := oauth.Authed(c, true, true, true, true)
if err != nil { if err != nil {
api.ErrorHandler(c, gtserror.NewErrorUnauthorized(err, err.Error()), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorUnauthorized(err, err.Error()), m.processor.InstanceGet)
return return
} }
if !*authed.User.Admin { if !*authed.User.Admin {
err := fmt.Errorf("user %s not an admin", authed.User.ID) err := fmt.Errorf("user %s not an admin", authed.User.ID)
api.ErrorHandler(c, gtserror.NewErrorForbidden(err, err.Error()), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorForbidden(err, err.Error()), m.processor.InstanceGet)
return return
} }
if _, err := api.NegotiateAccept(c, api.JSONAcceptHeaders...); err != nil { if _, err := apiutil.NegotiateAccept(c, apiutil.JSONAcceptHeaders...); err != nil {
api.ErrorHandler(c, gtserror.NewErrorNotAcceptable(err, err.Error()), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorNotAcceptable(err, err.Error()), m.processor.InstanceGet)
return return
} }
emojiID := c.Param(IDKey) emojiID := c.Param(IDKey)
if emojiID == "" { if emojiID == "" {
err := errors.New("no emoji id specified") err := errors.New("no emoji id specified")
api.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGet)
return return
} }
emoji, errWithCode := m.processor.AdminEmojiGet(c.Request.Context(), authed, emojiID) emoji, errWithCode := m.processor.AdminEmojiGet(c.Request.Context(), authed, emojiID)
if errWithCode != nil { if errWithCode != nil {
api.ErrorHandler(c, errWithCode, m.processor.InstanceGet) apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGet)
return return
} }

View file

@ -25,7 +25,7 @@ import (
"strings" "strings"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/superseriousbusiness/gotosocial/internal/api" apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util"
"github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/gtserror"
@ -125,18 +125,18 @@ import (
func (m *Module) EmojisGETHandler(c *gin.Context) { func (m *Module) EmojisGETHandler(c *gin.Context) {
authed, err := oauth.Authed(c, true, true, true, true) authed, err := oauth.Authed(c, true, true, true, true)
if err != nil { if err != nil {
api.ErrorHandler(c, gtserror.NewErrorUnauthorized(err, err.Error()), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorUnauthorized(err, err.Error()), m.processor.InstanceGet)
return return
} }
if !*authed.User.Admin { if !*authed.User.Admin {
err := fmt.Errorf("user %s not an admin", authed.User.ID) err := fmt.Errorf("user %s not an admin", authed.User.ID)
api.ErrorHandler(c, gtserror.NewErrorForbidden(err, err.Error()), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorForbidden(err, err.Error()), m.processor.InstanceGet)
return return
} }
if _, err := api.NegotiateAccept(c, api.JSONAcceptHeaders...); err != nil { if _, err := apiutil.NegotiateAccept(c, apiutil.JSONAcceptHeaders...); err != nil {
api.ErrorHandler(c, gtserror.NewErrorNotAcceptable(err, err.Error()), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorNotAcceptable(err, err.Error()), m.processor.InstanceGet)
return return
} }
@ -149,7 +149,7 @@ func (m *Module) EmojisGETHandler(c *gin.Context) {
i, err := strconv.ParseInt(limitString, 10, 32) i, err := strconv.ParseInt(limitString, 10, 32)
if err != nil { if err != nil {
err := fmt.Errorf("error parsing %s: %s", LimitKey, err) err := fmt.Errorf("error parsing %s: %s", LimitKey, err)
api.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGet)
return return
} }
limit = int(i) limit = int(i)
@ -177,7 +177,7 @@ func (m *Module) EmojisGETHandler(c *gin.Context) {
shortcode = strings.Trim(filter[10:], ":") // remove any errant ":" shortcode = strings.Trim(filter[10:], ":") // remove any errant ":"
default: default:
err := fmt.Errorf("filter %s not recognized; accepted values are 'domain:[domain]', 'disabled', 'enabled', 'shortcode:[shortcode]'", filter) err := fmt.Errorf("filter %s not recognized; accepted values are 'domain:[domain]', 'disabled', 'enabled', 'shortcode:[shortcode]'", filter)
api.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGet)
return return
} }
} }
@ -200,7 +200,7 @@ func (m *Module) EmojisGETHandler(c *gin.Context) {
resp, errWithCode := m.processor.AdminEmojisGet(c.Request.Context(), authed, domain, includeDisabled, includeEnabled, shortcode, maxShortcodeDomain, minShortcodeDomain, limit) resp, errWithCode := m.processor.AdminEmojisGet(c.Request.Context(), authed, domain, includeDisabled, includeEnabled, shortcode, maxShortcodeDomain, minShortcodeDomain, limit)
if errWithCode != nil { if errWithCode != nil {
api.ErrorHandler(c, errWithCode, m.processor.InstanceGet) apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGet)
return return
} }

View file

@ -25,8 +25,8 @@ import (
"strings" "strings"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/superseriousbusiness/gotosocial/internal/api" apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
"github.com/superseriousbusiness/gotosocial/internal/api/model" apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util"
"github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/oauth" "github.com/superseriousbusiness/gotosocial/internal/oauth"
@ -123,42 +123,42 @@ import (
func (m *Module) EmojiPATCHHandler(c *gin.Context) { func (m *Module) EmojiPATCHHandler(c *gin.Context) {
authed, err := oauth.Authed(c, true, true, true, true) authed, err := oauth.Authed(c, true, true, true, true)
if err != nil { if err != nil {
api.ErrorHandler(c, gtserror.NewErrorUnauthorized(err, err.Error()), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorUnauthorized(err, err.Error()), m.processor.InstanceGet)
return return
} }
if !*authed.User.Admin { if !*authed.User.Admin {
err := fmt.Errorf("user %s not an admin", authed.User.ID) err := fmt.Errorf("user %s not an admin", authed.User.ID)
api.ErrorHandler(c, gtserror.NewErrorForbidden(err, err.Error()), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorForbidden(err, err.Error()), m.processor.InstanceGet)
return return
} }
if _, err := api.NegotiateAccept(c, api.JSONAcceptHeaders...); err != nil { if _, err := apiutil.NegotiateAccept(c, apiutil.JSONAcceptHeaders...); err != nil {
api.ErrorHandler(c, gtserror.NewErrorNotAcceptable(err, err.Error()), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorNotAcceptable(err, err.Error()), m.processor.InstanceGet)
return return
} }
emojiID := c.Param(IDKey) emojiID := c.Param(IDKey)
if emojiID == "" { if emojiID == "" {
err := errors.New("no emoji id specified") err := errors.New("no emoji id specified")
api.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGet)
return return
} }
form := &model.EmojiUpdateRequest{} form := &apimodel.EmojiUpdateRequest{}
if err := c.ShouldBind(form); err != nil { if err := c.ShouldBind(form); err != nil {
api.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGet)
return return
} }
if err := validateUpdateEmoji(form); err != nil { if err := validateUpdateEmoji(form); err != nil {
api.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGet)
return return
} }
emoji, errWithCode := m.processor.AdminEmojiUpdate(c.Request.Context(), emojiID, form) emoji, errWithCode := m.processor.AdminEmojiUpdate(c.Request.Context(), emojiID, form)
if errWithCode != nil { if errWithCode != nil {
api.ErrorHandler(c, errWithCode, m.processor.InstanceGet) apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGet)
return return
} }
@ -166,14 +166,14 @@ func (m *Module) EmojiPATCHHandler(c *gin.Context) {
} }
// do a first pass on the form here // do a first pass on the form here
func validateUpdateEmoji(form *model.EmojiUpdateRequest) error { func validateUpdateEmoji(form *apimodel.EmojiUpdateRequest) error {
// check + normalize update type so we don't need // check + normalize update type so we don't need
// to do this trimming + lowercasing again later // to do this trimming + lowercasing again later
switch strings.TrimSpace(strings.ToLower(string(form.Type))) { switch strings.TrimSpace(strings.ToLower(string(form.Type))) {
case string(model.EmojiUpdateDisable): case string(apimodel.EmojiUpdateDisable):
// no params required for this one, so don't bother checking // no params required for this one, so don't bother checking
form.Type = model.EmojiUpdateDisable form.Type = apimodel.EmojiUpdateDisable
case string(model.EmojiUpdateCopy): case string(apimodel.EmojiUpdateCopy):
// need at least a valid shortcode when doing a copy // need at least a valid shortcode when doing a copy
if form.Shortcode == nil { if form.Shortcode == nil {
return errors.New("emoji action type was 'copy' but no shortcode was provided") return errors.New("emoji action type was 'copy' but no shortcode was provided")
@ -190,8 +190,8 @@ func validateUpdateEmoji(form *model.EmojiUpdateRequest) error {
} }
} }
form.Type = model.EmojiUpdateCopy form.Type = apimodel.EmojiUpdateCopy
case string(model.EmojiUpdateModify): case string(apimodel.EmojiUpdateModify):
// need either image or category name for modify // need either image or category name for modify
hasImage := form.Image != nil && form.Image.Size != 0 hasImage := form.Image != nil && form.Image.Size != 0
hasCategoryName := form.CategoryName != nil hasCategoryName := form.CategoryName != nil
@ -212,7 +212,7 @@ func validateUpdateEmoji(form *model.EmojiUpdateRequest) error {
} }
} }
form.Type = model.EmojiUpdateModify form.Type = apimodel.EmojiUpdateModify
default: default:
return errors.New("emoji action type must be one of 'disable', 'copy', 'modify'") return errors.New("emoji action type must be one of 'disable', 'copy', 'modify'")
} }

View file

@ -23,8 +23,8 @@ import (
"net/http" "net/http"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/superseriousbusiness/gotosocial/internal/api" apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
"github.com/superseriousbusiness/gotosocial/internal/api/model" apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util"
"github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/oauth" "github.com/superseriousbusiness/gotosocial/internal/oauth"
@ -71,19 +71,19 @@ import (
func (m *Module) MediaCleanupPOSTHandler(c *gin.Context) { func (m *Module) MediaCleanupPOSTHandler(c *gin.Context) {
authed, err := oauth.Authed(c, true, true, true, true) authed, err := oauth.Authed(c, true, true, true, true)
if err != nil { if err != nil {
api.ErrorHandler(c, gtserror.NewErrorUnauthorized(err, err.Error()), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorUnauthorized(err, err.Error()), m.processor.InstanceGet)
return return
} }
if !*authed.User.Admin { if !*authed.User.Admin {
err := fmt.Errorf("user %s not an admin", authed.User.ID) err := fmt.Errorf("user %s not an admin", authed.User.ID)
api.ErrorHandler(c, gtserror.NewErrorForbidden(err, err.Error()), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorForbidden(err, err.Error()), m.processor.InstanceGet)
return return
} }
form := &model.MediaCleanupRequest{} form := &apimodel.MediaCleanupRequest{}
if err := c.ShouldBind(form); err != nil { if err := c.ShouldBind(form); err != nil {
api.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGet)
return return
} }
@ -98,7 +98,7 @@ func (m *Module) MediaCleanupPOSTHandler(c *gin.Context) {
} }
if errWithCode := m.processor.AdminMediaPrune(c.Request.Context(), remoteCacheDays); errWithCode != nil { if errWithCode := m.processor.AdminMediaPrune(c.Request.Context(), remoteCacheDays); errWithCode != nil {
api.ErrorHandler(c, errWithCode, m.processor.InstanceGet) apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGet)
return return
} }

View file

@ -23,7 +23,7 @@ import (
"net/http" "net/http"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/superseriousbusiness/gotosocial/internal/api" apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util"
"github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/oauth" "github.com/superseriousbusiness/gotosocial/internal/oauth"
) )
@ -74,18 +74,18 @@ import (
func (m *Module) MediaRefetchPOSTHandler(c *gin.Context) { func (m *Module) MediaRefetchPOSTHandler(c *gin.Context) {
authed, err := oauth.Authed(c, true, true, true, true) authed, err := oauth.Authed(c, true, true, true, true)
if err != nil { if err != nil {
api.ErrorHandler(c, gtserror.NewErrorUnauthorized(err, err.Error()), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorUnauthorized(err, err.Error()), m.processor.InstanceGet)
return return
} }
if !*authed.User.Admin { if !*authed.User.Admin {
err := fmt.Errorf("user %s not an admin", authed.User.ID) err := fmt.Errorf("user %s not an admin", authed.User.ID)
api.ErrorHandler(c, gtserror.NewErrorForbidden(err, err.Error()), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorForbidden(err, err.Error()), m.processor.InstanceGet)
return return
} }
if errWithCode := m.processor.AdminMediaRefetch(c.Request.Context(), authed, c.Query(DomainQueryKey)); errWithCode != nil { if errWithCode := m.processor.AdminMediaRefetch(c.Request.Context(), authed, c.Query(DomainQueryKey)); errWithCode != nil {
api.ErrorHandler(c, errWithCode, m.processor.InstanceGet) apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGet)
return return
} }

View file

@ -1,21 +0,0 @@
/*
GoToSocial
Copyright (C) 2021-2022 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package app_test
// TODO: write tests

View file

@ -16,15 +16,15 @@
along with this program. If not, see <http://www.gnu.org/licenses/>. along with this program. If not, see <http://www.gnu.org/licenses/>.
*/ */
package app package apps
import ( import (
"fmt" "fmt"
"net/http" "net/http"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/superseriousbusiness/gotosocial/internal/api" apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
"github.com/superseriousbusiness/gotosocial/internal/api/model" apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util"
"github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/oauth" "github.com/superseriousbusiness/gotosocial/internal/oauth"
) )
@ -77,48 +77,48 @@ const (
func (m *Module) AppsPOSTHandler(c *gin.Context) { func (m *Module) AppsPOSTHandler(c *gin.Context) {
authed, err := oauth.Authed(c, false, false, false, false) authed, err := oauth.Authed(c, false, false, false, false)
if err != nil { if err != nil {
api.ErrorHandler(c, gtserror.NewErrorUnauthorized(err, err.Error()), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorUnauthorized(err, err.Error()), m.processor.InstanceGet)
return return
} }
if _, err := api.NegotiateAccept(c, api.JSONAcceptHeaders...); err != nil { if _, err := apiutil.NegotiateAccept(c, apiutil.JSONAcceptHeaders...); err != nil {
api.ErrorHandler(c, gtserror.NewErrorNotAcceptable(err, err.Error()), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorNotAcceptable(err, err.Error()), m.processor.InstanceGet)
return return
} }
form := &model.ApplicationCreateRequest{} form := &apimodel.ApplicationCreateRequest{}
if err := c.ShouldBind(form); err != nil { if err := c.ShouldBind(form); err != nil {
api.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGet)
return return
} }
if len([]rune(form.ClientName)) > formFieldLen { if len([]rune(form.ClientName)) > formFieldLen {
err := fmt.Errorf("client_name must be less than %d characters", formFieldLen) err := fmt.Errorf("client_name must be less than %d characters", formFieldLen)
api.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGet)
return return
} }
if len([]rune(form.RedirectURIs)) > formRedirectLen { if len([]rune(form.RedirectURIs)) > formRedirectLen {
err := fmt.Errorf("redirect_uris must be less than %d characters", formRedirectLen) err := fmt.Errorf("redirect_uris must be less than %d characters", formRedirectLen)
api.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGet)
return return
} }
if len([]rune(form.Scopes)) > formFieldLen { if len([]rune(form.Scopes)) > formFieldLen {
err := fmt.Errorf("scopes must be less than %d characters", formFieldLen) err := fmt.Errorf("scopes must be less than %d characters", formFieldLen)
api.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGet)
return return
} }
if len([]rune(form.Website)) > formFieldLen { if len([]rune(form.Website)) > formFieldLen {
err := fmt.Errorf("website must be less than %d characters", formFieldLen) err := fmt.Errorf("website must be less than %d characters", formFieldLen)
api.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGet)
return return
} }
apiApp, errWithCode := m.processor.AppCreate(c.Request.Context(), authed, form) apiApp, errWithCode := m.processor.AppCreate(c.Request.Context(), authed, form)
if errWithCode != nil { if errWithCode != nil {
api.ErrorHandler(c, errWithCode, m.processor.InstanceGet) apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGet)
return return
} }

View file

@ -16,33 +16,28 @@
along with this program. If not, see <http://www.gnu.org/licenses/>. along with this program. If not, see <http://www.gnu.org/licenses/>.
*/ */
package app package apps
import ( import (
"net/http" "net/http"
"github.com/superseriousbusiness/gotosocial/internal/api" "github.com/gin-gonic/gin"
"github.com/superseriousbusiness/gotosocial/internal/processing" "github.com/superseriousbusiness/gotosocial/internal/processing"
"github.com/superseriousbusiness/gotosocial/internal/router"
) )
// BasePath is the base path for this api module // BasePath is the base path for this api module, excluding the api prefix
const BasePath = "/api/v1/apps" const BasePath = "/v1/apps"
// Module implements the ClientAPIModule interface for requests relating to registering/removing applications
type Module struct { type Module struct {
processor processing.Processor processor processing.Processor
} }
// New returns a new auth module func New(processor processing.Processor) *Module {
func New(processor processing.Processor) api.ClientModule {
return &Module{ return &Module{
processor: processor, processor: processor,
} }
} }
// Route satisfies the RESTAPIModule interface func (m *Module) Route(attachHandler func(method string, path string, f ...gin.HandlerFunc) gin.IRoutes) {
func (m *Module) Route(s router.Router) error { attachHandler(http.MethodPost, BasePath, m.AppsPOSTHandler)
s.AttachHandler(http.MethodPost, BasePath, m.AppsPOSTHandler)
return nil
} }

View file

@ -1,105 +0,0 @@
/*
GoToSocial
Copyright (C) 2021-2022 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package auth
import (
"net/http"
"github.com/superseriousbusiness/gotosocial/internal/api"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/oauth"
"github.com/superseriousbusiness/gotosocial/internal/oidc"
"github.com/superseriousbusiness/gotosocial/internal/processing"
"github.com/superseriousbusiness/gotosocial/internal/router"
)
/* #nosec G101 */
const (
// AuthSignInPath is the API path for users to sign in through
AuthSignInPath = "/auth/sign_in"
// CheckYourEmailPath users land here after registering a new account, instructs them to confirm thier email
CheckYourEmailPath = "/check_your_email"
// WaitForApprovalPath users land here after confirming thier email but before an admin approves thier account
// (if such is required)
WaitForApprovalPath = "/wait_for_approval"
// AccountDisabledPath users land here when thier account is suspended by an admin
AccountDisabledPath = "/account_disabled"
// OauthTokenPath is the API path to use for granting token requests to users with valid credentials
OauthTokenPath = "/oauth/token"
// OauthAuthorizePath is the API path for authorization requests (eg., authorize this app to act on my behalf as a user)
OauthAuthorizePath = "/oauth/authorize"
// OauthFinalizePath is the API path for completing user registration with additional user details
OauthFinalizePath = "/oauth/finalize"
// CallbackPath is the API path for receiving callback tokens from external OIDC providers
CallbackPath = oidc.CallbackPath
callbackStateParam = "state"
callbackCodeParam = "code"
sessionUserID = "userid"
sessionClientID = "client_id"
sessionRedirectURI = "redirect_uri"
sessionForceLogin = "force_login"
sessionResponseType = "response_type"
sessionScope = "scope"
sessionInternalState = "internal_state"
sessionClientState = "client_state"
sessionClaims = "claims"
sessionAppID = "app_id"
)
// Module implements the ClientAPIModule interface for
type Module struct {
db db.DB
idp oidc.IDP
processor processing.Processor
}
// New returns a new auth module
func New(db db.DB, idp oidc.IDP, processor processing.Processor) api.ClientModule {
return &Module{
db: db,
idp: idp,
processor: processor,
}
}
// Route satisfies the RESTAPIModule interface
func (m *Module) Route(s router.Router) error {
s.AttachHandler(http.MethodGet, AuthSignInPath, m.SignInGETHandler)
s.AttachHandler(http.MethodPost, AuthSignInPath, m.SignInPOSTHandler)
s.AttachHandler(http.MethodPost, OauthTokenPath, m.TokenPOSTHandler)
s.AttachHandler(http.MethodGet, OauthAuthorizePath, m.AuthorizeGETHandler)
s.AttachHandler(http.MethodPost, OauthAuthorizePath, m.AuthorizePOSTHandler)
s.AttachHandler(http.MethodGet, CallbackPath, m.CallbackGETHandler)
s.AttachHandler(http.MethodPost, OauthFinalizePath, m.FinalizePOSTHandler)
s.AttachHandler(http.MethodGet, oauth.OOBTokenPath, m.OobHandler)
return nil
}

View file

@ -21,14 +21,13 @@ package blocks
import ( import (
"net/http" "net/http"
"github.com/superseriousbusiness/gotosocial/internal/api" "github.com/gin-gonic/gin"
"github.com/superseriousbusiness/gotosocial/internal/processing" "github.com/superseriousbusiness/gotosocial/internal/processing"
"github.com/superseriousbusiness/gotosocial/internal/router"
) )
const ( const (
// BasePath is the base URI path for serving favourites // BasePath is the base URI path for serving blocks, minus the api prefix.
BasePath = "/api/v1/blocks" BasePath = "/v1/blocks"
// MaxIDKey is the url query for setting a max ID to return // MaxIDKey is the url query for setting a max ID to return
MaxIDKey = "max_id" MaxIDKey = "max_id"
@ -38,20 +37,16 @@ const (
LimitKey = "limit" LimitKey = "limit"
) )
// Module implements the ClientAPIModule interface for everything relating to viewing blocks
type Module struct { type Module struct {
processor processing.Processor processor processing.Processor
} }
// New returns a new blocks module func New(processor processing.Processor) *Module {
func New(processor processing.Processor) api.ClientModule {
return &Module{ return &Module{
processor: processor, processor: processor,
} }
} }
// Route attaches all routes from this module to the given router func (m *Module) Route(attachHandler func(method string, path string, f ...gin.HandlerFunc) gin.IRoutes) {
func (m *Module) Route(r router.Router) error { attachHandler(http.MethodGet, BasePath, m.BlocksGETHandler)
r.AttachHandler(http.MethodGet, BasePath, m.BlocksGETHandler)
return nil
} }

View file

@ -24,7 +24,7 @@ import (
"strconv" "strconv"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/superseriousbusiness/gotosocial/internal/api" apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util"
"github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/oauth" "github.com/superseriousbusiness/gotosocial/internal/oauth"
) )
@ -96,12 +96,12 @@ import (
func (m *Module) BlocksGETHandler(c *gin.Context) { func (m *Module) BlocksGETHandler(c *gin.Context) {
authed, err := oauth.Authed(c, true, true, true, true) authed, err := oauth.Authed(c, true, true, true, true)
if err != nil { if err != nil {
api.ErrorHandler(c, gtserror.NewErrorUnauthorized(err, err.Error()), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorUnauthorized(err, err.Error()), m.processor.InstanceGet)
return return
} }
if _, err := api.NegotiateAccept(c, api.JSONAcceptHeaders...); err != nil { if _, err := apiutil.NegotiateAccept(c, apiutil.JSONAcceptHeaders...); err != nil {
api.ErrorHandler(c, gtserror.NewErrorNotAcceptable(err, err.Error()), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorNotAcceptable(err, err.Error()), m.processor.InstanceGet)
return return
} }
@ -123,7 +123,7 @@ func (m *Module) BlocksGETHandler(c *gin.Context) {
i, err := strconv.ParseInt(limitString, 10, 32) i, err := strconv.ParseInt(limitString, 10, 32)
if err != nil { if err != nil {
err := fmt.Errorf("error parsing %s: %s", LimitKey, err) err := fmt.Errorf("error parsing %s: %s", LimitKey, err)
api.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGet)
return return
} }
limit = int(i) limit = int(i)
@ -131,7 +131,7 @@ func (m *Module) BlocksGETHandler(c *gin.Context) {
resp, errWithCode := m.processor.BlocksGet(c.Request.Context(), authed, maxID, sinceID, limit) resp, errWithCode := m.processor.BlocksGet(c.Request.Context(), authed, maxID, sinceID, limit)
if errWithCode != nil { if errWithCode != nil {
api.ErrorHandler(c, errWithCode, m.processor.InstanceGet) apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGet)
return return
} }

View file

@ -21,9 +21,8 @@ package bookmarks
import ( import (
"net/http" "net/http"
"github.com/superseriousbusiness/gotosocial/internal/api" "github.com/gin-gonic/gin"
"github.com/superseriousbusiness/gotosocial/internal/processing" "github.com/superseriousbusiness/gotosocial/internal/processing"
"github.com/superseriousbusiness/gotosocial/internal/router"
) )
const ( const (
@ -31,20 +30,16 @@ const (
BasePath = "/api/v1/bookmarks" BasePath = "/api/v1/bookmarks"
) )
// Module implements the ClientAPIModule interface for everything related to bookmarks
type Module struct { type Module struct {
processor processing.Processor processor processing.Processor
} }
// New returns a new emoji module func New(processor processing.Processor) *Module {
func New(processor processing.Processor) api.ClientModule {
return &Module{ return &Module{
processor: processor, processor: processor,
} }
} }
// Route attaches all routes from this module to the given router func (m *Module) Route(attachHandler func(method string, path string, f ...gin.HandlerFunc) gin.IRoutes) {
func (m *Module) Route(r router.Router) error { attachHandler(http.MethodGet, BasePath, m.BookmarksGETHandler)
r.AttachHandler(http.MethodGet, BasePath, m.BookmarksGETHandler)
return nil
} }

View file

@ -29,7 +29,7 @@ import (
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/api/client/bookmarks" "github.com/superseriousbusiness/gotosocial/internal/api/client/bookmarks"
"github.com/superseriousbusiness/gotosocial/internal/api/client/status" "github.com/superseriousbusiness/gotosocial/internal/api/client/statuses"
"github.com/superseriousbusiness/gotosocial/internal/api/model" "github.com/superseriousbusiness/gotosocial/internal/api/model"
"github.com/superseriousbusiness/gotosocial/internal/concurrency" "github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/db"
@ -67,7 +67,7 @@ type BookmarkTestSuite struct {
testFollows map[string]*gtsmodel.Follow testFollows map[string]*gtsmodel.Follow
// module being tested // module being tested
statusModule *status.Module statusModule *statuses.Module
bookmarkModule *bookmarks.Module bookmarkModule *bookmarks.Module
} }
@ -99,8 +99,8 @@ func (suite *BookmarkTestSuite) SetupTest() {
suite.federator = testrig.NewTestFederator(suite.db, testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil, "../../../../testrig/media"), suite.db, fedWorker), suite.storage, suite.mediaManager, fedWorker) suite.federator = testrig.NewTestFederator(suite.db, testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil, "../../../../testrig/media"), suite.db, fedWorker), suite.storage, suite.mediaManager, fedWorker)
suite.emailSender = testrig.NewEmailSender("../../../../web/template/", nil) suite.emailSender = testrig.NewEmailSender("../../../../web/template/", nil)
suite.processor = testrig.NewTestProcessor(suite.db, suite.storage, suite.federator, suite.emailSender, suite.mediaManager, clientWorker, fedWorker) suite.processor = testrig.NewTestProcessor(suite.db, suite.storage, suite.federator, suite.emailSender, suite.mediaManager, clientWorker, fedWorker)
suite.statusModule = status.New(suite.processor).(*status.Module) suite.statusModule = statuses.New(suite.processor)
suite.bookmarkModule = bookmarks.New(suite.processor).(*bookmarks.Module) suite.bookmarkModule = bookmarks.New(suite.processor)
suite.NoError(suite.processor.Start()) suite.NoError(suite.processor.Start())
} }
@ -123,7 +123,7 @@ func (suite *BookmarkTestSuite) TestGetBookmark() {
ctx.Set(oauth.SessionAuthorizedToken, oauthToken) ctx.Set(oauth.SessionAuthorizedToken, oauthToken)
ctx.Set(oauth.SessionAuthorizedUser, suite.testUsers["local_account_1"]) ctx.Set(oauth.SessionAuthorizedUser, suite.testUsers["local_account_1"])
ctx.Set(oauth.SessionAuthorizedAccount, suite.testAccounts["local_account_1"]) ctx.Set(oauth.SessionAuthorizedAccount, suite.testAccounts["local_account_1"])
ctx.Request = httptest.NewRequest(http.MethodPost, fmt.Sprintf("http://localhost:8080%s", strings.Replace(status.BookmarkPath, ":id", targetStatus.ID, 1)), nil) // the endpoint we're hitting ctx.Request = httptest.NewRequest(http.MethodPost, fmt.Sprintf("http://localhost:8080%s", strings.Replace(statuses.BookmarkPath, ":id", targetStatus.ID, 1)), nil) // the endpoint we're hitting
ctx.Request.Header.Set("accept", "application/json") ctx.Request.Header.Set("accept", "application/json")
suite.bookmarkModule.BookmarksGETHandler(ctx) suite.bookmarkModule.BookmarksGETHandler(ctx)

View file

@ -6,7 +6,7 @@ import (
"strconv" "strconv"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/superseriousbusiness/gotosocial/internal/api" apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util"
"github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/oauth" "github.com/superseriousbusiness/gotosocial/internal/oauth"
) )
@ -56,12 +56,12 @@ const (
func (m *Module) BookmarksGETHandler(c *gin.Context) { func (m *Module) BookmarksGETHandler(c *gin.Context) {
authed, err := oauth.Authed(c, true, true, true, true) authed, err := oauth.Authed(c, true, true, true, true)
if err != nil { if err != nil {
api.ErrorHandler(c, gtserror.NewErrorUnauthorized(err, err.Error()), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorUnauthorized(err, err.Error()), m.processor.InstanceGet)
return return
} }
if _, err := api.NegotiateAccept(c, api.JSONAcceptHeaders...); err != nil { if _, err := apiutil.NegotiateAccept(c, apiutil.JSONAcceptHeaders...); err != nil {
api.ErrorHandler(c, gtserror.NewErrorNotAcceptable(err, err.Error()), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorNotAcceptable(err, err.Error()), m.processor.InstanceGet)
return return
} }
@ -71,7 +71,7 @@ func (m *Module) BookmarksGETHandler(c *gin.Context) {
i, err := strconv.ParseInt(limitString, 10, 64) i, err := strconv.ParseInt(limitString, 10, 64)
if err != nil { if err != nil {
err := fmt.Errorf("error parsing %s: %s", LimitKey, err) err := fmt.Errorf("error parsing %s: %s", LimitKey, err)
api.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGet)
return return
} }
limit = int(i) limit = int(i)
@ -91,12 +91,12 @@ func (m *Module) BookmarksGETHandler(c *gin.Context) {
resp, errWithCode := m.processor.BookmarksGet(c.Request.Context(), authed, maxID, minID, limit) resp, errWithCode := m.processor.BookmarksGet(c.Request.Context(), authed, maxID, minID, limit)
if errWithCode != nil { if errWithCode != nil {
api.ErrorHandler(c, errWithCode, m.processor.InstanceGet) apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGet)
return return
} }
if errWithCode != nil { if errWithCode != nil {
api.ErrorHandler(c, errWithCode, m.processor.InstanceGet) apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGet)
return return
} }

View file

@ -0,0 +1,45 @@
/*
GoToSocial
Copyright (C) 2021-2022 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package customemojis
import (
"net/http"
"github.com/gin-gonic/gin"
"github.com/superseriousbusiness/gotosocial/internal/processing"
)
const (
// BasePath is the base path for serving custom emojis, minus the 'api' prefix
BasePath = "/v1/custom_emojis"
)
type Module struct {
processor processing.Processor
}
func New(processor processing.Processor) *Module {
return &Module{
processor: processor,
}
}
func (m *Module) Route(attachHandler func(method string, path string, f ...gin.HandlerFunc) gin.IRoutes) {
attachHandler(http.MethodGet, BasePath, m.CustomEmojisGETHandler)
}

View file

@ -0,0 +1,76 @@
/*
GoToSocial
Copyright (C) 2021-2022 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package customemojis
import (
"net/http"
"github.com/gin-gonic/gin"
apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/oauth"
)
// CustomEmojisGETHandler swagger:operation GET /api/v1/custom_emojis customEmojisGet
//
// Get an array of custom emojis available on the instance.
//
// ---
// tags:
// - custom_emojis
//
// produces:
// - application/json
//
// security:
// - OAuth2 Bearer:
// - read:custom_emojis
//
// responses:
// '200':
// description: Array of custom emojis.
// schema:
// type: array
// items:
// "$ref": "#/definitions/emoji"
// '401':
// description: unauthorized
// '406':
// description: not acceptable
// '500':
// description: internal server error
func (m *Module) CustomEmojisGETHandler(c *gin.Context) {
if _, err := oauth.Authed(c, true, true, true, true); err != nil {
apiutil.ErrorHandler(c, gtserror.NewErrorUnauthorized(err, err.Error()), m.processor.InstanceGet)
return
}
if _, err := apiutil.NegotiateAccept(c, apiutil.JSONAcceptHeaders...); err != nil {
apiutil.ErrorHandler(c, gtserror.NewErrorNotAcceptable(err, err.Error()), m.processor.InstanceGet)
return
}
emojis, errWithCode := m.processor.CustomEmojisGet(c)
if errWithCode != nil {
apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGet)
return
}
c.JSON(http.StatusOK, emojis)
}

View file

@ -1,58 +0,0 @@
package emoji
import (
"net/http"
"github.com/gin-gonic/gin"
"github.com/superseriousbusiness/gotosocial/internal/api"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/oauth"
)
// EmojisGETHandler swagger:operation GET /api/v1/custom_emojis customEmojisGet
//
// Get an array of custom emojis available on the instance.
//
// ---
// tags:
// - custom_emojis
//
// produces:
// - application/json
//
// security:
// - OAuth2 Bearer:
// - read:custom_emojis
//
// responses:
// '200':
// description: Array of custom emojis.
// schema:
// type: array
// items:
// "$ref": "#/definitions/emoji"
// '401':
// description: unauthorized
// '406':
// description: not acceptable
// '500':
// description: internal server error
func (m *Module) EmojisGETHandler(c *gin.Context) {
if _, err := oauth.Authed(c, true, true, true, true); err != nil {
api.ErrorHandler(c, gtserror.NewErrorUnauthorized(err, err.Error()), m.processor.InstanceGet)
return
}
if _, err := api.NegotiateAccept(c, api.JSONAcceptHeaders...); err != nil {
api.ErrorHandler(c, gtserror.NewErrorNotAcceptable(err, err.Error()), m.processor.InstanceGet)
return
}
emojis, errWithCode := m.processor.CustomEmojisGet(c)
if errWithCode != nil {
api.ErrorHandler(c, errWithCode, m.processor.InstanceGet)
return
}
c.JSON(http.StatusOK, emojis)
}

View file

@ -21,14 +21,13 @@ package favourites
import ( import (
"net/http" "net/http"
"github.com/superseriousbusiness/gotosocial/internal/api" "github.com/gin-gonic/gin"
"github.com/superseriousbusiness/gotosocial/internal/processing" "github.com/superseriousbusiness/gotosocial/internal/processing"
"github.com/superseriousbusiness/gotosocial/internal/router"
) )
const ( const (
// BasePath is the base URI path for serving favourites // BasePath is the base URI path for serving favourites, minus the 'api' prefix
BasePath = "/api/v1/favourites" BasePath = "/v1/favourites"
// MaxIDKey is the url query for setting a max status ID to return // MaxIDKey is the url query for setting a max status ID to return
MaxIDKey = "max_id" MaxIDKey = "max_id"
@ -42,20 +41,16 @@ const (
LocalKey = "local" LocalKey = "local"
) )
// Module implements the ClientAPIModule interface for everything relating to viewing favourites
type Module struct { type Module struct {
processor processing.Processor processor processing.Processor
} }
// New returns a new favourites module func New(processor processing.Processor) *Module {
func New(processor processing.Processor) api.ClientModule {
return &Module{ return &Module{
processor: processor, processor: processor,
} }
} }
// Route attaches all routes from this module to the given router func (m *Module) Route(attachHandler func(method string, path string, f ...gin.HandlerFunc) gin.IRoutes) {
func (m *Module) Route(r router.Router) error { attachHandler(http.MethodGet, BasePath, m.FavouritesGETHandler)
r.AttachHandler(http.MethodGet, BasePath, m.FavouritesGETHandler)
return nil
} }

View file

@ -87,7 +87,7 @@ func (suite *FavouritesStandardTestSuite) SetupTest() {
suite.federator = testrig.NewTestFederator(suite.db, testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil, "../../../../testrig/media"), suite.db, fedWorker), suite.storage, suite.mediaManager, fedWorker) suite.federator = testrig.NewTestFederator(suite.db, testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil, "../../../../testrig/media"), suite.db, fedWorker), suite.storage, suite.mediaManager, fedWorker)
suite.emailSender = testrig.NewEmailSender("../../../../web/template/", nil) suite.emailSender = testrig.NewEmailSender("../../../../web/template/", nil)
suite.processor = testrig.NewTestProcessor(suite.db, suite.storage, suite.federator, suite.emailSender, suite.mediaManager, clientWorker, fedWorker) suite.processor = testrig.NewTestProcessor(suite.db, suite.storage, suite.federator, suite.emailSender, suite.mediaManager, clientWorker, fedWorker)
suite.favModule = favourites.New(suite.processor).(*favourites.Module) suite.favModule = favourites.New(suite.processor)
suite.NoError(suite.processor.Start()) suite.NoError(suite.processor.Start())
} }

View file

@ -6,7 +6,7 @@ import (
"strconv" "strconv"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/superseriousbusiness/gotosocial/internal/api" apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util"
"github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/oauth" "github.com/superseriousbusiness/gotosocial/internal/oauth"
) )
@ -78,12 +78,12 @@ import (
func (m *Module) FavouritesGETHandler(c *gin.Context) { func (m *Module) FavouritesGETHandler(c *gin.Context) {
authed, err := oauth.Authed(c, true, true, true, true) authed, err := oauth.Authed(c, true, true, true, true)
if err != nil { if err != nil {
api.ErrorHandler(c, gtserror.NewErrorUnauthorized(err, err.Error()), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorUnauthorized(err, err.Error()), m.processor.InstanceGet)
return return
} }
if _, err := api.NegotiateAccept(c, api.JSONAcceptHeaders...); err != nil { if _, err := apiutil.NegotiateAccept(c, apiutil.JSONAcceptHeaders...); err != nil {
api.ErrorHandler(c, gtserror.NewErrorNotAcceptable(err, err.Error()), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorNotAcceptable(err, err.Error()), m.processor.InstanceGet)
return return
} }
@ -105,7 +105,7 @@ func (m *Module) FavouritesGETHandler(c *gin.Context) {
i, err := strconv.ParseInt(limitString, 10, 32) i, err := strconv.ParseInt(limitString, 10, 32)
if err != nil { if err != nil {
err := fmt.Errorf("error parsing %s: %s", LimitKey, err) err := fmt.Errorf("error parsing %s: %s", LimitKey, err)
api.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGet)
return return
} }
limit = int(i) limit = int(i)
@ -113,7 +113,7 @@ func (m *Module) FavouritesGETHandler(c *gin.Context) {
resp, errWithCode := m.processor.FavedTimelineGet(c.Request.Context(), authed, maxID, minID, limit) resp, errWithCode := m.processor.FavedTimelineGet(c.Request.Context(), authed, maxID, minID, limit)
if errWithCode != nil { if errWithCode != nil {
api.ErrorHandler(c, errWithCode, m.processor.InstanceGet) apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGet)
return return
} }

View file

@ -21,30 +21,25 @@ package filter
import ( import (
"net/http" "net/http"
"github.com/superseriousbusiness/gotosocial/internal/api" "github.com/gin-gonic/gin"
"github.com/superseriousbusiness/gotosocial/internal/processing" "github.com/superseriousbusiness/gotosocial/internal/processing"
"github.com/superseriousbusiness/gotosocial/internal/router"
) )
const ( const (
// BasePath is the base path for serving the filter API // BasePath is the base path for serving the filters API, minus the 'api' prefix
BasePath = "/api/v1/filters" BasePath = "/v1/filters"
) )
// Module implements the ClientAPIModule interface for every related to filters
type Module struct { type Module struct {
processor processing.Processor processor processing.Processor
} }
// New returns a new filter module func New(processor processing.Processor) *Module {
func New(processor processing.Processor) api.ClientModule {
return &Module{ return &Module{
processor: processor, processor: processor,
} }
} }
// Route attaches all routes from this module to the given router func (m *Module) Route(attachHandler func(method string, path string, f ...gin.HandlerFunc) gin.IRoutes) {
func (m *Module) Route(r router.Router) error { attachHandler(http.MethodGet, BasePath, m.FiltersGETHandler)
r.AttachHandler(http.MethodGet, BasePath, m.FiltersGETHandler)
return nil
} }

View file

@ -4,7 +4,7 @@ import (
"net/http" "net/http"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/superseriousbusiness/gotosocial/internal/api" apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util"
"github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/oauth" "github.com/superseriousbusiness/gotosocial/internal/oauth"
) )
@ -12,12 +12,12 @@ import (
// FiltersGETHandler returns a list of filters set by/for the authed account // FiltersGETHandler returns a list of filters set by/for the authed account
func (m *Module) FiltersGETHandler(c *gin.Context) { func (m *Module) FiltersGETHandler(c *gin.Context) {
if _, err := oauth.Authed(c, true, true, true, true); err != nil { if _, err := oauth.Authed(c, true, true, true, true); err != nil {
api.ErrorHandler(c, gtserror.NewErrorUnauthorized(err, err.Error()), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorUnauthorized(err, err.Error()), m.processor.InstanceGet)
return return
} }
if _, err := api.NegotiateAccept(c, api.JSONAcceptHeaders...); err != nil { if _, err := apiutil.NegotiateAccept(c, apiutil.JSONAcceptHeaders...); err != nil {
api.ErrorHandler(c, gtserror.NewErrorNotAcceptable(err, err.Error()), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorNotAcceptable(err, err.Error()), m.processor.InstanceGet)
return return
} }

View file

@ -16,14 +16,14 @@
along with this program. If not, see <http://www.gnu.org/licenses/>. along with this program. If not, see <http://www.gnu.org/licenses/>.
*/ */
package followrequest package followrequests
import ( import (
"errors" "errors"
"net/http" "net/http"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/superseriousbusiness/gotosocial/internal/api" apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util"
"github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/oauth" "github.com/superseriousbusiness/gotosocial/internal/oauth"
) )
@ -72,25 +72,25 @@ import (
func (m *Module) FollowRequestAuthorizePOSTHandler(c *gin.Context) { func (m *Module) FollowRequestAuthorizePOSTHandler(c *gin.Context) {
authed, err := oauth.Authed(c, true, true, true, true) authed, err := oauth.Authed(c, true, true, true, true)
if err != nil { if err != nil {
api.ErrorHandler(c, gtserror.NewErrorUnauthorized(err, err.Error()), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorUnauthorized(err, err.Error()), m.processor.InstanceGet)
return return
} }
if _, err := api.NegotiateAccept(c, api.JSONAcceptHeaders...); err != nil { if _, err := apiutil.NegotiateAccept(c, apiutil.JSONAcceptHeaders...); err != nil {
api.ErrorHandler(c, gtserror.NewErrorNotAcceptable(err, err.Error()), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorNotAcceptable(err, err.Error()), m.processor.InstanceGet)
return return
} }
originAccountID := c.Param(IDKey) originAccountID := c.Param(IDKey)
if originAccountID == "" { if originAccountID == "" {
err := errors.New("no account id specified") err := errors.New("no account id specified")
api.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGet)
return return
} }
relationship, errWithCode := m.processor.FollowRequestAccept(c.Request.Context(), authed, originAccountID) relationship, errWithCode := m.processor.FollowRequestAccept(c.Request.Context(), authed, originAccountID)
if errWithCode != nil { if errWithCode != nil {
api.ErrorHandler(c, errWithCode, m.processor.InstanceGet) apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGet)
return return
} }

View file

@ -16,7 +16,7 @@
along with this program. If not, see <http://www.gnu.org/licenses/>. along with this program. If not, see <http://www.gnu.org/licenses/>.
*/ */
package followrequest_test package followrequests_test
import ( import (
"context" "context"
@ -30,7 +30,7 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/api/client/followrequest" "github.com/superseriousbusiness/gotosocial/internal/api/client/followrequests"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
) )
@ -60,7 +60,7 @@ func (suite *AuthorizeTestSuite) TestAuthorize() {
ctx.Params = gin.Params{ ctx.Params = gin.Params{
gin.Param{ gin.Param{
Key: followrequest.IDKey, Key: followrequests.IDKey,
Value: requestingAccount.ID, Value: requestingAccount.ID,
}, },
} }
@ -90,7 +90,7 @@ func (suite *AuthorizeTestSuite) TestAuthorizeNoFR() {
ctx.Params = gin.Params{ ctx.Params = gin.Params{
gin.Param{ gin.Param{
Key: followrequest.IDKey, Key: followrequests.IDKey,
Value: requestingAccount.ID, Value: requestingAccount.ID,
}, },
} }

View file

@ -16,21 +16,20 @@
along with this program. If not, see <http://www.gnu.org/licenses/>. along with this program. If not, see <http://www.gnu.org/licenses/>.
*/ */
package followrequest package followrequests
import ( import (
"net/http" "net/http"
"github.com/superseriousbusiness/gotosocial/internal/api" "github.com/gin-gonic/gin"
"github.com/superseriousbusiness/gotosocial/internal/processing" "github.com/superseriousbusiness/gotosocial/internal/processing"
"github.com/superseriousbusiness/gotosocial/internal/router"
) )
const ( const (
// IDKey is for account IDs // IDKey is for account IDs
IDKey = "id" IDKey = "id"
// BasePath is the base path for serving the follow request API // BasePath is the base path for serving the follow request API, minus the 'api' prefix
BasePath = "/api/v1/follow_requests" BasePath = "/v1/follow_requests"
// BasePathWithID is just the base path with the ID key in it. // BasePathWithID is just the base path with the ID key in it.
// Use this anywhere you need to know the ID of the account that owns the follow request being queried. // Use this anywhere you need to know the ID of the account that owns the follow request being queried.
BasePathWithID = BasePath + "/:" + IDKey BasePathWithID = BasePath + "/:" + IDKey
@ -40,22 +39,18 @@ const (
RejectPath = BasePathWithID + "/reject" RejectPath = BasePathWithID + "/reject"
) )
// Module implements the ClientAPIModule interface
type Module struct { type Module struct {
processor processing.Processor processor processing.Processor
} }
// New returns a new follow request module func New(processor processing.Processor) *Module {
func New(processor processing.Processor) api.ClientModule {
return &Module{ return &Module{
processor: processor, processor: processor,
} }
} }
// Route attaches all routes from this module to the given router func (m *Module) Route(attachHandler func(method string, path string, f ...gin.HandlerFunc) gin.IRoutes) {
func (m *Module) Route(r router.Router) error { attachHandler(http.MethodGet, BasePath, m.FollowRequestGETHandler)
r.AttachHandler(http.MethodGet, BasePath, m.FollowRequestGETHandler) attachHandler(http.MethodPost, AuthorizePath, m.FollowRequestAuthorizePOSTHandler)
r.AttachHandler(http.MethodPost, AuthorizePath, m.FollowRequestAuthorizePOSTHandler) attachHandler(http.MethodPost, RejectPath, m.FollowRequestRejectPOSTHandler)
r.AttachHandler(http.MethodPost, RejectPath, m.FollowRequestRejectPOSTHandler)
return nil
} }

View file

@ -16,7 +16,7 @@
along with this program. If not, see <http://www.gnu.org/licenses/>. along with this program. If not, see <http://www.gnu.org/licenses/>.
*/ */
package followrequest_test package followrequests_test
import ( import (
"bytes" "bytes"
@ -25,7 +25,7 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/api/client/followrequest" "github.com/superseriousbusiness/gotosocial/internal/api/client/followrequests"
"github.com/superseriousbusiness/gotosocial/internal/concurrency" "github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/db"
@ -59,7 +59,7 @@ type FollowRequestStandardTestSuite struct {
testStatuses map[string]*gtsmodel.Status testStatuses map[string]*gtsmodel.Status
// module being tested // module being tested
followRequestModule *followrequest.Module followRequestModule *followrequests.Module
} }
func (suite *FollowRequestStandardTestSuite) SetupSuite() { func (suite *FollowRequestStandardTestSuite) SetupSuite() {
@ -85,7 +85,7 @@ func (suite *FollowRequestStandardTestSuite) SetupTest() {
suite.federator = testrig.NewTestFederator(suite.db, testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil, "../../../../testrig/media"), suite.db, fedWorker), suite.storage, suite.mediaManager, fedWorker) suite.federator = testrig.NewTestFederator(suite.db, testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil, "../../../../testrig/media"), suite.db, fedWorker), suite.storage, suite.mediaManager, fedWorker)
suite.emailSender = testrig.NewEmailSender("../../../../web/template/", nil) suite.emailSender = testrig.NewEmailSender("../../../../web/template/", nil)
suite.processor = testrig.NewTestProcessor(suite.db, suite.storage, suite.federator, suite.emailSender, suite.mediaManager, clientWorker, fedWorker) suite.processor = testrig.NewTestProcessor(suite.db, suite.storage, suite.federator, suite.emailSender, suite.mediaManager, clientWorker, fedWorker)
suite.followRequestModule = followrequest.New(suite.processor).(*followrequest.Module) suite.followRequestModule = followrequests.New(suite.processor)
testrig.StandardDBSetup(suite.db, nil) testrig.StandardDBSetup(suite.db, nil)
testrig.StandardStorageSetup(suite.storage, "../../../../testrig/media") testrig.StandardStorageSetup(suite.storage, "../../../../testrig/media")

View file

@ -16,13 +16,13 @@
along with this program. If not, see <http://www.gnu.org/licenses/>. along with this program. If not, see <http://www.gnu.org/licenses/>.
*/ */
package followrequest package followrequests
import ( import (
"net/http" "net/http"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/superseriousbusiness/gotosocial/internal/api" apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util"
"github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/oauth" "github.com/superseriousbusiness/gotosocial/internal/oauth"
) )
@ -74,18 +74,18 @@ import (
func (m *Module) FollowRequestGETHandler(c *gin.Context) { func (m *Module) FollowRequestGETHandler(c *gin.Context) {
authed, err := oauth.Authed(c, true, true, true, true) authed, err := oauth.Authed(c, true, true, true, true)
if err != nil { if err != nil {
api.ErrorHandler(c, gtserror.NewErrorUnauthorized(err, err.Error()), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorUnauthorized(err, err.Error()), m.processor.InstanceGet)
return return
} }
if _, err := api.NegotiateAccept(c, api.JSONAcceptHeaders...); err != nil { if _, err := apiutil.NegotiateAccept(c, apiutil.JSONAcceptHeaders...); err != nil {
api.ErrorHandler(c, gtserror.NewErrorNotAcceptable(err, err.Error()), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorNotAcceptable(err, err.Error()), m.processor.InstanceGet)
return return
} }
accts, errWithCode := m.processor.FollowRequestsGet(c.Request.Context(), authed) accts, errWithCode := m.processor.FollowRequestsGet(c.Request.Context(), authed)
if errWithCode != nil { if errWithCode != nil {
api.ErrorHandler(c, errWithCode, m.processor.InstanceGet) apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGet)
return return
} }

View file

@ -16,7 +16,7 @@
along with this program. If not, see <http://www.gnu.org/licenses/>. along with this program. If not, see <http://www.gnu.org/licenses/>.
*/ */
package followrequest_test package followrequests_test
import ( import (
"context" "context"

View file

@ -16,14 +16,14 @@
along with this program. If not, see <http://www.gnu.org/licenses/>. along with this program. If not, see <http://www.gnu.org/licenses/>.
*/ */
package followrequest package followrequests
import ( import (
"errors" "errors"
"net/http" "net/http"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/superseriousbusiness/gotosocial/internal/api" apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util"
"github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/oauth" "github.com/superseriousbusiness/gotosocial/internal/oauth"
) )
@ -70,25 +70,25 @@ import (
func (m *Module) FollowRequestRejectPOSTHandler(c *gin.Context) { func (m *Module) FollowRequestRejectPOSTHandler(c *gin.Context) {
authed, err := oauth.Authed(c, true, true, true, true) authed, err := oauth.Authed(c, true, true, true, true)
if err != nil { if err != nil {
api.ErrorHandler(c, gtserror.NewErrorUnauthorized(err, err.Error()), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorUnauthorized(err, err.Error()), m.processor.InstanceGet)
return return
} }
if _, err := api.NegotiateAccept(c, api.JSONAcceptHeaders...); err != nil { if _, err := apiutil.NegotiateAccept(c, apiutil.JSONAcceptHeaders...); err != nil {
api.ErrorHandler(c, gtserror.NewErrorNotAcceptable(err, err.Error()), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorNotAcceptable(err, err.Error()), m.processor.InstanceGet)
return return
} }
originAccountID := c.Param(IDKey) originAccountID := c.Param(IDKey)
if originAccountID == "" { if originAccountID == "" {
err := errors.New("no account id specified") err := errors.New("no account id specified")
api.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGet) apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGet)
return return
} }
relationship, errWithCode := m.processor.FollowRequestReject(c.Request.Context(), authed, originAccountID) relationship, errWithCode := m.processor.FollowRequestReject(c.Request.Context(), authed, originAccountID)
if errWithCode != nil { if errWithCode != nil {
api.ErrorHandler(c, errWithCode, m.processor.InstanceGet) apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGet)
return return
} }

View file

@ -16,7 +16,7 @@
along with this program. If not, see <http://www.gnu.org/licenses/>. along with this program. If not, see <http://www.gnu.org/licenses/>.
*/ */
package followrequest_test package followrequests_test
import ( import (
"context" "context"
@ -30,7 +30,7 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/api/client/followrequest" "github.com/superseriousbusiness/gotosocial/internal/api/client/followrequests"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
) )
@ -60,7 +60,7 @@ func (suite *RejectTestSuite) TestReject() {
ctx.Params = gin.Params{ ctx.Params = gin.Params{
gin.Param{ gin.Param{
Key: followrequest.IDKey, Key: followrequests.IDKey,
Value: requestingAccount.ID, Value: requestingAccount.ID,
}, },
} }

Some files were not shown because too many files have changed in this diff Show more