[REFACTOR] PKT protocol

- Use `Fprintf` to convert to hex and do padding. Simplifies the code.
- Use `Read()` and `io.ReadFull` instead of `ReadByte()`. Should improve
performance and allows for cleaner code.
- s/pktLineTypeUnknow/pktLineTypeUnknown.
- Disallow empty Pkt line per the specification.
- Disallow too large Pkt line per the specification.
- Add unit tests.
This commit is contained in:
Gusted 2024-03-29 00:20:21 +01:00
parent a11116602e
commit 2c8bcc163e
No known key found for this signature in database
GPG key ID: FD821B732837125F
3 changed files with 93 additions and 52 deletions

View file

@ -583,7 +583,7 @@ Forgejo or set your environment appropriately.`, "")
for {
// note: pktLineTypeUnknow means pktLineTypeFlush and pktLineTypeData all allowed
rs, err = readPktLine(ctx, reader, pktLineTypeUnknow)
rs, err = readPktLine(ctx, reader, pktLineTypeUnknown)
if err != nil {
return err
}
@ -604,7 +604,7 @@ Forgejo or set your environment appropriately.`, "")
if hasPushOptions {
for {
rs, err = readPktLine(ctx, reader, pktLineTypeUnknow)
rs, err = readPktLine(ctx, reader, pktLineTypeUnknown)
if err != nil {
return err
}
@ -699,8 +699,8 @@ Forgejo or set your environment appropriately.`, "")
type pktLineType int64
const (
// UnKnow type
pktLineTypeUnknow pktLineType = 0
// Unknown type
pktLineTypeUnknown pktLineType = 0
// flush-pkt "0000"
pktLineTypeFlush pktLineType = iota
// data line
@ -714,22 +714,16 @@ type gitPktLine struct {
Data []byte
}
// Reads an Pkt-Line from `in`. If requestType is not unknown, it will a
func readPktLine(ctx context.Context, in *bufio.Reader, requestType pktLineType) (*gitPktLine, error) {
var (
err error
r *gitPktLine
)
// read prefix
// Read length prefix
lengthBytes := make([]byte, 4)
for i := 0; i < 4; i++ {
lengthBytes[i], err = in.ReadByte()
if err != nil {
return nil, fail(ctx, "Protocol: stdin error", "Pkt-Line: read stdin failed : %v", err)
}
if n, err := in.Read(lengthBytes); n != 4 || err != nil {
return nil, fail(ctx, "Protocol: stdin error", "Pkt-Line: read stdin failed : %v", err)
}
r = new(gitPktLine)
var err error
r := &gitPktLine{}
r.Length, err = strconv.ParseUint(string(lengthBytes), 16, 32)
if err != nil {
return nil, fail(ctx, "Protocol: format parse error", "Pkt-Line format is wrong :%v", err)
@ -748,11 +742,8 @@ func readPktLine(ctx context.Context, in *bufio.Reader, requestType pktLineType)
}
r.Data = make([]byte, r.Length-4)
for i := range r.Data {
r.Data[i], err = in.ReadByte()
if err != nil {
return nil, fail(ctx, "Protocol: data error", "Pkt-Line: read stdin failed : %v", err)
}
if n, err := io.ReadFull(in, r.Data); uint64(n) != r.Length-4 || err != nil {
return nil, fail(ctx, "Protocol: stdin error", "Pkt-Line: read stdin failed : %v", err)
}
r.Type = pktLineTypeData
@ -768,20 +759,23 @@ func writeFlushPktLine(ctx context.Context, out io.Writer) error {
return nil
}
// Write an Pkt-Line based on `data` to `out` according to the specifcation.
// https://git-scm.com/docs/protocol-common
func writeDataPktLine(ctx context.Context, out io.Writer, data []byte) error {
hexchar := []byte("0123456789abcdef")
hex := func(n uint64) byte {
return hexchar[(n)&15]
// Implementations SHOULD NOT send an empty pkt-line ("0004").
if len(data) == 0 {
return fail(ctx, "Protocol: write error", "Not allowed to write empty Pkt-Line")
}
length := uint64(len(data) + 4)
tmp := make([]byte, 4)
tmp[0] = hex(length >> 12)
tmp[1] = hex(length >> 8)
tmp[2] = hex(length >> 4)
tmp[3] = hex(length)
lr, err := out.Write(tmp)
// The maximum length of a pkt-lines data component is 65516 bytes.
// Implementations MUST NOT send pkt-line whose length exceeds 65520 (65516 bytes of payload + 4 bytes of length data).
if length > 65520 {
return fail(ctx, "Protocol: write error", "Pkt-Line exceeds maximum of 65520 bytes")
}
lr, err := fmt.Fprintf(out, "%04x", length)
if err != nil || lr != 4 {
return fail(ctx, "Protocol: write error", "Pkt-Line response failed: %v", err)
}

View file

@ -14,29 +14,72 @@ import (
)
func TestPktLine(t *testing.T) {
// test read
ctx := context.Background()
s := strings.NewReader("0000")
r := bufio.NewReader(s)
result, err := readPktLine(ctx, r, pktLineTypeFlush)
assert.NoError(t, err)
assert.Equal(t, pktLineTypeFlush, result.Type)
s = strings.NewReader("0006a\n")
r = bufio.NewReader(s)
result, err = readPktLine(ctx, r, pktLineTypeData)
assert.NoError(t, err)
assert.Equal(t, pktLineTypeData, result.Type)
assert.Equal(t, []byte("a\n"), result.Data)
t.Run("Read", func(t *testing.T) {
s := strings.NewReader("0000")
r := bufio.NewReader(s)
result, err := readPktLine(ctx, r, pktLineTypeFlush)
assert.NoError(t, err)
assert.Equal(t, pktLineTypeFlush, result.Type)
// test write
w := bytes.NewBuffer([]byte{})
err = writeFlushPktLine(ctx, w)
assert.NoError(t, err)
assert.Equal(t, []byte("0000"), w.Bytes())
s = strings.NewReader("0006a\n")
r = bufio.NewReader(s)
result, err = readPktLine(ctx, r, pktLineTypeData)
assert.NoError(t, err)
assert.Equal(t, pktLineTypeData, result.Type)
assert.Equal(t, []byte("a\n"), result.Data)
w.Reset()
err = writeDataPktLine(ctx, w, []byte("a\nb"))
assert.NoError(t, err)
assert.Equal(t, []byte("0007a\nb"), w.Bytes())
s = strings.NewReader("0004")
r = bufio.NewReader(s)
result, err = readPktLine(ctx, r, pktLineTypeData)
assert.Error(t, err)
assert.Nil(t, result)
data := strings.Repeat("x", 65516)
r = bufio.NewReader(strings.NewReader("fff0" + data))
result, err = readPktLine(ctx, r, pktLineTypeData)
assert.NoError(t, err)
assert.Equal(t, pktLineTypeData, result.Type)
assert.Equal(t, []byte(data), result.Data)
r = bufio.NewReader(strings.NewReader("fff1a"))
result, err = readPktLine(ctx, r, pktLineTypeData)
assert.Error(t, err)
assert.Nil(t, result)
})
t.Run("Write", func(t *testing.T) {
w := bytes.NewBuffer([]byte{})
err := writeFlushPktLine(ctx, w)
assert.NoError(t, err)
assert.Equal(t, []byte("0000"), w.Bytes())
w.Reset()
err = writeDataPktLine(ctx, w, []byte("a\nb"))
assert.NoError(t, err)
assert.Equal(t, []byte("0007a\nb"), w.Bytes())
w.Reset()
data := bytes.Repeat([]byte{0x05}, 288)
err = writeDataPktLine(ctx, w, data)
assert.NoError(t, err)
assert.Equal(t, append([]byte("0124"), data...), w.Bytes())
w.Reset()
err = writeDataPktLine(ctx, w, nil)
assert.Error(t, err)
assert.Empty(t, w.Bytes())
w.Reset()
data = bytes.Repeat([]byte{0x64}, 65516)
err = writeDataPktLine(ctx, w, data)
assert.NoError(t, err)
assert.Equal(t, append([]byte("fff0"), data...), w.Bytes())
w.Reset()
err = writeDataPktLine(ctx, w, bytes.Repeat([]byte{0x64}, 65516+1))
assert.Error(t, err)
assert.Empty(t, w.Bytes())
})
}

View file

@ -14,6 +14,7 @@ import (
"regexp"
"strconv"
"strings"
"testing"
"time"
"unicode"
@ -106,7 +107,10 @@ func fail(ctx context.Context, userMessage, logMsgFmt string, args ...any) error
logMsg = userMessage + ". " + logMsg
}
}
_ = private.SSHLog(ctx, true, logMsg)
// Don't send an log if this is done in a test and no InternalToken is set.
if !testing.Testing() || setting.InternalToken != "" {
_ = private.SSHLog(ctx, true, logMsg)
}
}
return cli.Exit("", 1)
}