forked from mirrors/gotosocial
ec325fee14
* [chore] Update a bunch of database dependencies * fix lil thing
100 lines
2.3 KiB
Go
100 lines
2.3 KiB
Go
package pgconn
|
|
|
|
import (
|
|
"errors"
|
|
"fmt"
|
|
|
|
"github.com/jackc/pgx/v5/pgproto3"
|
|
)
|
|
|
|
// NewGSSFunc creates a GSS authentication provider, for use with
|
|
// RegisterGSSProvider.
|
|
type NewGSSFunc func() (GSS, error)
|
|
|
|
var newGSS NewGSSFunc
|
|
|
|
// RegisterGSSProvider registers a GSS authentication provider. For example, if
|
|
// you need to use Kerberos to authenticate with your server, add this to your
|
|
// main package:
|
|
//
|
|
// import "github.com/otan/gopgkrb5"
|
|
//
|
|
// func init() {
|
|
// pgconn.RegisterGSSProvider(func() (pgconn.GSS, error) { return gopgkrb5.NewGSS() })
|
|
// }
|
|
func RegisterGSSProvider(newGSSArg NewGSSFunc) {
|
|
newGSS = newGSSArg
|
|
}
|
|
|
|
// GSS provides GSSAPI authentication (e.g., Kerberos).
|
|
type GSS interface {
|
|
GetInitToken(host string, service string) ([]byte, error)
|
|
GetInitTokenFromSPN(spn string) ([]byte, error)
|
|
Continue(inToken []byte) (done bool, outToken []byte, err error)
|
|
}
|
|
|
|
func (c *PgConn) gssAuth() error {
|
|
if newGSS == nil {
|
|
return errors.New("kerberos error: no GSSAPI provider registered, see https://github.com/otan/gopgkrb5")
|
|
}
|
|
cli, err := newGSS()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
var nextData []byte
|
|
if c.config.KerberosSpn != "" {
|
|
// Use the supplied SPN if provided.
|
|
nextData, err = cli.GetInitTokenFromSPN(c.config.KerberosSpn)
|
|
} else {
|
|
// Allow the kerberos service name to be overridden
|
|
service := "postgres"
|
|
if c.config.KerberosSrvName != "" {
|
|
service = c.config.KerberosSrvName
|
|
}
|
|
nextData, err = cli.GetInitToken(c.config.Host, service)
|
|
}
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
for {
|
|
gssResponse := &pgproto3.GSSResponse{
|
|
Data: nextData,
|
|
}
|
|
c.frontend.Send(gssResponse)
|
|
err = c.frontend.Flush()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
resp, err := c.rxGSSContinue()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
var done bool
|
|
done, nextData, err = cli.Continue(resp.Data)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if done {
|
|
break
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (c *PgConn) rxGSSContinue() (*pgproto3.AuthenticationGSSContinue, error) {
|
|
msg, err := c.receiveMessage()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
switch m := msg.(type) {
|
|
case *pgproto3.AuthenticationGSSContinue:
|
|
return m, nil
|
|
case *pgproto3.ErrorResponse:
|
|
return nil, ErrorResponseToPgError(m)
|
|
}
|
|
|
|
return nil, fmt.Errorf("expected AuthenticationGSSContinue message but received unexpected message %T", msg)
|
|
}
|