forked from mirrors/gotosocial
fce3ba6382
Signed-off-by: kim (grufwub) <grufwub@gmail.com>
504 lines
14 KiB
Go
504 lines
14 KiB
Go
package manage
|
|
|
|
import (
|
|
"context"
|
|
"time"
|
|
|
|
"github.com/superseriousbusiness/oauth2/v4"
|
|
"github.com/superseriousbusiness/oauth2/v4/errors"
|
|
"github.com/superseriousbusiness/oauth2/v4/generates"
|
|
"github.com/superseriousbusiness/oauth2/v4/models"
|
|
)
|
|
|
|
// NewDefaultManager create to default authorization management instance
|
|
func NewDefaultManager() *Manager {
|
|
m := NewManager()
|
|
// default implementation
|
|
m.MapAuthorizeGenerate(generates.NewAuthorizeGenerate())
|
|
m.MapAccessGenerate(generates.NewAccessGenerate())
|
|
|
|
return m
|
|
}
|
|
|
|
// NewManager create to authorization management instance
|
|
func NewManager() *Manager {
|
|
return &Manager{
|
|
gtcfg: make(map[oauth2.GrantType]*Config),
|
|
validateURI: DefaultValidateURI,
|
|
}
|
|
}
|
|
|
|
// Manager provide authorization management
|
|
type Manager struct {
|
|
codeExp time.Duration
|
|
gtcfg map[oauth2.GrantType]*Config
|
|
rcfg *RefreshingConfig
|
|
validateURI ValidateURIHandler
|
|
authorizeGenerate oauth2.AuthorizeGenerate
|
|
accessGenerate oauth2.AccessGenerate
|
|
tokenStore oauth2.TokenStore
|
|
clientStore oauth2.ClientStore
|
|
}
|
|
|
|
// get grant type config
|
|
func (m *Manager) grantConfig(gt oauth2.GrantType) *Config {
|
|
if c, ok := m.gtcfg[gt]; ok && c != nil {
|
|
return c
|
|
}
|
|
switch gt {
|
|
case oauth2.AuthorizationCode:
|
|
return DefaultAuthorizeCodeTokenCfg
|
|
case oauth2.Implicit:
|
|
return DefaultImplicitTokenCfg
|
|
case oauth2.PasswordCredentials:
|
|
return DefaultPasswordTokenCfg
|
|
case oauth2.ClientCredentials:
|
|
return DefaultClientTokenCfg
|
|
}
|
|
return &Config{}
|
|
}
|
|
|
|
// SetAuthorizeCodeExp set the authorization code expiration time
|
|
func (m *Manager) SetAuthorizeCodeExp(exp time.Duration) {
|
|
m.codeExp = exp
|
|
}
|
|
|
|
// SetAuthorizeCodeTokenCfg set the authorization code grant token config
|
|
func (m *Manager) SetAuthorizeCodeTokenCfg(cfg *Config) {
|
|
m.gtcfg[oauth2.AuthorizationCode] = cfg
|
|
}
|
|
|
|
// SetImplicitTokenCfg set the implicit grant token config
|
|
func (m *Manager) SetImplicitTokenCfg(cfg *Config) {
|
|
m.gtcfg[oauth2.Implicit] = cfg
|
|
}
|
|
|
|
// SetPasswordTokenCfg set the password grant token config
|
|
func (m *Manager) SetPasswordTokenCfg(cfg *Config) {
|
|
m.gtcfg[oauth2.PasswordCredentials] = cfg
|
|
}
|
|
|
|
// SetClientTokenCfg set the client grant token config
|
|
func (m *Manager) SetClientTokenCfg(cfg *Config) {
|
|
m.gtcfg[oauth2.ClientCredentials] = cfg
|
|
}
|
|
|
|
// SetRefreshTokenCfg set the refreshing token config
|
|
func (m *Manager) SetRefreshTokenCfg(cfg *RefreshingConfig) {
|
|
m.rcfg = cfg
|
|
}
|
|
|
|
// SetValidateURIHandler set the validates that RedirectURI is contained in baseURI
|
|
func (m *Manager) SetValidateURIHandler(handler ValidateURIHandler) {
|
|
m.validateURI = handler
|
|
}
|
|
|
|
// MapAuthorizeGenerate mapping the authorize code generate interface
|
|
func (m *Manager) MapAuthorizeGenerate(gen oauth2.AuthorizeGenerate) {
|
|
m.authorizeGenerate = gen
|
|
}
|
|
|
|
// MapAccessGenerate mapping the access token generate interface
|
|
func (m *Manager) MapAccessGenerate(gen oauth2.AccessGenerate) {
|
|
m.accessGenerate = gen
|
|
}
|
|
|
|
// MapClientStorage mapping the client store interface
|
|
func (m *Manager) MapClientStorage(stor oauth2.ClientStore) {
|
|
m.clientStore = stor
|
|
}
|
|
|
|
// MustClientStorage mandatory mapping the client store interface
|
|
func (m *Manager) MustClientStorage(stor oauth2.ClientStore, err error) {
|
|
if err != nil {
|
|
panic(err.Error())
|
|
}
|
|
m.clientStore = stor
|
|
}
|
|
|
|
// MapTokenStorage mapping the token store interface
|
|
func (m *Manager) MapTokenStorage(stor oauth2.TokenStore) {
|
|
m.tokenStore = stor
|
|
}
|
|
|
|
// MustTokenStorage mandatory mapping the token store interface
|
|
func (m *Manager) MustTokenStorage(stor oauth2.TokenStore, err error) {
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
m.tokenStore = stor
|
|
}
|
|
|
|
// GetClient get the client information
|
|
func (m *Manager) GetClient(ctx context.Context, clientID string) (cli oauth2.ClientInfo, err error) {
|
|
cli, err = m.clientStore.GetByID(ctx, clientID)
|
|
if err != nil {
|
|
return
|
|
} else if cli == nil {
|
|
err = errors.ErrInvalidClient
|
|
}
|
|
return
|
|
}
|
|
|
|
// GenerateAuthToken generate the authorization token(code)
|
|
func (m *Manager) GenerateAuthToken(ctx context.Context, rt oauth2.ResponseType, tgr *oauth2.TokenGenerateRequest) (oauth2.TokenInfo, error) {
|
|
cli, err := m.GetClient(ctx, tgr.ClientID)
|
|
if err != nil {
|
|
return nil, err
|
|
} else if tgr.RedirectURI != "" {
|
|
if err := m.validateURI(cli.GetDomain(), tgr.RedirectURI); err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
ti := models.NewToken()
|
|
ti.SetClientID(tgr.ClientID)
|
|
ti.SetUserID(tgr.UserID)
|
|
ti.SetRedirectURI(tgr.RedirectURI)
|
|
ti.SetScope(tgr.Scope)
|
|
|
|
createAt := time.Now()
|
|
td := &oauth2.GenerateBasic{
|
|
Client: cli,
|
|
UserID: tgr.UserID,
|
|
CreateAt: createAt,
|
|
TokenInfo: ti,
|
|
Request: tgr.Request,
|
|
}
|
|
switch rt {
|
|
case oauth2.Code:
|
|
codeExp := m.codeExp
|
|
if codeExp == 0 {
|
|
codeExp = DefaultCodeExp
|
|
}
|
|
ti.SetCodeCreateAt(createAt)
|
|
ti.SetCodeExpiresIn(codeExp)
|
|
if exp := tgr.AccessTokenExp; exp > 0 {
|
|
ti.SetAccessExpiresIn(exp)
|
|
}
|
|
if tgr.CodeChallenge != "" {
|
|
ti.SetCodeChallenge(tgr.CodeChallenge)
|
|
ti.SetCodeChallengeMethod(tgr.CodeChallengeMethod)
|
|
}
|
|
|
|
tv, err := m.authorizeGenerate.Token(ctx, td)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
ti.SetCode(tv)
|
|
case oauth2.Token:
|
|
// set access token expires
|
|
icfg := m.grantConfig(oauth2.Implicit)
|
|
aexp := icfg.AccessTokenExp
|
|
if exp := tgr.AccessTokenExp; exp > 0 {
|
|
aexp = exp
|
|
}
|
|
ti.SetAccessCreateAt(createAt)
|
|
ti.SetAccessExpiresIn(aexp)
|
|
|
|
if icfg.IsGenerateRefresh {
|
|
ti.SetRefreshCreateAt(createAt)
|
|
ti.SetRefreshExpiresIn(icfg.RefreshTokenExp)
|
|
}
|
|
|
|
tv, rv, err := m.accessGenerate.Token(ctx, td, icfg.IsGenerateRefresh)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
ti.SetAccess(tv)
|
|
|
|
if rv != "" {
|
|
ti.SetRefresh(rv)
|
|
}
|
|
}
|
|
|
|
err = m.tokenStore.Create(ctx, ti)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return ti, nil
|
|
}
|
|
|
|
// get authorization code data
|
|
func (m *Manager) getAuthorizationCode(ctx context.Context, code string) (oauth2.TokenInfo, error) {
|
|
ti, err := m.tokenStore.GetByCode(ctx, code)
|
|
if err != nil {
|
|
return nil, err
|
|
} else if ti == nil || ti.GetCode() != code || ti.GetCodeCreateAt().Add(ti.GetCodeExpiresIn()).Before(time.Now()) {
|
|
err = errors.ErrInvalidAuthorizeCode
|
|
return nil, errors.ErrInvalidAuthorizeCode
|
|
}
|
|
return ti, nil
|
|
}
|
|
|
|
// delete authorization code data
|
|
func (m *Manager) delAuthorizationCode(ctx context.Context, code string) error {
|
|
return m.tokenStore.RemoveByCode(ctx, code)
|
|
}
|
|
|
|
// get and delete authorization code data
|
|
func (m *Manager) getAndDelAuthorizationCode(ctx context.Context, tgr *oauth2.TokenGenerateRequest) (oauth2.TokenInfo, error) {
|
|
code := tgr.Code
|
|
ti, err := m.getAuthorizationCode(ctx, code)
|
|
if err != nil {
|
|
return nil, err
|
|
} else if ti.GetClientID() != tgr.ClientID {
|
|
return nil, errors.ErrInvalidAuthorizeCode
|
|
} else if codeURI := ti.GetRedirectURI(); codeURI != "" && codeURI != tgr.RedirectURI {
|
|
return nil, errors.ErrInvalidAuthorizeCode
|
|
}
|
|
|
|
err = m.delAuthorizationCode(ctx, code)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return ti, nil
|
|
}
|
|
|
|
func (m *Manager) validateCodeChallenge(ti oauth2.TokenInfo, ver string) error {
|
|
cc := ti.GetCodeChallenge()
|
|
// early return
|
|
if cc == "" && ver == "" {
|
|
return nil
|
|
}
|
|
if cc == "" {
|
|
return errors.ErrMissingCodeVerifier
|
|
}
|
|
if ver == "" {
|
|
return errors.ErrMissingCodeVerifier
|
|
}
|
|
ccm := ti.GetCodeChallengeMethod()
|
|
if ccm.String() == "" {
|
|
ccm = oauth2.CodeChallengePlain
|
|
}
|
|
if !ccm.Validate(cc, ver) {
|
|
return errors.ErrInvalidCodeChallenge
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// GenerateAccessToken generate the access token
|
|
func (m *Manager) GenerateAccessToken(ctx context.Context, gt oauth2.GrantType, tgr *oauth2.TokenGenerateRequest) (oauth2.TokenInfo, error) {
|
|
cli, err := m.GetClient(ctx, tgr.ClientID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if cliPass, ok := cli.(oauth2.ClientPasswordVerifier); ok {
|
|
if !cliPass.VerifyPassword(tgr.ClientSecret) {
|
|
return nil, errors.ErrInvalidClient
|
|
}
|
|
} else if len(cli.GetSecret()) > 0 && tgr.ClientSecret != cli.GetSecret() {
|
|
return nil, errors.ErrInvalidClient
|
|
}
|
|
if tgr.RedirectURI != "" {
|
|
if err := m.validateURI(cli.GetDomain(), tgr.RedirectURI); err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
if gt == oauth2.AuthorizationCode {
|
|
ti, err := m.getAndDelAuthorizationCode(ctx, tgr)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if err := m.validateCodeChallenge(ti, tgr.CodeVerifier); err != nil {
|
|
return nil, err
|
|
}
|
|
tgr.UserID = ti.GetUserID()
|
|
tgr.Scope = ti.GetScope()
|
|
if exp := ti.GetAccessExpiresIn(); exp > 0 {
|
|
tgr.AccessTokenExp = exp
|
|
}
|
|
}
|
|
|
|
ti := models.NewToken()
|
|
ti.SetClientID(tgr.ClientID)
|
|
ti.SetUserID(tgr.UserID)
|
|
ti.SetRedirectURI(tgr.RedirectURI)
|
|
ti.SetScope(tgr.Scope)
|
|
|
|
createAt := time.Now()
|
|
ti.SetAccessCreateAt(createAt)
|
|
|
|
// set access token expires
|
|
gcfg := m.grantConfig(gt)
|
|
aexp := gcfg.AccessTokenExp
|
|
if exp := tgr.AccessTokenExp; exp > 0 {
|
|
aexp = exp
|
|
}
|
|
ti.SetAccessExpiresIn(aexp)
|
|
if gcfg.IsGenerateRefresh {
|
|
ti.SetRefreshCreateAt(createAt)
|
|
ti.SetRefreshExpiresIn(gcfg.RefreshTokenExp)
|
|
}
|
|
|
|
td := &oauth2.GenerateBasic{
|
|
Client: cli,
|
|
UserID: tgr.UserID,
|
|
CreateAt: createAt,
|
|
TokenInfo: ti,
|
|
Request: tgr.Request,
|
|
}
|
|
|
|
av, rv, err := m.accessGenerate.Token(ctx, td, gcfg.IsGenerateRefresh)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
ti.SetAccess(av)
|
|
|
|
if rv != "" {
|
|
ti.SetRefresh(rv)
|
|
}
|
|
|
|
err = m.tokenStore.Create(ctx, ti)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return ti, nil
|
|
}
|
|
|
|
// RefreshAccessToken refreshing an access token
|
|
func (m *Manager) RefreshAccessToken(ctx context.Context, tgr *oauth2.TokenGenerateRequest) (oauth2.TokenInfo, error) {
|
|
cli, err := m.GetClient(ctx, tgr.ClientID)
|
|
if err != nil {
|
|
return nil, err
|
|
} else if cliPass, ok := cli.(oauth2.ClientPasswordVerifier); ok {
|
|
if !cliPass.VerifyPassword(tgr.ClientSecret) {
|
|
return nil, errors.ErrInvalidClient
|
|
}
|
|
} else if tgr.ClientSecret != cli.GetSecret() {
|
|
return nil, errors.ErrInvalidClient
|
|
}
|
|
|
|
ti, err := m.LoadRefreshToken(ctx, tgr.Refresh)
|
|
if err != nil {
|
|
return nil, err
|
|
} else if ti.GetClientID() != tgr.ClientID {
|
|
return nil, errors.ErrInvalidRefreshToken
|
|
}
|
|
|
|
oldAccess, oldRefresh := ti.GetAccess(), ti.GetRefresh()
|
|
|
|
td := &oauth2.GenerateBasic{
|
|
Client: cli,
|
|
UserID: ti.GetUserID(),
|
|
CreateAt: time.Now(),
|
|
TokenInfo: ti,
|
|
Request: tgr.Request,
|
|
}
|
|
|
|
rcfg := DefaultRefreshTokenCfg
|
|
if v := m.rcfg; v != nil {
|
|
rcfg = v
|
|
}
|
|
|
|
ti.SetAccessCreateAt(td.CreateAt)
|
|
if v := rcfg.AccessTokenExp; v > 0 {
|
|
ti.SetAccessExpiresIn(v)
|
|
}
|
|
|
|
if v := rcfg.RefreshTokenExp; v > 0 {
|
|
ti.SetRefreshExpiresIn(v)
|
|
}
|
|
|
|
if rcfg.IsResetRefreshTime {
|
|
ti.SetRefreshCreateAt(td.CreateAt)
|
|
}
|
|
|
|
if scope := tgr.Scope; scope != "" {
|
|
ti.SetScope(scope)
|
|
}
|
|
|
|
tv, rv, err := m.accessGenerate.Token(ctx, td, rcfg.IsGenerateRefresh)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
ti.SetAccess(tv)
|
|
if rv != "" {
|
|
ti.SetRefresh(rv)
|
|
}
|
|
|
|
if err := m.tokenStore.Create(ctx, ti); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if rcfg.IsRemoveAccess {
|
|
// remove the old access token
|
|
if err := m.tokenStore.RemoveByAccess(ctx, oldAccess); err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
if rcfg.IsRemoveRefreshing && rv != "" {
|
|
// remove the old refresh token
|
|
if err := m.tokenStore.RemoveByRefresh(ctx, oldRefresh); err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
if rv == "" {
|
|
ti.SetRefresh("")
|
|
ti.SetRefreshCreateAt(time.Now())
|
|
ti.SetRefreshExpiresIn(0)
|
|
}
|
|
|
|
return ti, nil
|
|
}
|
|
|
|
// RemoveAccessToken use the access token to delete the token information
|
|
func (m *Manager) RemoveAccessToken(ctx context.Context, access string) error {
|
|
if access == "" {
|
|
return errors.ErrInvalidAccessToken
|
|
}
|
|
return m.tokenStore.RemoveByAccess(ctx, access)
|
|
}
|
|
|
|
// RemoveRefreshToken use the refresh token to delete the token information
|
|
func (m *Manager) RemoveRefreshToken(ctx context.Context, refresh string) error {
|
|
if refresh == "" {
|
|
return errors.ErrInvalidAccessToken
|
|
}
|
|
return m.tokenStore.RemoveByRefresh(ctx, refresh)
|
|
}
|
|
|
|
// LoadAccessToken according to the access token for corresponding token information
|
|
func (m *Manager) LoadAccessToken(ctx context.Context, access string) (oauth2.TokenInfo, error) {
|
|
if access == "" {
|
|
return nil, errors.ErrInvalidAccessToken
|
|
}
|
|
|
|
ct := time.Now()
|
|
ti, err := m.tokenStore.GetByAccess(ctx, access)
|
|
if err != nil {
|
|
return nil, err
|
|
} else if ti == nil || ti.GetAccess() != access {
|
|
return nil, errors.ErrInvalidAccessToken
|
|
} else if ti.GetRefresh() != "" && ti.GetRefreshExpiresIn() != 0 &&
|
|
ti.GetRefreshCreateAt().Add(ti.GetRefreshExpiresIn()).Before(ct) {
|
|
return nil, errors.ErrExpiredRefreshToken
|
|
} else if ti.GetAccessExpiresIn() != 0 &&
|
|
ti.GetAccessCreateAt().Add(ti.GetAccessExpiresIn()).Before(ct) {
|
|
return nil, errors.ErrExpiredAccessToken
|
|
}
|
|
return ti, nil
|
|
}
|
|
|
|
// LoadRefreshToken according to the refresh token for corresponding token information
|
|
func (m *Manager) LoadRefreshToken(ctx context.Context, refresh string) (oauth2.TokenInfo, error) {
|
|
if refresh == "" {
|
|
return nil, errors.ErrInvalidRefreshToken
|
|
}
|
|
|
|
ti, err := m.tokenStore.GetByRefresh(ctx, refresh)
|
|
if err != nil {
|
|
return nil, err
|
|
} else if ti == nil || ti.GetRefresh() != refresh {
|
|
return nil, errors.ErrInvalidRefreshToken
|
|
} else if ti.GetRefreshExpiresIn() != 0 && // refresh token set to not expire
|
|
ti.GetRefreshCreateAt().Add(ti.GetRefreshExpiresIn()).Before(time.Now()) {
|
|
return nil, errors.ErrExpiredRefreshToken
|
|
}
|
|
return ti, nil
|
|
}
|