diff --git a/server/api/hook.go b/server/api/hook.go index 40a93ef35..0ed75a4f3 100644 --- a/server/api/hook.go +++ b/server/api/hook.go @@ -178,7 +178,7 @@ func PostHook(c *gin.Context) { // // get the token and verify the hook is authorized - parsedToken, err := token.ParseRequest(c.Request, func(_ *token.Token) (string, error) { + parsedToken, err := token.ParseRequest([]token.Type{token.HookToken}, c.Request, func(_ *token.Token) (string, error) { return repo.Hash, nil }) if err != nil { diff --git a/server/router/middleware/session/agent.go b/server/router/middleware/session/agent.go index cacaf0d44..0cbfa2a33 100644 --- a/server/router/middleware/session/agent.go +++ b/server/router/middleware/session/agent.go @@ -31,17 +31,14 @@ func AuthorizeAgent(c *gin.Context) { return } - parsed, err := token.ParseRequest(c.Request, func(_ *token.Token) (string, error) { + _, err := token.ParseRequest([]token.Type{token.AgentToken}, c.Request, func(_ *token.Token) (string, error) { return secret, nil }) - switch { - case err != nil: + if err != nil { c.String(http.StatusInternalServerError, "invalid or empty token. %s", err) c.Abort() - case parsed.Kind != token.AgentToken: - c.String(http.StatusForbidden, "invalid token. please use an agent token") - c.Abort() - default: - c.Next() + return } + + c.Next() } diff --git a/server/router/middleware/session/user.go b/server/router/middleware/session/user.go index b413cf528..2f098b5cb 100644 --- a/server/router/middleware/session/user.go +++ b/server/router/middleware/session/user.go @@ -43,7 +43,7 @@ func SetUser() gin.HandlerFunc { return func(c *gin.Context) { var user *model.User - t, err := token.ParseRequest(c.Request, func(t *token.Token) (string, error) { + t, err := token.ParseRequest([]token.Type{token.UserToken, token.SessToken}, c.Request, func(t *token.Token) (string, error) { var err error userID, err := strconv.ParseInt(t.Get("user-id"), 10, 64) if err != nil { @@ -58,7 +58,7 @@ func SetUser() gin.HandlerFunc { // if this is a session token (ie not the API token) // this means the user is accessing with a web browser, // so we should implement CSRF protection measures. - if t.Kind == token.SessToken { + if t.Type == token.SessToken { err = token.CheckCsrf(c.Request, func(_ *token.Token) (string, error) { return user.Hash, nil }) diff --git a/shared/token/token.go b/shared/token/token.go index 545107f59..f893babfe 100644 --- a/shared/token/token.go +++ b/shared/token/token.go @@ -24,36 +24,52 @@ import ( type SecretFunc func(*Token) (string, error) +type Type string + const ( - UserToken = "user" - SessToken = "sess" - HookToken = "hook" - CsrfToken = "csrf" - AgentToken = "agent" + UserToken Type = "user" // user token (exp cli) + SessToken Type = "sess" // session token (ui token requires csrf check) + HookToken Type = "hook" // repo hook token + CsrfToken Type = "csrf" + AgentToken Type = "agent" ) // SignerAlgo id default algorithm used to sign JWT tokens. const SignerAlgo = "HS256" type Token struct { - Kind string + Type Type claims jwt.MapClaims } -func parse(raw string, fn SecretFunc) (*Token, error) { +func Parse(allowedTypes []Type, raw string, fn SecretFunc) (*Token, error) { token := &Token{ claims: jwt.MapClaims{}, } parsed, err := jwt.Parse(raw, keyFunc(token, fn)) if err != nil { return nil, err - } else if !parsed.Valid { + } + if !parsed.Valid { return nil, jwt.ErrTokenUnverifiable } + + hasAllowedType := false + for _, k := range allowedTypes { + if k == token.Type { + hasAllowedType = true + break + } + } + + if !hasAllowedType { + return nil, jwt.ErrInvalidType + } + return token, nil } -func ParseRequest(r *http.Request, fn SecretFunc) (*Token, error) { +func ParseRequest(allowedTypes []Type, r *http.Request, fn SecretFunc) (*Token, error) { // first we attempt to get the token from the // authorization header. token := r.Header.Get("Authorization") @@ -63,19 +79,19 @@ func ParseRequest(r *http.Request, fn SecretFunc) (*Token, error) { if _, err := fmt.Sscanf(token, "Bearer %s", &bearer); err != nil { return nil, err } - return parse(bearer, fn) + return Parse(allowedTypes, bearer, fn) } token = r.Header.Get("X-Gitlab-Token") if len(token) != 0 { - return parse(token, fn) + return Parse(allowedTypes, token, fn) } // then we attempt to get the token from the // access_token url query parameter token = r.FormValue("access_token") if len(token) != 0 { - return parse(token, fn) + return Parse(allowedTypes, token, fn) } // and finally we attempt to get the token from @@ -84,7 +100,7 @@ func ParseRequest(r *http.Request, fn SecretFunc) (*Token, error) { if err != nil { return nil, err } - return parse(cookie.Value, fn) + return Parse(allowedTypes, cookie.Value, fn) } func CheckCsrf(r *http.Request, fn SecretFunc) error { @@ -97,12 +113,12 @@ func CheckCsrf(r *http.Request, fn SecretFunc) error { // parse the raw CSRF token value and validate raw := r.Header.Get("X-CSRF-TOKEN") - _, err := parse(raw, fn) + _, err := Parse([]Type{CsrfToken}, raw, fn) return err } -func New(kind string) *Token { - return &Token{Kind: kind, claims: jwt.MapClaims{}} +func New(tokenType Type) *Token { + return &Token{Type: tokenType, claims: jwt.MapClaims{}} } // Sign signs the token using the given secret hash @@ -124,7 +140,7 @@ func (t *Token) SignExpires(secret string, exp int64) (string, error) { claims[k] = v } - claims["type"] = t.Kind + claims["type"] = t.Type if exp > 0 { claims["exp"] = float64(exp) } @@ -157,12 +173,12 @@ func keyFunc(token *Token, fn SecretFunc) jwt.Keyfunc { return nil, jwt.ErrSignatureInvalid } - // extract the token kind and cast to the expected type - kind, ok := claims["type"] + // extract the token type and cast to the expected type + tokenType, ok := claims["type"].(string) if !ok { return nil, jwt.ErrInvalidType } - token.Kind, _ = kind.(string) + token.Type = Type(tokenType) // copy custom claims for k, v := range claims { diff --git a/shared/token/token_test.go b/shared/token/token_test.go new file mode 100644 index 000000000..fb77ef609 --- /dev/null +++ b/shared/token/token_test.go @@ -0,0 +1,62 @@ +package token_test + +import ( + "testing" + + "github.com/franela/goblin" + "github.com/gin-gonic/gin" + "github.com/golang-jwt/jwt/v5" + "github.com/stretchr/testify/assert" + + "go.woodpecker-ci.org/woodpecker/v2/shared/token" +) + +func TestToken(t *testing.T) { + gin.SetMode(gin.TestMode) + + g := goblin.Goblin(t) + g.Describe("Token", func() { + jwtSecret := "secret-to-sign-the-token" + + g.It("should parse a valid token", func() { + _token := token.New(token.UserToken) + _token.Set("user-id", "1") + signedToken, err := _token.Sign(jwtSecret) + assert.NoError(g, err) + + parsed, err := token.Parse([]token.Type{token.UserToken}, signedToken, func(_ *token.Token) (string, error) { + return jwtSecret, nil + }) + + assert.NoError(g, err) + assert.NotNil(g, parsed) + assert.Equal(g, "1", parsed.Get("user-id")) + }) + + g.It("should fail to parse a token with a wrong type", func() { + _token := token.New(token.UserToken) + _token.Set("user-id", "1") + signedToken, err := _token.Sign(jwtSecret) + assert.NoError(g, err) + + _, err = token.Parse([]token.Type{token.AgentToken}, signedToken, func(_ *token.Token) (string, error) { + return jwtSecret, nil + }) + + assert.ErrorIs(g, err, jwt.ErrInvalidType) + }) + + g.It("should fail to parse a token with a wrong secret", func() { + _token := token.New(token.UserToken) + _token.Set("user-id", "1") + signedToken, err := _token.Sign(jwtSecret) + assert.NoError(g, err) + + _, err = token.Parse([]token.Type{token.UserToken}, signedToken, func(_ *token.Token) (string, error) { + return "this-is-a-wrong-secret", nil + }) + + assert.ErrorIs(g, err, jwt.ErrSignatureInvalid) + }) + }) +}