gotosocial/vendor/mellium.im/sasl/scram.go
Tobi Smethurst 98263a7de6
Grand test fixup (#138)
* start fixing up tests

* fix up tests + automate with drone

* fiddle with linting

* messing about with drone.yml

* some more fiddling

* hmmm

* add cache

* add vendor directory

* verbose

* ci updates

* update some little things

* update sig
2021-08-12 21:03:24 +02:00

227 lines
6.6 KiB
Go

// Copyright 2016 The Mellium Contributors.
// Use of this source code is governed by the BSD 2-clause license that can be
// found in the LICENSE file.
package sasl
import (
"bytes"
"crypto/hmac"
"encoding/base64"
"errors"
"hash"
"strconv"
"strings"
"golang.org/x/crypto/pbkdf2"
)
const (
gs2HeaderCBSupport = "p=tls-unique,"
gs2HeaderNoServerCBSupport = "y,"
gs2HeaderNoCBSupport = "n,"
)
var (
clientKeyInput = []byte("Client Key")
serverKeyInput = []byte("Server Key")
)
// The number of random bytes to generate for a nonce.
const noncerandlen = 16
func getGS2Header(name string, n *Negotiator) (gs2Header []byte) {
_, _, identity := n.Credentials()
switch {
case n.TLSState() == nil || !strings.HasSuffix(name, "-PLUS"):
// We do not support channel binding
gs2Header = []byte(gs2HeaderNoCBSupport)
case n.State()&RemoteCB == RemoteCB:
// We support channel binding and the server does too
gs2Header = []byte(gs2HeaderCBSupport)
case n.State()&RemoteCB != RemoteCB:
// We support channel binding but the server does not
gs2Header = []byte(gs2HeaderNoServerCBSupport)
}
if len(identity) > 0 {
gs2Header = append(gs2Header, []byte(`a=`)...)
gs2Header = append(gs2Header, identity...)
}
gs2Header = append(gs2Header, ',')
return
}
func scram(name string, fn func() hash.Hash) Mechanism {
// BUG(ssw): We need a way to cache the SCRAM client and server key
// calculations.
return Mechanism{
Name: name,
Start: func(m *Negotiator) (bool, []byte, interface{}, error) {
user, _, _ := m.Credentials()
// Escape "=" and ",". This is mostly the same as bytes.Replace but
// faster because we can do both replacements in a single pass.
n := bytes.Count(user, []byte{'='}) + bytes.Count(user, []byte{','})
username := make([]byte, len(user)+(n*2))
w := 0
start := 0
for i := 0; i < n; i++ {
j := start
j += bytes.IndexAny(user[start:], "=,")
w += copy(username[w:], user[start:j])
switch user[j] {
case '=':
w += copy(username[w:], "=3D")
case ',':
w += copy(username[w:], "=2C")
}
start = j + 1
}
copy(username[w:], user[start:])
clientFirstMessage := make([]byte, 5+len(m.Nonce())+len(username))
copy(clientFirstMessage, "n=")
copy(clientFirstMessage[2:], username)
copy(clientFirstMessage[2+len(username):], ",r=")
copy(clientFirstMessage[5+len(username):], m.Nonce())
return true, append(getGS2Header(name, m), clientFirstMessage...), clientFirstMessage, nil
},
Next: func(m *Negotiator, challenge []byte, data interface{}) (more bool, resp []byte, cache interface{}, err error) {
if challenge == nil || len(challenge) == 0 {
return more, resp, cache, ErrInvalidChallenge
}
if m.State()&Receiving == Receiving {
panic("not yet implemented")
}
return scramClientNext(name, fn, m, challenge, data)
},
}
}
func scramClientNext(name string, fn func() hash.Hash, m *Negotiator, challenge []byte, data interface{}) (more bool, resp []byte, cache interface{}, err error) {
_, password, _ := m.Credentials()
state := m.State()
switch state & StepMask {
case AuthTextSent:
iter := -1
var salt, nonce []byte
for _, field := range bytes.Split(challenge, []byte{','}) {
if len(field) < 3 || (len(field) >= 2 && field[1] != '=') {
continue
}
switch field[0] {
case 'i':
ival := string(bytes.TrimRight(field[2:], "\x00"))
if iter, err = strconv.Atoi(ival); err != nil {
return
}
case 's':
salt = make([]byte, base64.StdEncoding.DecodedLen(len(field)-2))
var n int
n, err = base64.StdEncoding.Decode(salt, field[2:])
salt = salt[:n]
if err != nil {
return
}
case 'r':
nonce = field[2:]
case 'm':
// RFC 5802:
// m: This attribute is reserved for future extensibility. In this
// version of SCRAM, its presence in a client or a server message
// MUST cause authentication failure when the attribute is parsed by
// the other end.
err = errors.New("Server sent reserved attribute `m'")
return
}
}
switch {
case iter < 0:
err = errors.New("Iteration count is missing")
return
case iter < 0:
err = errors.New("Iteration count is invalid")
return
case nonce == nil || !bytes.HasPrefix(nonce, m.Nonce()):
err = errors.New("Server nonce does not match client nonce")
return
case salt == nil:
err = errors.New("Server sent empty salt")
return
}
gs2Header := getGS2Header(name, m)
tlsState := m.TLSState()
var channelBinding []byte
if tlsState != nil && strings.HasSuffix(name, "-PLUS") {
channelBinding = make(
[]byte,
2+base64.StdEncoding.EncodedLen(len(gs2Header)+len(tlsState.TLSUnique)),
)
base64.StdEncoding.Encode(channelBinding[2:], append(gs2Header, tlsState.TLSUnique...))
channelBinding[0] = 'c'
channelBinding[1] = '='
} else {
channelBinding = make(
[]byte,
2+base64.StdEncoding.EncodedLen(len(gs2Header)),
)
base64.StdEncoding.Encode(channelBinding[2:], gs2Header)
channelBinding[0] = 'c'
channelBinding[1] = '='
}
clientFinalMessageWithoutProof := append(channelBinding, []byte(",r=")...)
clientFinalMessageWithoutProof = append(clientFinalMessageWithoutProof, nonce...)
clientFirstMessage := data.([]byte)
authMessage := append(clientFirstMessage, ',')
authMessage = append(authMessage, challenge...)
authMessage = append(authMessage, ',')
authMessage = append(authMessage, clientFinalMessageWithoutProof...)
saltedPassword := pbkdf2.Key(password, salt, iter, fn().Size(), fn)
h := hmac.New(fn, saltedPassword)
h.Write(serverKeyInput)
serverKey := h.Sum(nil)
h.Reset()
h.Write(clientKeyInput)
clientKey := h.Sum(nil)
h = hmac.New(fn, serverKey)
h.Write(authMessage)
serverSignature := h.Sum(nil)
h = fn()
h.Write(clientKey)
storedKey := h.Sum(nil)
h = hmac.New(fn, storedKey)
h.Write(authMessage)
clientSignature := h.Sum(nil)
clientProof := make([]byte, len(clientKey))
xorBytes(clientProof, clientKey, clientSignature)
encodedClientProof := make([]byte, base64.StdEncoding.EncodedLen(len(clientProof)))
base64.StdEncoding.Encode(encodedClientProof, clientProof)
clientFinalMessage := append(clientFinalMessageWithoutProof, []byte(",p=")...)
clientFinalMessage = append(clientFinalMessage, encodedClientProof...)
return true, clientFinalMessage, serverSignature, nil
case ResponseSent:
clientCalculatedServerFinalMessage := "v=" + base64.StdEncoding.EncodeToString(data.([]byte))
if clientCalculatedServerFinalMessage != string(challenge) {
err = ErrAuthn
return
}
// Success!
return false, nil, nil, nil
}
err = ErrInvalidState
return
}