Enhance token checking (#3842)

This commit is contained in:
Anbraten 2024-06-27 00:08:59 +02:00 committed by GitHub
parent ea8976bf88
commit b8b6efb352
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 106 additions and 31 deletions

View file

@ -178,7 +178,7 @@ func PostHook(c *gin.Context) {
// //
// get the token and verify the hook is authorized // get the token and verify the hook is authorized
parsedToken, err := token.ParseRequest(c.Request, func(_ *token.Token) (string, error) { parsedToken, err := token.ParseRequest([]token.Type{token.HookToken}, c.Request, func(_ *token.Token) (string, error) {
return repo.Hash, nil return repo.Hash, nil
}) })
if err != nil { if err != nil {

View file

@ -31,17 +31,14 @@ func AuthorizeAgent(c *gin.Context) {
return return
} }
parsed, err := token.ParseRequest(c.Request, func(_ *token.Token) (string, error) { _, err := token.ParseRequest([]token.Type{token.AgentToken}, c.Request, func(_ *token.Token) (string, error) {
return secret, nil return secret, nil
}) })
switch { if err != nil {
case err != nil:
c.String(http.StatusInternalServerError, "invalid or empty token. %s", err) c.String(http.StatusInternalServerError, "invalid or empty token. %s", err)
c.Abort() c.Abort()
case parsed.Kind != token.AgentToken: return
c.String(http.StatusForbidden, "invalid token. please use an agent token")
c.Abort()
default:
c.Next()
} }
c.Next()
} }

View file

@ -43,7 +43,7 @@ func SetUser() gin.HandlerFunc {
return func(c *gin.Context) { return func(c *gin.Context) {
var user *model.User var user *model.User
t, err := token.ParseRequest(c.Request, func(t *token.Token) (string, error) { t, err := token.ParseRequest([]token.Type{token.UserToken, token.SessToken}, c.Request, func(t *token.Token) (string, error) {
var err error var err error
userID, err := strconv.ParseInt(t.Get("user-id"), 10, 64) userID, err := strconv.ParseInt(t.Get("user-id"), 10, 64)
if err != nil { if err != nil {
@ -58,7 +58,7 @@ func SetUser() gin.HandlerFunc {
// if this is a session token (ie not the API token) // if this is a session token (ie not the API token)
// this means the user is accessing with a web browser, // this means the user is accessing with a web browser,
// so we should implement CSRF protection measures. // so we should implement CSRF protection measures.
if t.Kind == token.SessToken { if t.Type == token.SessToken {
err = token.CheckCsrf(c.Request, func(_ *token.Token) (string, error) { err = token.CheckCsrf(c.Request, func(_ *token.Token) (string, error) {
return user.Hash, nil return user.Hash, nil
}) })

View file

@ -24,36 +24,52 @@ import (
type SecretFunc func(*Token) (string, error) type SecretFunc func(*Token) (string, error)
type Type string
const ( const (
UserToken = "user" UserToken Type = "user" // user token (exp cli)
SessToken = "sess" SessToken Type = "sess" // session token (ui token requires csrf check)
HookToken = "hook" HookToken Type = "hook" // repo hook token
CsrfToken = "csrf" CsrfToken Type = "csrf"
AgentToken = "agent" AgentToken Type = "agent"
) )
// SignerAlgo id default algorithm used to sign JWT tokens. // SignerAlgo id default algorithm used to sign JWT tokens.
const SignerAlgo = "HS256" const SignerAlgo = "HS256"
type Token struct { type Token struct {
Kind string Type Type
claims jwt.MapClaims claims jwt.MapClaims
} }
func parse(raw string, fn SecretFunc) (*Token, error) { func Parse(allowedTypes []Type, raw string, fn SecretFunc) (*Token, error) {
token := &Token{ token := &Token{
claims: jwt.MapClaims{}, claims: jwt.MapClaims{},
} }
parsed, err := jwt.Parse(raw, keyFunc(token, fn)) parsed, err := jwt.Parse(raw, keyFunc(token, fn))
if err != nil { if err != nil {
return nil, err return nil, err
} else if !parsed.Valid { }
if !parsed.Valid {
return nil, jwt.ErrTokenUnverifiable return nil, jwt.ErrTokenUnverifiable
} }
hasAllowedType := false
for _, k := range allowedTypes {
if k == token.Type {
hasAllowedType = true
break
}
}
if !hasAllowedType {
return nil, jwt.ErrInvalidType
}
return token, nil return token, nil
} }
func ParseRequest(r *http.Request, fn SecretFunc) (*Token, error) { func ParseRequest(allowedTypes []Type, r *http.Request, fn SecretFunc) (*Token, error) {
// first we attempt to get the token from the // first we attempt to get the token from the
// authorization header. // authorization header.
token := r.Header.Get("Authorization") token := r.Header.Get("Authorization")
@ -63,19 +79,19 @@ func ParseRequest(r *http.Request, fn SecretFunc) (*Token, error) {
if _, err := fmt.Sscanf(token, "Bearer %s", &bearer); err != nil { if _, err := fmt.Sscanf(token, "Bearer %s", &bearer); err != nil {
return nil, err return nil, err
} }
return parse(bearer, fn) return Parse(allowedTypes, bearer, fn)
} }
token = r.Header.Get("X-Gitlab-Token") token = r.Header.Get("X-Gitlab-Token")
if len(token) != 0 { if len(token) != 0 {
return parse(token, fn) return Parse(allowedTypes, token, fn)
} }
// then we attempt to get the token from the // then we attempt to get the token from the
// access_token url query parameter // access_token url query parameter
token = r.FormValue("access_token") token = r.FormValue("access_token")
if len(token) != 0 { if len(token) != 0 {
return parse(token, fn) return Parse(allowedTypes, token, fn)
} }
// and finally we attempt to get the token from // and finally we attempt to get the token from
@ -84,7 +100,7 @@ func ParseRequest(r *http.Request, fn SecretFunc) (*Token, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
return parse(cookie.Value, fn) return Parse(allowedTypes, cookie.Value, fn)
} }
func CheckCsrf(r *http.Request, fn SecretFunc) error { func CheckCsrf(r *http.Request, fn SecretFunc) error {
@ -97,12 +113,12 @@ func CheckCsrf(r *http.Request, fn SecretFunc) error {
// parse the raw CSRF token value and validate // parse the raw CSRF token value and validate
raw := r.Header.Get("X-CSRF-TOKEN") raw := r.Header.Get("X-CSRF-TOKEN")
_, err := parse(raw, fn) _, err := Parse([]Type{CsrfToken}, raw, fn)
return err return err
} }
func New(kind string) *Token { func New(tokenType Type) *Token {
return &Token{Kind: kind, claims: jwt.MapClaims{}} return &Token{Type: tokenType, claims: jwt.MapClaims{}}
} }
// Sign signs the token using the given secret hash // Sign signs the token using the given secret hash
@ -124,7 +140,7 @@ func (t *Token) SignExpires(secret string, exp int64) (string, error) {
claims[k] = v claims[k] = v
} }
claims["type"] = t.Kind claims["type"] = t.Type
if exp > 0 { if exp > 0 {
claims["exp"] = float64(exp) claims["exp"] = float64(exp)
} }
@ -157,12 +173,12 @@ func keyFunc(token *Token, fn SecretFunc) jwt.Keyfunc {
return nil, jwt.ErrSignatureInvalid return nil, jwt.ErrSignatureInvalid
} }
// extract the token kind and cast to the expected type // extract the token type and cast to the expected type
kind, ok := claims["type"] tokenType, ok := claims["type"].(string)
if !ok { if !ok {
return nil, jwt.ErrInvalidType return nil, jwt.ErrInvalidType
} }
token.Kind, _ = kind.(string) token.Type = Type(tokenType)
// copy custom claims // copy custom claims
for k, v := range claims { for k, v := range claims {

View file

@ -0,0 +1,62 @@
package token_test
import (
"testing"
"github.com/franela/goblin"
"github.com/gin-gonic/gin"
"github.com/golang-jwt/jwt/v5"
"github.com/stretchr/testify/assert"
"go.woodpecker-ci.org/woodpecker/v2/shared/token"
)
func TestToken(t *testing.T) {
gin.SetMode(gin.TestMode)
g := goblin.Goblin(t)
g.Describe("Token", func() {
jwtSecret := "secret-to-sign-the-token"
g.It("should parse a valid token", func() {
_token := token.New(token.UserToken)
_token.Set("user-id", "1")
signedToken, err := _token.Sign(jwtSecret)
assert.NoError(g, err)
parsed, err := token.Parse([]token.Type{token.UserToken}, signedToken, func(_ *token.Token) (string, error) {
return jwtSecret, nil
})
assert.NoError(g, err)
assert.NotNil(g, parsed)
assert.Equal(g, "1", parsed.Get("user-id"))
})
g.It("should fail to parse a token with a wrong type", func() {
_token := token.New(token.UserToken)
_token.Set("user-id", "1")
signedToken, err := _token.Sign(jwtSecret)
assert.NoError(g, err)
_, err = token.Parse([]token.Type{token.AgentToken}, signedToken, func(_ *token.Token) (string, error) {
return jwtSecret, nil
})
assert.ErrorIs(g, err, jwt.ErrInvalidType)
})
g.It("should fail to parse a token with a wrong secret", func() {
_token := token.New(token.UserToken)
_token.Set("user-id", "1")
signedToken, err := _token.Sign(jwtSecret)
assert.NoError(g, err)
_, err = token.Parse([]token.Type{token.UserToken}, signedToken, func(_ *token.Token) (string, error) {
return "this-is-a-wrong-secret", nil
})
assert.ErrorIs(g, err, jwt.ErrSignatureInvalid)
})
})
}