[feature] add authorization to the already-existing authentication (#365)

* add ensureUserIsAuthorizedOrRedirect to /oauth/authorize

* adding authorization (email confirm, account approve, etc) to TokenCheck

* revert un-needed changes to signin.go

* oops what happened here

* error css

* add account.SuspendedAt check

* remove redundant checks from oauth util Authed function

* wip tests

* tests passing

* stop stripping useful information from ErrAlreadyExists

* that feeling of scraping the dryer LINT off the screen

* oops I didn't mean to get rid of this NewTestRouter function

* make tests work with recorder

* re-add ConfigureTemplatesWithGin to handle template path err

Co-authored-by: tsmethurst <tobi.smethurst@protonmail.com>
This commit is contained in:
Forest Johnson 2022-02-07 11:04:31 +00:00 committed by GitHub
parent 5c9d20cea3
commit 6ed368cbeb
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
19 changed files with 424 additions and 47 deletions

View file

@ -32,10 +32,23 @@ import (
const ( const (
// AuthSignInPath is the API path for users to sign in through // AuthSignInPath is the API path for users to sign in through
AuthSignInPath = "/auth/sign_in" AuthSignInPath = "/auth/sign_in"
// CheckYourEmailPath users land here after registering a new account, instructs them to confirm thier email
CheckYourEmailPath = "/check_your_email"
// WaitForApprovalPath users land here after confirming thier email but before an admin approves thier account
// (if such is required)
WaitForApprovalPath = "/wait_for_approval"
// AccountDisabledPath users land here when thier account is suspended by an admin
AccountDisabledPath = "/account_disabled"
// OauthTokenPath is the API path to use for granting token requests to users with valid credentials // OauthTokenPath is the API path to use for granting token requests to users with valid credentials
OauthTokenPath = "/oauth/token" OauthTokenPath = "/oauth/token"
// OauthAuthorizePath is the API path for authorization requests (eg., authorize this app to act on my behalf as a user) // OauthAuthorizePath is the API path for authorization requests (eg., authorize this app to act on my behalf as a user)
OauthAuthorizePath = "/oauth/authorize" OauthAuthorizePath = "/oauth/authorize"
// CallbackPath is the API path for receiving callback tokens from external OIDC providers // CallbackPath is the API path for receiving callback tokens from external OIDC providers
CallbackPath = oidc.CallbackPath CallbackPath = oidc.CallbackPath

View file

@ -18,4 +18,96 @@
package auth_test package auth_test
// TODO import (
"context"
"fmt"
"net/http/httptest"
"github.com/gin-contrib/sessions"
"github.com/gin-contrib/sessions/memstore"
"github.com/gin-gonic/gin"
"github.com/spf13/viper"
"github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/api/client/auth"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/oauth"
"github.com/superseriousbusiness/gotosocial/internal/oidc"
"github.com/superseriousbusiness/gotosocial/internal/router"
"github.com/superseriousbusiness/gotosocial/testrig"
)
type AuthStandardTestSuite struct {
suite.Suite
db db.DB
idp oidc.IDP
oauthServer oauth.Server
// standard suite models
testTokens map[string]*gtsmodel.Token
testClients map[string]*gtsmodel.Client
testApplications map[string]*gtsmodel.Application
testUsers map[string]*gtsmodel.User
testAccounts map[string]*gtsmodel.Account
// module being tested
authModule *auth.Module
}
const (
sessionUserID = "userid"
sessionClientID = "client_id"
)
func (suite *AuthStandardTestSuite) SetupSuite() {
suite.testTokens = testrig.NewTestTokens()
suite.testClients = testrig.NewTestClients()
suite.testApplications = testrig.NewTestApplications()
suite.testUsers = testrig.NewTestUsers()
suite.testAccounts = testrig.NewTestAccounts()
}
func (suite *AuthStandardTestSuite) SetupTest() {
testrig.InitTestConfig()
suite.db = testrig.NewTestDB()
testrig.InitTestLog()
suite.oauthServer = testrig.NewTestOauthServer(suite.db)
var err error
suite.idp, err = oidc.NewIDP(context.Background())
if err != nil {
panic(err)
}
suite.authModule = auth.New(suite.db, suite.oauthServer, suite.idp).(*auth.Module)
testrig.StandardDBSetup(suite.db, nil)
}
func (suite *AuthStandardTestSuite) TearDownTest() {
testrig.StandardDBTeardown(suite.db)
}
func (suite *AuthStandardTestSuite) newContext(requestMethod string, requestPath string) (*gin.Context, *httptest.ResponseRecorder) {
// create the recorder and gin test context
recorder := httptest.NewRecorder()
ctx, engine := gin.CreateTestContext(recorder)
// load templates into the engine
testrig.ConfigureTemplatesWithGin(engine)
// create the request
protocol := viper.GetString(config.Keys.Protocol)
host := viper.GetString(config.Keys.Host)
baseURI := fmt.Sprintf("%s://%s", protocol, host)
requestURI := fmt.Sprintf("%s/%s", baseURI, requestPath)
ctx.Request = httptest.NewRequest(requestMethod, requestURI, nil) // the endpoint we're hitting
ctx.Request.Header.Set("accept", "text/html")
// trigger the session middleware on the context
store := memstore.NewStore(make([]byte, 32), make([]byte, 32))
store.Options(router.SessionOptions())
sessionMiddleware := sessions.Sessions("gotosocial-localhost", store)
sessionMiddleware(ctx)
return ctx, recorder
}

View file

@ -44,7 +44,7 @@ func (m *Module) AuthorizeGETHandler(c *gin.Context) {
s := sessions.Default(c) s := sessions.Default(c)
if _, err := api.NegotiateAccept(c, api.HTMLAcceptHeaders...); err != nil { if _, err := api.NegotiateAccept(c, api.HTMLAcceptHeaders...); err != nil {
c.JSON(http.StatusNotAcceptable, gin.H{"error": err.Error()}) c.HTML(http.StatusNotAcceptable, "error.tmpl", gin.H{"error": err.Error()})
return return
} }
@ -57,7 +57,7 @@ func (m *Module) AuthorizeGETHandler(c *gin.Context) {
if err := c.Bind(form); err != nil { if err := c.Bind(form); err != nil {
l.Debugf("invalid auth form: %s", err) l.Debugf("invalid auth form: %s", err)
m.clearSession(s) m.clearSession(s)
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) c.HTML(http.StatusBadRequest, "error.tmpl", gin.H{"error": err.Error()})
return return
} }
l.Debugf("parsed auth form: %+v", form) l.Debugf("parsed auth form: %+v", form)
@ -65,7 +65,7 @@ func (m *Module) AuthorizeGETHandler(c *gin.Context) {
if err := extractAuthForm(s, form); err != nil { if err := extractAuthForm(s, form); err != nil {
l.Debugf(fmt.Sprintf("error parsing form at /oauth/authorize: %s", err)) l.Debugf(fmt.Sprintf("error parsing form at /oauth/authorize: %s", err))
m.clearSession(s) m.clearSession(s)
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) c.HTML(http.StatusBadRequest, "error.tmpl", gin.H{"error": err.Error()})
return return
} }
c.Redirect(http.StatusSeeOther, AuthSignInPath) c.Redirect(http.StatusSeeOther, AuthSignInPath)
@ -75,28 +75,33 @@ func (m *Module) AuthorizeGETHandler(c *gin.Context) {
// We can use the client_id on the session to retrieve info about the app associated with the client_id // We can use the client_id on the session to retrieve info about the app associated with the client_id
clientID, ok := s.Get(sessionClientID).(string) clientID, ok := s.Get(sessionClientID).(string)
if !ok || clientID == "" { if !ok || clientID == "" {
c.JSON(http.StatusInternalServerError, gin.H{"error": "no client_id found in session"}) c.HTML(http.StatusInternalServerError, "error.tmpl", gin.H{"error": "no client_id found in session"})
return return
} }
app := &gtsmodel.Application{} app := &gtsmodel.Application{}
if err := m.db.GetWhere(c.Request.Context(), []db.Where{{Key: sessionClientID, Value: clientID}}, app); err != nil { if err := m.db.GetWhere(c.Request.Context(), []db.Where{{Key: sessionClientID, Value: clientID}}, app); err != nil {
m.clearSession(s) m.clearSession(s)
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("no application found for client id %s", clientID)}) c.HTML(http.StatusInternalServerError, "error.tmpl", gin.H{
"error": fmt.Sprintf("no application found for client id %s", clientID),
})
return return
} }
// we can also use the userid of the user to fetch their username from the db to greet them nicely <3 // redirect the user if they have not confirmed their email yet, thier account has not been approved yet,
// or thier account has been disabled.
user := &gtsmodel.User{} user := &gtsmodel.User{}
if err := m.db.GetByID(c.Request.Context(), userID, user); err != nil { if err := m.db.GetByID(c.Request.Context(), userID, user); err != nil {
m.clearSession(s) m.clearSession(s)
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.HTML(http.StatusInternalServerError, "error.tmpl", gin.H{"error": err.Error()})
return return
} }
acct, err := m.db.GetAccountByID(c.Request.Context(), user.AccountID) acct, err := m.db.GetAccountByID(c.Request.Context(), user.AccountID)
if err != nil { if err != nil {
m.clearSession(s) m.clearSession(s)
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.HTML(http.StatusInternalServerError, "error.tmpl", gin.H{"error": err.Error()})
return
}
if !ensureUserIsAuthorizedOrRedirect(c, user, acct) {
return return
} }
@ -104,13 +109,13 @@ func (m *Module) AuthorizeGETHandler(c *gin.Context) {
redirect, ok := s.Get(sessionRedirectURI).(string) redirect, ok := s.Get(sessionRedirectURI).(string)
if !ok || redirect == "" { if !ok || redirect == "" {
m.clearSession(s) m.clearSession(s)
c.JSON(http.StatusInternalServerError, gin.H{"error": "no redirect_uri found in session"}) c.HTML(http.StatusInternalServerError, "error.tmpl", gin.H{"error": "no redirect_uri found in session"})
return return
} }
scope, ok := s.Get(sessionScope).(string) scope, ok := s.Get(sessionScope).(string)
if !ok || scope == "" { if !ok || scope == "" {
m.clearSession(s) m.clearSession(s)
c.JSON(http.StatusInternalServerError, gin.H{"error": "no scope found in session"}) c.HTML(http.StatusInternalServerError, "error.tmpl", gin.H{"error": "no scope found in session"})
return return
} }
@ -170,10 +175,28 @@ func (m *Module) AuthorizePOSTHandler(c *gin.Context) {
errs = append(errs, "session missing userid") errs = append(errs, "session missing userid")
} }
// redirect the user if they have not confirmed their email yet, thier account has not been approved yet,
// or thier account has been disabled.
user := &gtsmodel.User{}
if err := m.db.GetByID(c.Request.Context(), userID, user); err != nil {
m.clearSession(s)
c.HTML(http.StatusInternalServerError, "error.tmpl", gin.H{"error": err.Error()})
return
}
acct, err := m.db.GetAccountByID(c.Request.Context(), user.AccountID)
if err != nil {
m.clearSession(s)
c.HTML(http.StatusInternalServerError, "error.tmpl", gin.H{"error": err.Error()})
return
}
if !ensureUserIsAuthorizedOrRedirect(c, user, acct) {
return
}
m.clearSession(s) m.clearSession(s)
if len(errs) != 0 { if len(errs) != 0 {
c.JSON(http.StatusBadRequest, gin.H{"error": strings.Join(errs, ": ")}) c.HTML(http.StatusBadRequest, "error.tmpl", gin.H{"error": strings.Join(errs, ": ")})
return return
} }
@ -190,7 +213,7 @@ func (m *Module) AuthorizePOSTHandler(c *gin.Context) {
// and proceed with authorization using the oauth2 library // and proceed with authorization using the oauth2 library
if err := m.server.HandleAuthorizeRequest(c.Writer, c.Request); err != nil { if err := m.server.HandleAuthorizeRequest(c.Writer, c.Request); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) c.HTML(http.StatusBadRequest, "error.tmpl", gin.H{"error": err.Error()})
} }
} }
@ -216,3 +239,27 @@ func extractAuthForm(s sessions.Session, form *model.OAuthAuthorize) error {
s.Set(sessionState, uuid.NewString()) s.Set(sessionState, uuid.NewString())
return s.Save() return s.Save()
} }
func ensureUserIsAuthorizedOrRedirect(ctx *gin.Context, user *gtsmodel.User, account *gtsmodel.Account) bool {
if user.ConfirmedAt.IsZero() {
ctx.Redirect(http.StatusSeeOther, CheckYourEmailPath)
return false
}
if !user.Approved {
ctx.Redirect(http.StatusSeeOther, WaitForApprovalPath)
return false
}
if user.Disabled {
ctx.Redirect(http.StatusSeeOther, AccountDisabledPath)
return false
}
if !account.SuspendedAt.IsZero() {
ctx.Redirect(http.StatusSeeOther, AccountDisabledPath)
return false
}
return true
}

View file

@ -0,0 +1,113 @@
package auth_test
import (
"context"
"fmt"
"net/http"
"testing"
"time"
"codeberg.org/gruf/go-errors"
"github.com/gin-contrib/sessions"
"github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/api/client/auth"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
type AuthAuthorizeTestSuite struct {
AuthStandardTestSuite
}
type authorizeHandlerTestCase struct {
description string
mutateUserAccount func(*gtsmodel.User, *gtsmodel.Account)
expectedStatusCode int
expectedLocationHeader string
}
func (suite *AuthAuthorizeTestSuite) TestAccountAuthorizeHandler() {
var tests = []authorizeHandlerTestCase{
{
description: "user has their email unconfirmed",
mutateUserAccount: func(user *gtsmodel.User, account *gtsmodel.Account) {
// nothing to do, weed_lord420 already has their email unconfirmed
},
expectedStatusCode: http.StatusSeeOther,
expectedLocationHeader: auth.CheckYourEmailPath,
},
{
description: "user has their email confirmed but is not approved",
mutateUserAccount: func(user *gtsmodel.User, account *gtsmodel.Account) {
user.ConfirmedAt = time.Now()
user.Email = user.UnconfirmedEmail
},
expectedStatusCode: http.StatusSeeOther,
expectedLocationHeader: auth.WaitForApprovalPath,
},
{
description: "user has their email confirmed and is approved, but User entity has been disabled",
mutateUserAccount: func(user *gtsmodel.User, account *gtsmodel.Account) {
user.ConfirmedAt = time.Now()
user.Email = user.UnconfirmedEmail
user.Approved = true
user.Disabled = true
},
expectedStatusCode: http.StatusSeeOther,
expectedLocationHeader: auth.AccountDisabledPath,
},
{
description: "user has their email confirmed and is approved, but Account entity has been suspended",
mutateUserAccount: func(user *gtsmodel.User, account *gtsmodel.Account) {
user.ConfirmedAt = time.Now()
user.Email = user.UnconfirmedEmail
user.Approved = true
user.Disabled = false
account.SuspendedAt = time.Now()
},
expectedStatusCode: http.StatusSeeOther,
expectedLocationHeader: auth.AccountDisabledPath,
},
}
doTest := func(testCase authorizeHandlerTestCase) {
ctx, recorder := suite.newContext(http.MethodGet, auth.OauthAuthorizePath)
user := suite.testUsers["unconfirmed_account"]
account := suite.testAccounts["unconfirmed_account"]
testSession := sessions.Default(ctx)
testSession.Set(sessionUserID, user.ID)
testSession.Set(sessionClientID, suite.testApplications["application_1"].ClientID)
if err := testSession.Save(); err != nil {
panic(errors.WrapMsgf(err, "failed on case: %s", testCase.description))
}
testCase.mutateUserAccount(user, account)
testCase.description = fmt.Sprintf("%s, %t, %s", user.Email, user.Disabled, account.SuspendedAt)
user.UpdatedAt = time.Now()
err := suite.db.UpdateByPrimaryKey(context.Background(), user)
suite.NoError(err)
_, err = suite.db.UpdateAccount(context.Background(), account)
suite.NoError(err)
// call the handler
suite.authModule.AuthorizeGETHandler(ctx)
// 1. we should have a redirect
suite.Equal(testCase.expectedStatusCode, recorder.Code, fmt.Sprintf("failed on case: %s", testCase.description))
// 2. we should have a redirect to the check your email path, as this user has not confirmed their email yet.
suite.Equal(testCase.expectedLocationHeader, recorder.Header().Get("Location"), fmt.Sprintf("failed on case: %s", testCase.description))
}
for _, testCase := range tests {
doTest(testCase)
}
}
func TestAccountUpdateTestSuite(t *testing.T) {
suite.Run(t, new(AuthAuthorizeTestSuite))
}

View file

@ -62,6 +62,22 @@ func (m *Module) TokenCheck(c *gin.Context) {
l.Warnf("no user found for userID %s", userID) l.Warnf("no user found for userID %s", userID)
return return
} }
if user.ConfirmedAt.IsZero() {
l.Warnf("authenticated user %s has never confirmed thier email address", userID)
return
}
if !user.Approved {
l.Warnf("authenticated user %s's account was never approved by an admin", userID)
return
}
if user.Disabled {
l.Warnf("authenticated user %s's account was disabled'", userID)
return
}
c.Set(oauth.SessionAuthorizedUser, user) c.Set(oauth.SessionAuthorizedUser, user)
// fetch account for this token // fetch account for this token
@ -74,6 +90,12 @@ func (m *Module) TokenCheck(c *gin.Context) {
l.Warnf("no account found for userID %s", userID) l.Warnf("no account found for userID %s", userID)
return return
} }
if !acct.SuspendedAt.IsZero() {
l.Warnf("authenticated user %s's account (accountId=%s) has been suspended", userID, user.AccountID)
return
}
c.Set(oauth.SessionAuthorizedAccount, acct) c.Set(oauth.SessionAuthorizedAccount, acct)
} }

View file

@ -19,7 +19,7 @@ func processPostgresError(err error) db.Error {
// (https://www.postgresql.org/docs/10/errcodes-appendix.html) // (https://www.postgresql.org/docs/10/errcodes-appendix.html)
switch pgErr.Code { switch pgErr.Code {
case "23505" /* unique_violation */ : case "23505" /* unique_violation */ :
return db.ErrAlreadyExists return db.NewErrAlreadyExists(pgErr.Message)
default: default:
return err return err
} }
@ -36,7 +36,7 @@ func processSQLiteError(err error) db.Error {
// Handle supplied error code: // Handle supplied error code:
switch sqliteErr.Code() { switch sqliteErr.Code() {
case sqlite3.SQLITE_CONSTRAINT_UNIQUE: case sqlite3.SQLITE_CONSTRAINT_UNIQUE:
return db.ErrAlreadyExists return db.NewErrAlreadyExists(err.Error())
default: default:
return err return err
} }

View file

@ -28,8 +28,19 @@ var (
ErrNoEntries Error = fmt.Errorf("no entries") ErrNoEntries Error = fmt.Errorf("no entries")
// ErrMultipleEntries is returned when a caller expected ONE entry for a query, but multiples were found. // ErrMultipleEntries is returned when a caller expected ONE entry for a query, but multiples were found.
ErrMultipleEntries Error = fmt.Errorf("multiple entries") ErrMultipleEntries Error = fmt.Errorf("multiple entries")
// ErrAlreadyExists is returned when a caller tries to insert a database entry that already exists in the db.
ErrAlreadyExists Error = fmt.Errorf("already exists")
// ErrUnknown denotes an unknown database error. // ErrUnknown denotes an unknown database error.
ErrUnknown Error = fmt.Errorf("unknown error") ErrUnknown Error = fmt.Errorf("unknown error")
) )
// ErrAlreadyExists is returned when a caller tries to insert a database entry that already exists in the db.
type ErrAlreadyExists struct {
message string
}
func (e *ErrAlreadyExists) Error() string {
return e.message
}
func NewErrAlreadyExists(msg string) error {
return &ErrAlreadyExists{message: msg}
}

View file

@ -20,6 +20,7 @@ package dereferencing
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"net/url" "net/url"
@ -60,7 +61,8 @@ func (d *deref) GetRemoteAttachment(ctx context.Context, requestingUsername stri
} }
if err := d.db.Put(ctx, a); err != nil { if err := d.db.Put(ctx, a); err != nil {
if err != db.ErrAlreadyExists { var alreadyExistsError *db.ErrAlreadyExists
if !errors.As(err, &alreadyExistsError) {
return nil, fmt.Errorf("GetRemoteAttachment: error inserting attachment: %s", err) return nil, fmt.Errorf("GetRemoteAttachment: error inserting attachment: %s", err)
} }
} }

View file

@ -231,7 +231,8 @@ func (f *federatingDB) createNote(ctx context.Context, note vocab.ActivityStream
status.ID = statusID status.ID = statusID
if err := f.db.PutStatus(ctx, status); err != nil { if err := f.db.PutStatus(ctx, status); err != nil {
if err == db.ErrAlreadyExists { var alreadyExistsError *db.ErrAlreadyExists
if errors.As(err, &alreadyExistsError) {
// the status already exists in the database, which means we've already handled everything else, // the status already exists in the database, which means we've already handled everything else,
// so we can just return nil here and be done with it. // so we can just return nil here and be done with it.
return nil return nil

View file

@ -78,25 +78,12 @@ func Authed(c *gin.Context, requireToken bool, requireApp bool, requireUser bool
return nil, errors.New("application not supplied") return nil, errors.New("application not supplied")
} }
if requireUser { if requireUser && a.User == nil {
if a.User == nil { return nil, errors.New("user not supplied or not authorized")
return nil, errors.New("user not supplied")
}
if a.User.Disabled || !a.User.Approved {
return nil, errors.New("user disabled or not approved")
}
if a.User.Email == "" {
return nil, errors.New("user has no confirmed email address")
}
} }
if requireAccount { if requireAccount && a.Account == nil {
if a.Account == nil { return nil, errors.New("account not supplied or not authorized")
return nil, errors.New("account not supplied")
}
if !a.Account.SuspendedAt.IsZero() {
return nil, errors.New("account suspended")
}
} }
return a, nil return a, nil

View file

@ -223,8 +223,11 @@ func (p *processor) ProcessTags(ctx context.Context, form *apimodel.AdvancedStat
return fmt.Errorf("error generating hashtags from status: %s", err) return fmt.Errorf("error generating hashtags from status: %s", err)
} }
for _, tag := range gtsTags { for _, tag := range gtsTags {
if err := p.db.Put(ctx, tag); err != nil && err != db.ErrAlreadyExists { if err := p.db.Put(ctx, tag); err != nil {
return fmt.Errorf("error putting tags in db: %s", err) var alreadyExistsError *db.ErrAlreadyExists
if !errors.As(err, &alreadyExistsError) {
return fmt.Errorf("error putting tags in db: %s", err)
}
} }
tags = append(tags, tag.ID) tags = append(tags, tag.ID)
} }

View file

@ -138,7 +138,7 @@ func New(ctx context.Context, db db.DB) (Router, error) {
} }
// set template functions // set template functions
loadTemplateFunctions(engine) LoadTemplateFunctions(engine)
// load templates onto the engine // load templates onto the engine
if err := loadTemplates(engine); err != nil { if err := loadTemplates(engine); err != nil {

View file

@ -33,8 +33,8 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/db"
) )
// sessionOptions returns the standard set of options to use for each session. // SessionOptions returns the standard set of options to use for each session.
func sessionOptions() sessions.Options { func SessionOptions() sessions.Options {
return sessions.Options{ return sessions.Options{
Path: "/", Path: "/",
Domain: viper.GetString(config.Keys.Host), Domain: viper.GetString(config.Keys.Host),
@ -75,7 +75,7 @@ func useSession(ctx context.Context, sessionDB db.Session, engine *gin.Engine) e
} }
store := memstore.NewStore(rs.Auth, rs.Crypt) store := memstore.NewStore(rs.Auth, rs.Crypt)
store.Options(sessionOptions()) store.Options(SessionOptions())
sessionName, err := SessionName() sessionName, err := SessionName()
if err != nil { if err != nil {

View file

@ -31,7 +31,7 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/config"
) )
// loadTemplates loads html templates for use by the given engine // LoadTemplates loads html templates for use by the given engine
func loadTemplates(engine *gin.Engine) error { func loadTemplates(engine *gin.Engine) error {
cwd, err := os.Getwd() cwd, err := os.Getwd()
if err != nil { if err != nil {
@ -39,8 +39,13 @@ func loadTemplates(engine *gin.Engine) error {
} }
templateBaseDir := viper.GetString(config.Keys.WebTemplateBaseDir) templateBaseDir := viper.GetString(config.Keys.WebTemplateBaseDir)
tmPath := filepath.Join(cwd, fmt.Sprintf("%s*", templateBaseDir))
_, err = os.Stat(filepath.Join(cwd, templateBaseDir, "index.tmpl"))
if err != nil {
return fmt.Errorf("%s doesn't seem to contain the templates; index.tmpl is missing: %s", filepath.Join(cwd, templateBaseDir), err)
}
tmPath := filepath.Join(cwd, fmt.Sprintf("%s*", templateBaseDir))
engine.LoadHTMLGlob(tmPath) engine.LoadHTMLGlob(tmPath)
return nil return nil
} }
@ -87,7 +92,7 @@ func visibilityIcon(visibility model.Visibility) template.HTML {
return template.HTML(fmt.Sprintf(`<i aria-label="Visibility: %v" class="fa fa-%v"></i>`, icon.label, icon.faIcon)) return template.HTML(fmt.Sprintf(`<i aria-label="Visibility: %v" class="fa fa-%v"></i>`, icon.label, icon.faIcon))
} }
func loadTemplateFunctions(engine *gin.Engine) { func LoadTemplateFunctions(engine *gin.Engine) {
engine.SetFuncMap(template.FuncMap{ engine.SetFuncMap(template.FuncMap{
"noescape": noescape, "noescape": noescape,
"oddOrEven": oddOrEven, "oddOrEven": oddOrEven,

View file

@ -20,7 +20,14 @@ package testrig
import ( import (
"context" "context"
"fmt"
"os"
"path/filepath"
"runtime"
"github.com/gin-gonic/gin"
"github.com/spf13/viper"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/router" "github.com/superseriousbusiness/gotosocial/internal/router"
) )
@ -33,3 +40,26 @@ func NewTestRouter(db db.DB) router.Router {
} }
return r return r
} }
// ConfigureTemplatesWithGin will panic on any errors related to template loading during tests
func ConfigureTemplatesWithGin(engine *gin.Engine) {
router.LoadTemplateFunctions(engine)
// https://stackoverflow.com/questions/31873396/is-it-possible-to-get-the-current-root-of-package-structure-as-a-string-in-golan
_, runtimeCallerLocation, _, _ := runtime.Caller(0)
projectRoot, err := filepath.Abs(filepath.Join(filepath.Dir(runtimeCallerLocation), "../"))
if err != nil {
panic(err)
}
templateBaseDir := viper.GetString(config.Keys.WebTemplateBaseDir)
_, err = os.Stat(filepath.Join(projectRoot, templateBaseDir, "index.tmpl"))
if err != nil {
panic(fmt.Errorf("%s doesn't seem to contain the templates; index.tmpl is missing: %s", filepath.Join(projectRoot, templateBaseDir), err))
}
tmPath := filepath.Join(projectRoot, fmt.Sprintf("%s*", templateBaseDir))
engine.LoadHTMLGlob(tmPath)
}

View file

@ -165,6 +165,25 @@ section.login form button {
grid-column: 2; grid-column: 2;
} }
section.error {
display: flex;
flex-direction: row;
align-items: center;
}
section.error span {
font-size: 2em;
}
section.error pre {
border: 1px solid #ff000080;
margin-left: 1em;
padding: 0 0.7em;
border-radius: 0.5em;
background-color: #ff000010;
font-size: 1.3em;
white-space: pre-wrap;
}
input, select, textarea { input, select, textarea {
border: 1px solid #fafaff; border: 1px solid #fafaff;
color: #fafaff; color: #fafaff;

View file

@ -165,6 +165,24 @@ section.login {
} }
} }
section.error {
display: flex;
flex-direction: row;
align-items: center;
span {
font-size: 2em;
}
pre {
border: 1px solid #ff000080;
margin-left: 1em;
padding: 0 0.7em;
border-radius: 0.5em;
background-color: #ff000010;
font-size: 1.3em;
white-space: pre-wrap;
}
}
input, select, textarea { input, select, textarea {
border: 1px solid $fg; border: 1px solid $fg;
color: $fg; color: $fg;

View file

@ -2,7 +2,13 @@
<main> <main>
<form action="/oauth/authorize" method="POST"> <form action="/oauth/authorize" method="POST">
<h1>Hi {{.user}}!</h1> <h1>Hi {{.user}}!</h1>
<p>Application <b>{{.appname}}</b> {{if len .appwebsite | eq 0 | not}}({{.appwebsite}}) {{end}}would like to perform actions on your behalf, with scope <em>{{.scope}}</em>.</p> <p>
Application <b>{{.appname}}</b>
{{if len .appwebsite | eq 0 | not}}
({{.appwebsite}})
{{end}}
would like to perform actions on your behalf, with scope <em>{{.scope}}</em>.
</p>
<p>The application will redirect to {{.redirect}} to continue.</p> <p>The application will redirect to {{.redirect}} to continue.</p>
<p> <p>
<button <button

8
web/template/error.tmpl Normal file
View file

@ -0,0 +1,8 @@
{{ template "header.tmpl" .}}
<main>
<section class="error">
<span>❌</span> <pre>{{.error}}</pre>
</section>
</main>
{{ template "footer.tmpl" .}}