[chore]: Bump github.com/jackc/pgx/v5 from 5.5.3 to 5.5.5 (#2747)

This commit is contained in:
dependabot[bot] 2024-03-11 10:13:33 +00:00 committed by GitHub
parent e24efcac8b
commit d115f9ebc4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
67 changed files with 515 additions and 385 deletions

2
go.mod
View file

@ -37,7 +37,7 @@ require (
github.com/gorilla/feeds v1.1.2 github.com/gorilla/feeds v1.1.2
github.com/gorilla/websocket v1.5.1 github.com/gorilla/websocket v1.5.1
github.com/h2non/filetype v1.1.3 github.com/h2non/filetype v1.1.3
github.com/jackc/pgx/v5 v5.5.3 github.com/jackc/pgx/v5 v5.5.5
github.com/microcosm-cc/bluemonday v1.0.26 github.com/microcosm-cc/bluemonday v1.0.26
github.com/miekg/dns v1.1.58 github.com/miekg/dns v1.1.58
github.com/minio/minio-go/v7 v7.0.67 github.com/minio/minio-go/v7 v7.0.67

4
go.sum
View file

@ -421,8 +421,8 @@ github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsI
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk= github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk=
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
github.com/jackc/pgx/v5 v5.5.3 h1:Ces6/M3wbDXYpM8JyyPD57ivTtJACFZJd885pdIaV2s= github.com/jackc/pgx/v5 v5.5.5 h1:amBjrZVmksIdNjxGW/IiIMzxMKZFelXbUoPNb+8sjQw=
github.com/jackc/pgx/v5 v5.5.3/go.mod h1:ez9gk+OAat140fv9ErkZDYFWmXLfV+++K0uAOiwgm1A= github.com/jackc/pgx/v5 v5.5.5/go.mod h1:ez9gk+OAat140fv9ErkZDYFWmXLfV+++K0uAOiwgm1A=
github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk= github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk=
github.com/jackc/puddle/v2 v2.2.1/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= github.com/jackc/puddle/v2 v2.2.1/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
github.com/jessevdk/go-flags v1.4.0/go.mod h1:4FA24M0QyGHXBuZZK/XkWh8h0e1EYbRYJSGM75WSRxI= github.com/jessevdk/go-flags v1.4.0/go.mod h1:4FA24M0QyGHXBuZZK/XkWh8h0e1EYbRYJSGM75WSRxI=

View file

@ -1,3 +1,29 @@
# 5.5.5 (March 9, 2024)
Use spaces instead of parentheses for SQL sanitization.
This still solves the problem of negative numbers creating a line comment, but this avoids breaking edge cases such as
`set foo to $1` where the substitution is taking place in a location where an arbitrary expression is not allowed.
# 5.5.4 (March 4, 2024)
Fix CVE-2024-27304
SQL injection can occur if an attacker can cause a single query or bind message to exceed 4 GB in size. An integer
overflow in the calculated message size can cause the one large message to be sent as multiple messages under the
attacker's control.
Thanks to Paul Gerste for reporting this issue.
* Fix behavior of CollectRows to return empty slice if Rows are empty (Felix)
* Fix simple protocol encoding of json.RawMessage
* Fix *Pipeline.getResults should close pipeline on error
* Fix panic in TryFindUnderlyingTypeScanPlan (David Kurman)
* Fix deallocation of invalidated cached statements in a transaction
* Handle invalid sslkey file
* Fix scan float4 into sql.Scanner
* Fix pgtype.Bits not making copy of data from read buffer. This would cause the data to be corrupted by future reads.
# 5.5.3 (February 3, 2024) # 5.5.3 (February 3, 2024)
* Fix: prepared statement already exists * Fix: prepared statement already exists

View file

@ -120,6 +120,7 @@ pgerrcode contains constants for the PostgreSQL error codes.
* [github.com/jackc/pgx-gofrs-uuid](https://github.com/jackc/pgx-gofrs-uuid) * [github.com/jackc/pgx-gofrs-uuid](https://github.com/jackc/pgx-gofrs-uuid)
* [github.com/jackc/pgx-shopspring-decimal](https://github.com/jackc/pgx-shopspring-decimal) * [github.com/jackc/pgx-shopspring-decimal](https://github.com/jackc/pgx-shopspring-decimal)
* [github.com/twpayne/pgx-geos](https://github.com/twpayne/pgx-geos) ([PostGIS](https://postgis.net/) and [GEOS](https://libgeos.org/) via [go-geos](https://github.com/twpayne/go-geos))
* [github.com/vgarvardt/pgx-google-uuid](https://github.com/vgarvardt/pgx-google-uuid) * [github.com/vgarvardt/pgx-google-uuid](https://github.com/vgarvardt/pgx-google-uuid)

View file

@ -1354,7 +1354,7 @@ order by attnum`,
} }
func (c *Conn) deallocateInvalidatedCachedStatements(ctx context.Context) error { func (c *Conn) deallocateInvalidatedCachedStatements(ctx context.Context) error {
if c.pgConn.TxStatus() != 'I' { if txStatus := c.pgConn.TxStatus(); txStatus != 'I' && txStatus != 'T' {
return nil return nil
} }

View file

@ -63,6 +63,10 @@ func (q *Query) Sanitize(args ...any) (string, error) {
return "", fmt.Errorf("invalid arg type: %T", arg) return "", fmt.Errorf("invalid arg type: %T", arg)
} }
argUse[argIdx] = true argUse[argIdx] = true
// Prevent SQL injection via Line Comment Creation
// https://github.com/jackc/pgx/security/advisories/GHSA-m7wr-2xf7-cm9p
str = " " + str + " "
default: default:
return "", fmt.Errorf("invalid Part type: %T", part) return "", fmt.Errorf("invalid Part type: %T", part)
} }

View file

@ -721,6 +721,9 @@ func configTLS(settings map[string]string, thisHost string, parseConfigOptions P
return nil, fmt.Errorf("unable to read sslkey: %w", err) return nil, fmt.Errorf("unable to read sslkey: %w", err)
} }
block, _ := pem.Decode(buf) block, _ := pem.Decode(buf)
if block == nil {
return nil, errors.New("failed to decode sslkey")
}
var pemKey []byte var pemKey []byte
var decryptedKey []byte var decryptedKey []byte
var decryptedError error var decryptedError error

View file

@ -1674,25 +1674,55 @@ func (rr *ResultReader) concludeCommand(commandTag CommandTag, err error) {
// Batch is a collection of queries that can be sent to the PostgreSQL server in a single round-trip. // Batch is a collection of queries that can be sent to the PostgreSQL server in a single round-trip.
type Batch struct { type Batch struct {
buf []byte buf []byte
err error
} }
// ExecParams appends an ExecParams command to the batch. See PgConn.ExecParams for parameter descriptions. // ExecParams appends an ExecParams command to the batch. See PgConn.ExecParams for parameter descriptions.
func (batch *Batch) ExecParams(sql string, paramValues [][]byte, paramOIDs []uint32, paramFormats []int16, resultFormats []int16) { func (batch *Batch) ExecParams(sql string, paramValues [][]byte, paramOIDs []uint32, paramFormats []int16, resultFormats []int16) {
batch.buf = (&pgproto3.Parse{Query: sql, ParameterOIDs: paramOIDs}).Encode(batch.buf) if batch.err != nil {
return
}
batch.buf, batch.err = (&pgproto3.Parse{Query: sql, ParameterOIDs: paramOIDs}).Encode(batch.buf)
if batch.err != nil {
return
}
batch.ExecPrepared("", paramValues, paramFormats, resultFormats) batch.ExecPrepared("", paramValues, paramFormats, resultFormats)
} }
// ExecPrepared appends an ExecPrepared e command to the batch. See PgConn.ExecPrepared for parameter descriptions. // ExecPrepared appends an ExecPrepared e command to the batch. See PgConn.ExecPrepared for parameter descriptions.
func (batch *Batch) ExecPrepared(stmtName string, paramValues [][]byte, paramFormats []int16, resultFormats []int16) { func (batch *Batch) ExecPrepared(stmtName string, paramValues [][]byte, paramFormats []int16, resultFormats []int16) {
batch.buf = (&pgproto3.Bind{PreparedStatement: stmtName, ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}).Encode(batch.buf) if batch.err != nil {
batch.buf = (&pgproto3.Describe{ObjectType: 'P'}).Encode(batch.buf) return
batch.buf = (&pgproto3.Execute{}).Encode(batch.buf) }
batch.buf, batch.err = (&pgproto3.Bind{PreparedStatement: stmtName, ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}).Encode(batch.buf)
if batch.err != nil {
return
}
batch.buf, batch.err = (&pgproto3.Describe{ObjectType: 'P'}).Encode(batch.buf)
if batch.err != nil {
return
}
batch.buf, batch.err = (&pgproto3.Execute{}).Encode(batch.buf)
if batch.err != nil {
return
}
} }
// ExecBatch executes all the queries in batch in a single round-trip. Execution is implicitly transactional unless a // ExecBatch executes all the queries in batch in a single round-trip. Execution is implicitly transactional unless a
// transaction is already in progress or SQL contains transaction control statements. This is a simpler way of executing // transaction is already in progress or SQL contains transaction control statements. This is a simpler way of executing
// multiple queries in a single round trip than using pipeline mode. // multiple queries in a single round trip than using pipeline mode.
func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultReader { func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultReader {
if batch.err != nil {
return &MultiResultReader{
closed: true,
err: batch.err,
}
}
if err := pgConn.lock(); err != nil { if err := pgConn.lock(); err != nil {
return &MultiResultReader{ return &MultiResultReader{
closed: true, closed: true,
@ -1718,7 +1748,13 @@ func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultR
pgConn.contextWatcher.Watch(ctx) pgConn.contextWatcher.Watch(ctx)
} }
batch.buf = (&pgproto3.Sync{}).Encode(batch.buf) batch.buf, batch.err = (&pgproto3.Sync{}).Encode(batch.buf)
if batch.err != nil {
multiResult.closed = true
multiResult.err = batch.err
pgConn.unlock()
return multiResult
}
pgConn.enterPotentialWriteReadDeadlock() pgConn.enterPotentialWriteReadDeadlock()
defer pgConn.exitPotentialWriteReadDeadlock() defer pgConn.exitPotentialWriteReadDeadlock()
@ -2094,6 +2130,8 @@ func (p *Pipeline) getResults() (results any, err error) {
for { for {
msg, err := p.conn.receiveMessage() msg, err := p.conn.receiveMessage()
if err != nil { if err != nil {
p.closed = true
p.err = err
p.conn.asyncClose() p.conn.asyncClose()
return nil, normalizeTimeoutError(p.ctx, err) return nil, normalizeTimeoutError(p.ctx, err)
} }

View file

@ -35,11 +35,10 @@ func (dst *AuthenticationCleartextPassword) Decode(src []byte) error {
} }
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *AuthenticationCleartextPassword) Encode(dst []byte) []byte { func (src *AuthenticationCleartextPassword) Encode(dst []byte) ([]byte, error) {
dst = append(dst, 'R') dst, sp := beginMessage(dst, 'R')
dst = pgio.AppendInt32(dst, 8)
dst = pgio.AppendUint32(dst, AuthTypeCleartextPassword) dst = pgio.AppendUint32(dst, AuthTypeCleartextPassword)
return dst return finishMessage(dst, sp)
} }
// MarshalJSON implements encoding/json.Marshaler. // MarshalJSON implements encoding/json.Marshaler.

View file

@ -27,11 +27,10 @@ func (a *AuthenticationGSS) Decode(src []byte) error {
return nil return nil
} }
func (a *AuthenticationGSS) Encode(dst []byte) []byte { func (a *AuthenticationGSS) Encode(dst []byte) ([]byte, error) {
dst = append(dst, 'R') dst, sp := beginMessage(dst, 'R')
dst = pgio.AppendInt32(dst, 4)
dst = pgio.AppendUint32(dst, AuthTypeGSS) dst = pgio.AppendUint32(dst, AuthTypeGSS)
return dst return finishMessage(dst, sp)
} }
func (a *AuthenticationGSS) MarshalJSON() ([]byte, error) { func (a *AuthenticationGSS) MarshalJSON() ([]byte, error) {

View file

@ -31,12 +31,11 @@ func (a *AuthenticationGSSContinue) Decode(src []byte) error {
return nil return nil
} }
func (a *AuthenticationGSSContinue) Encode(dst []byte) []byte { func (a *AuthenticationGSSContinue) Encode(dst []byte) ([]byte, error) {
dst = append(dst, 'R') dst, sp := beginMessage(dst, 'R')
dst = pgio.AppendInt32(dst, int32(len(a.Data))+8)
dst = pgio.AppendUint32(dst, AuthTypeGSSCont) dst = pgio.AppendUint32(dst, AuthTypeGSSCont)
dst = append(dst, a.Data...) dst = append(dst, a.Data...)
return dst return finishMessage(dst, sp)
} }
func (a *AuthenticationGSSContinue) MarshalJSON() ([]byte, error) { func (a *AuthenticationGSSContinue) MarshalJSON() ([]byte, error) {

View file

@ -38,12 +38,11 @@ func (dst *AuthenticationMD5Password) Decode(src []byte) error {
} }
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *AuthenticationMD5Password) Encode(dst []byte) []byte { func (src *AuthenticationMD5Password) Encode(dst []byte) ([]byte, error) {
dst = append(dst, 'R') dst, sp := beginMessage(dst, 'R')
dst = pgio.AppendInt32(dst, 12)
dst = pgio.AppendUint32(dst, AuthTypeMD5Password) dst = pgio.AppendUint32(dst, AuthTypeMD5Password)
dst = append(dst, src.Salt[:]...) dst = append(dst, src.Salt[:]...)
return dst return finishMessage(dst, sp)
} }
// MarshalJSON implements encoding/json.Marshaler. // MarshalJSON implements encoding/json.Marshaler.

View file

@ -35,11 +35,10 @@ func (dst *AuthenticationOk) Decode(src []byte) error {
} }
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *AuthenticationOk) Encode(dst []byte) []byte { func (src *AuthenticationOk) Encode(dst []byte) ([]byte, error) {
dst = append(dst, 'R') dst, sp := beginMessage(dst, 'R')
dst = pgio.AppendInt32(dst, 8)
dst = pgio.AppendUint32(dst, AuthTypeOk) dst = pgio.AppendUint32(dst, AuthTypeOk)
return dst return finishMessage(dst, sp)
} }
// MarshalJSON implements encoding/json.Marshaler. // MarshalJSON implements encoding/json.Marshaler.

View file

@ -47,10 +47,8 @@ func (dst *AuthenticationSASL) Decode(src []byte) error {
} }
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *AuthenticationSASL) Encode(dst []byte) []byte { func (src *AuthenticationSASL) Encode(dst []byte) ([]byte, error) {
dst = append(dst, 'R') dst, sp := beginMessage(dst, 'R')
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
dst = pgio.AppendUint32(dst, AuthTypeSASL) dst = pgio.AppendUint32(dst, AuthTypeSASL)
for _, s := range src.AuthMechanisms { for _, s := range src.AuthMechanisms {
@ -59,9 +57,7 @@ func (src *AuthenticationSASL) Encode(dst []byte) []byte {
} }
dst = append(dst, 0) dst = append(dst, 0)
pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) return finishMessage(dst, sp)
return dst
} }
// MarshalJSON implements encoding/json.Marshaler. // MarshalJSON implements encoding/json.Marshaler.

View file

@ -38,17 +38,11 @@ func (dst *AuthenticationSASLContinue) Decode(src []byte) error {
} }
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *AuthenticationSASLContinue) Encode(dst []byte) []byte { func (src *AuthenticationSASLContinue) Encode(dst []byte) ([]byte, error) {
dst = append(dst, 'R') dst, sp := beginMessage(dst, 'R')
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
dst = pgio.AppendUint32(dst, AuthTypeSASLContinue) dst = pgio.AppendUint32(dst, AuthTypeSASLContinue)
dst = append(dst, src.Data...) dst = append(dst, src.Data...)
return finishMessage(dst, sp)
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
return dst
} }
// MarshalJSON implements encoding/json.Marshaler. // MarshalJSON implements encoding/json.Marshaler.

View file

@ -38,17 +38,11 @@ func (dst *AuthenticationSASLFinal) Decode(src []byte) error {
} }
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *AuthenticationSASLFinal) Encode(dst []byte) []byte { func (src *AuthenticationSASLFinal) Encode(dst []byte) ([]byte, error) {
dst = append(dst, 'R') dst, sp := beginMessage(dst, 'R')
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
dst = pgio.AppendUint32(dst, AuthTypeSASLFinal) dst = pgio.AppendUint32(dst, AuthTypeSASLFinal)
dst = append(dst, src.Data...) dst = append(dst, src.Data...)
return finishMessage(dst, sp)
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
return dst
} }
// MarshalJSON implements encoding/json.Unmarshaler. // MarshalJSON implements encoding/json.Unmarshaler.

View file

@ -17,6 +17,7 @@ type Backend struct {
tracer *tracer tracer *tracer
wbuf []byte wbuf []byte
encodeError error
// Frontend message flyweights // Frontend message flyweights
bind Bind bind Bind
@ -55,11 +56,21 @@ func NewBackend(r io.Reader, w io.Writer) *Backend {
return &Backend{cr: cr, w: w} return &Backend{cr: cr, w: w}
} }
// Send sends a message to the frontend (i.e. the client). The message is not guaranteed to be written until Flush is // Send sends a message to the frontend (i.e. the client). The message is buffered until Flush is called. Any error
// called. // encountered will be returned from Flush.
func (b *Backend) Send(msg BackendMessage) { func (b *Backend) Send(msg BackendMessage) {
if b.encodeError != nil {
return
}
prevLen := len(b.wbuf) prevLen := len(b.wbuf)
b.wbuf = msg.Encode(b.wbuf) newBuf, err := msg.Encode(b.wbuf)
if err != nil {
b.encodeError = err
return
}
b.wbuf = newBuf
if b.tracer != nil { if b.tracer != nil {
b.tracer.traceMessage('B', int32(len(b.wbuf)-prevLen), msg) b.tracer.traceMessage('B', int32(len(b.wbuf)-prevLen), msg)
} }
@ -67,6 +78,12 @@ func (b *Backend) Send(msg BackendMessage) {
// Flush writes any pending messages to the frontend (i.e. the client). // Flush writes any pending messages to the frontend (i.e. the client).
func (b *Backend) Flush() error { func (b *Backend) Flush() error {
if err := b.encodeError; err != nil {
b.encodeError = nil
b.wbuf = b.wbuf[:0]
return &writeError{err: err, safeToRetry: true}
}
n, err := b.w.Write(b.wbuf) n, err := b.w.Write(b.wbuf)
const maxLen = 1024 const maxLen = 1024

View file

@ -29,12 +29,11 @@ func (dst *BackendKeyData) Decode(src []byte) error {
} }
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *BackendKeyData) Encode(dst []byte) []byte { func (src *BackendKeyData) Encode(dst []byte) ([]byte, error) {
dst = append(dst, 'K') dst, sp := beginMessage(dst, 'K')
dst = pgio.AppendUint32(dst, 12)
dst = pgio.AppendUint32(dst, src.ProcessID) dst = pgio.AppendUint32(dst, src.ProcessID)
dst = pgio.AppendUint32(dst, src.SecretKey) dst = pgio.AppendUint32(dst, src.SecretKey)
return dst return finishMessage(dst, sp)
} }
// MarshalJSON implements encoding/json.Marshaler. // MarshalJSON implements encoding/json.Marshaler.

View file

@ -5,7 +5,9 @@ import (
"encoding/binary" "encoding/binary"
"encoding/hex" "encoding/hex"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"math"
"github.com/jackc/pgx/v5/internal/pgio" "github.com/jackc/pgx/v5/internal/pgio"
) )
@ -108,21 +110,25 @@ func (dst *Bind) Decode(src []byte) error {
} }
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *Bind) Encode(dst []byte) []byte { func (src *Bind) Encode(dst []byte) ([]byte, error) {
dst = append(dst, 'B') dst, sp := beginMessage(dst, 'B')
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
dst = append(dst, src.DestinationPortal...) dst = append(dst, src.DestinationPortal...)
dst = append(dst, 0) dst = append(dst, 0)
dst = append(dst, src.PreparedStatement...) dst = append(dst, src.PreparedStatement...)
dst = append(dst, 0) dst = append(dst, 0)
if len(src.ParameterFormatCodes) > math.MaxUint16 {
return nil, errors.New("too many parameter format codes")
}
dst = pgio.AppendUint16(dst, uint16(len(src.ParameterFormatCodes))) dst = pgio.AppendUint16(dst, uint16(len(src.ParameterFormatCodes)))
for _, fc := range src.ParameterFormatCodes { for _, fc := range src.ParameterFormatCodes {
dst = pgio.AppendInt16(dst, fc) dst = pgio.AppendInt16(dst, fc)
} }
if len(src.Parameters) > math.MaxUint16 {
return nil, errors.New("too many parameters")
}
dst = pgio.AppendUint16(dst, uint16(len(src.Parameters))) dst = pgio.AppendUint16(dst, uint16(len(src.Parameters)))
for _, p := range src.Parameters { for _, p := range src.Parameters {
if p == nil { if p == nil {
@ -134,14 +140,15 @@ func (src *Bind) Encode(dst []byte) []byte {
dst = append(dst, p...) dst = append(dst, p...)
} }
if len(src.ResultFormatCodes) > math.MaxUint16 {
return nil, errors.New("too many result format codes")
}
dst = pgio.AppendUint16(dst, uint16(len(src.ResultFormatCodes))) dst = pgio.AppendUint16(dst, uint16(len(src.ResultFormatCodes)))
for _, fc := range src.ResultFormatCodes { for _, fc := range src.ResultFormatCodes {
dst = pgio.AppendInt16(dst, fc) dst = pgio.AppendInt16(dst, fc)
} }
pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) return finishMessage(dst, sp)
return dst
} }
// MarshalJSON implements encoding/json.Marshaler. // MarshalJSON implements encoding/json.Marshaler.

View file

@ -20,8 +20,8 @@ func (dst *BindComplete) Decode(src []byte) error {
} }
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *BindComplete) Encode(dst []byte) []byte { func (src *BindComplete) Encode(dst []byte) ([]byte, error) {
return append(dst, '2', 0, 0, 0, 4) return append(dst, '2', 0, 0, 0, 4), nil
} }
// MarshalJSON implements encoding/json.Marshaler. // MarshalJSON implements encoding/json.Marshaler.

View file

@ -36,12 +36,12 @@ func (dst *CancelRequest) Decode(src []byte) error {
} }
// Encode encodes src into dst. dst will include the 4 byte message length. // Encode encodes src into dst. dst will include the 4 byte message length.
func (src *CancelRequest) Encode(dst []byte) []byte { func (src *CancelRequest) Encode(dst []byte) ([]byte, error) {
dst = pgio.AppendInt32(dst, 16) dst = pgio.AppendInt32(dst, 16)
dst = pgio.AppendInt32(dst, cancelRequestCode) dst = pgio.AppendInt32(dst, cancelRequestCode)
dst = pgio.AppendUint32(dst, src.ProcessID) dst = pgio.AppendUint32(dst, src.ProcessID)
dst = pgio.AppendUint32(dst, src.SecretKey) dst = pgio.AppendUint32(dst, src.SecretKey)
return dst return dst, nil
} }
// MarshalJSON implements encoding/json.Marshaler. // MarshalJSON implements encoding/json.Marshaler.

View file

@ -4,8 +4,6 @@ import (
"bytes" "bytes"
"encoding/json" "encoding/json"
"errors" "errors"
"github.com/jackc/pgx/v5/internal/pgio"
) )
type Close struct { type Close struct {
@ -37,18 +35,12 @@ func (dst *Close) Decode(src []byte) error {
} }
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *Close) Encode(dst []byte) []byte { func (src *Close) Encode(dst []byte) ([]byte, error) {
dst = append(dst, 'C') dst, sp := beginMessage(dst, 'C')
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
dst = append(dst, src.ObjectType) dst = append(dst, src.ObjectType)
dst = append(dst, src.Name...) dst = append(dst, src.Name...)
dst = append(dst, 0) dst = append(dst, 0)
return finishMessage(dst, sp)
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
return dst
} }
// MarshalJSON implements encoding/json.Marshaler. // MarshalJSON implements encoding/json.Marshaler.

View file

@ -20,8 +20,8 @@ func (dst *CloseComplete) Decode(src []byte) error {
} }
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *CloseComplete) Encode(dst []byte) []byte { func (src *CloseComplete) Encode(dst []byte) ([]byte, error) {
return append(dst, '3', 0, 0, 0, 4) return append(dst, '3', 0, 0, 0, 4), nil
} }
// MarshalJSON implements encoding/json.Marshaler. // MarshalJSON implements encoding/json.Marshaler.

View file

@ -3,8 +3,6 @@ package pgproto3
import ( import (
"bytes" "bytes"
"encoding/json" "encoding/json"
"github.com/jackc/pgx/v5/internal/pgio"
) )
type CommandComplete struct { type CommandComplete struct {
@ -31,17 +29,11 @@ func (dst *CommandComplete) Decode(src []byte) error {
} }
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *CommandComplete) Encode(dst []byte) []byte { func (src *CommandComplete) Encode(dst []byte) ([]byte, error) {
dst = append(dst, 'C') dst, sp := beginMessage(dst, 'C')
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
dst = append(dst, src.CommandTag...) dst = append(dst, src.CommandTag...)
dst = append(dst, 0) dst = append(dst, 0)
return finishMessage(dst, sp)
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
return dst
} }
// MarshalJSON implements encoding/json.Marshaler. // MarshalJSON implements encoding/json.Marshaler.

View file

@ -5,6 +5,7 @@ import (
"encoding/binary" "encoding/binary"
"encoding/json" "encoding/json"
"errors" "errors"
"math"
"github.com/jackc/pgx/v5/internal/pgio" "github.com/jackc/pgx/v5/internal/pgio"
) )
@ -44,19 +45,18 @@ func (dst *CopyBothResponse) Decode(src []byte) error {
} }
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *CopyBothResponse) Encode(dst []byte) []byte { func (src *CopyBothResponse) Encode(dst []byte) ([]byte, error) {
dst = append(dst, 'W') dst, sp := beginMessage(dst, 'W')
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
dst = append(dst, src.OverallFormat) dst = append(dst, src.OverallFormat)
if len(src.ColumnFormatCodes) > math.MaxUint16 {
return nil, errors.New("too many column format codes")
}
dst = pgio.AppendUint16(dst, uint16(len(src.ColumnFormatCodes))) dst = pgio.AppendUint16(dst, uint16(len(src.ColumnFormatCodes)))
for _, fc := range src.ColumnFormatCodes { for _, fc := range src.ColumnFormatCodes {
dst = pgio.AppendUint16(dst, fc) dst = pgio.AppendUint16(dst, fc)
} }
pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) return finishMessage(dst, sp)
return dst
} }
// MarshalJSON implements encoding/json.Marshaler. // MarshalJSON implements encoding/json.Marshaler.

View file

@ -3,8 +3,6 @@ package pgproto3
import ( import (
"encoding/hex" "encoding/hex"
"encoding/json" "encoding/json"
"github.com/jackc/pgx/v5/internal/pgio"
) )
type CopyData struct { type CopyData struct {
@ -25,11 +23,10 @@ func (dst *CopyData) Decode(src []byte) error {
} }
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *CopyData) Encode(dst []byte) []byte { func (src *CopyData) Encode(dst []byte) ([]byte, error) {
dst = append(dst, 'd') dst, sp := beginMessage(dst, 'd')
dst = pgio.AppendInt32(dst, int32(4+len(src.Data)))
dst = append(dst, src.Data...) dst = append(dst, src.Data...)
return dst return finishMessage(dst, sp)
} }
// MarshalJSON implements encoding/json.Marshaler. // MarshalJSON implements encoding/json.Marshaler.

View file

@ -24,8 +24,8 @@ func (dst *CopyDone) Decode(src []byte) error {
} }
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *CopyDone) Encode(dst []byte) []byte { func (src *CopyDone) Encode(dst []byte) ([]byte, error) {
return append(dst, 'c', 0, 0, 0, 4) return append(dst, 'c', 0, 0, 0, 4), nil
} }
// MarshalJSON implements encoding/json.Marshaler. // MarshalJSON implements encoding/json.Marshaler.

View file

@ -3,8 +3,6 @@ package pgproto3
import ( import (
"bytes" "bytes"
"encoding/json" "encoding/json"
"github.com/jackc/pgx/v5/internal/pgio"
) )
type CopyFail struct { type CopyFail struct {
@ -28,17 +26,11 @@ func (dst *CopyFail) Decode(src []byte) error {
} }
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *CopyFail) Encode(dst []byte) []byte { func (src *CopyFail) Encode(dst []byte) ([]byte, error) {
dst = append(dst, 'f') dst, sp := beginMessage(dst, 'f')
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
dst = append(dst, src.Message...) dst = append(dst, src.Message...)
dst = append(dst, 0) dst = append(dst, 0)
return finishMessage(dst, sp)
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
return dst
} }
// MarshalJSON implements encoding/json.Marshaler. // MarshalJSON implements encoding/json.Marshaler.

View file

@ -5,6 +5,7 @@ import (
"encoding/binary" "encoding/binary"
"encoding/json" "encoding/json"
"errors" "errors"
"math"
"github.com/jackc/pgx/v5/internal/pgio" "github.com/jackc/pgx/v5/internal/pgio"
) )
@ -44,20 +45,19 @@ func (dst *CopyInResponse) Decode(src []byte) error {
} }
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *CopyInResponse) Encode(dst []byte) []byte { func (src *CopyInResponse) Encode(dst []byte) ([]byte, error) {
dst = append(dst, 'G') dst, sp := beginMessage(dst, 'G')
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
dst = append(dst, src.OverallFormat) dst = append(dst, src.OverallFormat)
if len(src.ColumnFormatCodes) > math.MaxUint16 {
return nil, errors.New("too many column format codes")
}
dst = pgio.AppendUint16(dst, uint16(len(src.ColumnFormatCodes))) dst = pgio.AppendUint16(dst, uint16(len(src.ColumnFormatCodes)))
for _, fc := range src.ColumnFormatCodes { for _, fc := range src.ColumnFormatCodes {
dst = pgio.AppendUint16(dst, fc) dst = pgio.AppendUint16(dst, fc)
} }
pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) return finishMessage(dst, sp)
return dst
} }
// MarshalJSON implements encoding/json.Marshaler. // MarshalJSON implements encoding/json.Marshaler.

View file

@ -5,6 +5,7 @@ import (
"encoding/binary" "encoding/binary"
"encoding/json" "encoding/json"
"errors" "errors"
"math"
"github.com/jackc/pgx/v5/internal/pgio" "github.com/jackc/pgx/v5/internal/pgio"
) )
@ -43,21 +44,20 @@ func (dst *CopyOutResponse) Decode(src []byte) error {
} }
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *CopyOutResponse) Encode(dst []byte) []byte { func (src *CopyOutResponse) Encode(dst []byte) ([]byte, error) {
dst = append(dst, 'H') dst, sp := beginMessage(dst, 'H')
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
dst = append(dst, src.OverallFormat) dst = append(dst, src.OverallFormat)
if len(src.ColumnFormatCodes) > math.MaxUint16 {
return nil, errors.New("too many column format codes")
}
dst = pgio.AppendUint16(dst, uint16(len(src.ColumnFormatCodes))) dst = pgio.AppendUint16(dst, uint16(len(src.ColumnFormatCodes)))
for _, fc := range src.ColumnFormatCodes { for _, fc := range src.ColumnFormatCodes {
dst = pgio.AppendUint16(dst, fc) dst = pgio.AppendUint16(dst, fc)
} }
pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) return finishMessage(dst, sp)
return dst
} }
// MarshalJSON implements encoding/json.Marshaler. // MarshalJSON implements encoding/json.Marshaler.

View file

@ -4,6 +4,8 @@ import (
"encoding/binary" "encoding/binary"
"encoding/hex" "encoding/hex"
"encoding/json" "encoding/json"
"errors"
"math"
"github.com/jackc/pgx/v5/internal/pgio" "github.com/jackc/pgx/v5/internal/pgio"
) )
@ -63,11 +65,12 @@ func (dst *DataRow) Decode(src []byte) error {
} }
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *DataRow) Encode(dst []byte) []byte { func (src *DataRow) Encode(dst []byte) ([]byte, error) {
dst = append(dst, 'D') dst, sp := beginMessage(dst, 'D')
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
if len(src.Values) > math.MaxUint16 {
return nil, errors.New("too many values")
}
dst = pgio.AppendUint16(dst, uint16(len(src.Values))) dst = pgio.AppendUint16(dst, uint16(len(src.Values)))
for _, v := range src.Values { for _, v := range src.Values {
if v == nil { if v == nil {
@ -79,9 +82,7 @@ func (src *DataRow) Encode(dst []byte) []byte {
dst = append(dst, v...) dst = append(dst, v...)
} }
pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) return finishMessage(dst, sp)
return dst
} }
// MarshalJSON implements encoding/json.Marshaler. // MarshalJSON implements encoding/json.Marshaler.

View file

@ -4,8 +4,6 @@ import (
"bytes" "bytes"
"encoding/json" "encoding/json"
"errors" "errors"
"github.com/jackc/pgx/v5/internal/pgio"
) )
type Describe struct { type Describe struct {
@ -37,18 +35,12 @@ func (dst *Describe) Decode(src []byte) error {
} }
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *Describe) Encode(dst []byte) []byte { func (src *Describe) Encode(dst []byte) ([]byte, error) {
dst = append(dst, 'D') dst, sp := beginMessage(dst, 'D')
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
dst = append(dst, src.ObjectType) dst = append(dst, src.ObjectType)
dst = append(dst, src.Name...) dst = append(dst, src.Name...)
dst = append(dst, 0) dst = append(dst, 0)
return finishMessage(dst, sp)
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
return dst
} }
// MarshalJSON implements encoding/json.Marshaler. // MarshalJSON implements encoding/json.Marshaler.

View file

@ -20,8 +20,8 @@ func (dst *EmptyQueryResponse) Decode(src []byte) error {
} }
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *EmptyQueryResponse) Encode(dst []byte) []byte { func (src *EmptyQueryResponse) Encode(dst []byte) ([]byte, error) {
return append(dst, 'I', 0, 0, 0, 4) return append(dst, 'I', 0, 0, 0, 4), nil
} }
// MarshalJSON implements encoding/json.Marshaler. // MarshalJSON implements encoding/json.Marshaler.

View file

@ -2,7 +2,6 @@ package pgproto3
import ( import (
"bytes" "bytes"
"encoding/binary"
"encoding/json" "encoding/json"
"strconv" "strconv"
) )
@ -111,119 +110,113 @@ func (dst *ErrorResponse) Decode(src []byte) error {
} }
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *ErrorResponse) Encode(dst []byte) []byte { func (src *ErrorResponse) Encode(dst []byte) ([]byte, error) {
return append(dst, src.marshalBinary('E')...) dst, sp := beginMessage(dst, 'E')
dst = src.appendFields(dst)
return finishMessage(dst, sp)
} }
func (src *ErrorResponse) marshalBinary(typeByte byte) []byte { func (src *ErrorResponse) appendFields(dst []byte) []byte {
var bigEndian BigEndianBuf
buf := &bytes.Buffer{}
buf.WriteByte(typeByte)
buf.Write(bigEndian.Uint32(0))
if src.Severity != "" { if src.Severity != "" {
buf.WriteByte('S') dst = append(dst, 'S')
buf.WriteString(src.Severity) dst = append(dst, src.Severity...)
buf.WriteByte(0) dst = append(dst, 0)
} }
if src.SeverityUnlocalized != "" { if src.SeverityUnlocalized != "" {
buf.WriteByte('V') dst = append(dst, 'V')
buf.WriteString(src.SeverityUnlocalized) dst = append(dst, src.SeverityUnlocalized...)
buf.WriteByte(0) dst = append(dst, 0)
} }
if src.Code != "" { if src.Code != "" {
buf.WriteByte('C') dst = append(dst, 'C')
buf.WriteString(src.Code) dst = append(dst, src.Code...)
buf.WriteByte(0) dst = append(dst, 0)
} }
if src.Message != "" { if src.Message != "" {
buf.WriteByte('M') dst = append(dst, 'M')
buf.WriteString(src.Message) dst = append(dst, src.Message...)
buf.WriteByte(0) dst = append(dst, 0)
} }
if src.Detail != "" { if src.Detail != "" {
buf.WriteByte('D') dst = append(dst, 'D')
buf.WriteString(src.Detail) dst = append(dst, src.Detail...)
buf.WriteByte(0) dst = append(dst, 0)
} }
if src.Hint != "" { if src.Hint != "" {
buf.WriteByte('H') dst = append(dst, 'H')
buf.WriteString(src.Hint) dst = append(dst, src.Hint...)
buf.WriteByte(0) dst = append(dst, 0)
} }
if src.Position != 0 { if src.Position != 0 {
buf.WriteByte('P') dst = append(dst, 'P')
buf.WriteString(strconv.Itoa(int(src.Position))) dst = append(dst, strconv.Itoa(int(src.Position))...)
buf.WriteByte(0) dst = append(dst, 0)
} }
if src.InternalPosition != 0 { if src.InternalPosition != 0 {
buf.WriteByte('p') dst = append(dst, 'p')
buf.WriteString(strconv.Itoa(int(src.InternalPosition))) dst = append(dst, strconv.Itoa(int(src.InternalPosition))...)
buf.WriteByte(0) dst = append(dst, 0)
} }
if src.InternalQuery != "" { if src.InternalQuery != "" {
buf.WriteByte('q') dst = append(dst, 'q')
buf.WriteString(src.InternalQuery) dst = append(dst, src.InternalQuery...)
buf.WriteByte(0) dst = append(dst, 0)
} }
if src.Where != "" { if src.Where != "" {
buf.WriteByte('W') dst = append(dst, 'W')
buf.WriteString(src.Where) dst = append(dst, src.Where...)
buf.WriteByte(0) dst = append(dst, 0)
} }
if src.SchemaName != "" { if src.SchemaName != "" {
buf.WriteByte('s') dst = append(dst, 's')
buf.WriteString(src.SchemaName) dst = append(dst, src.SchemaName...)
buf.WriteByte(0) dst = append(dst, 0)
} }
if src.TableName != "" { if src.TableName != "" {
buf.WriteByte('t') dst = append(dst, 't')
buf.WriteString(src.TableName) dst = append(dst, src.TableName...)
buf.WriteByte(0) dst = append(dst, 0)
} }
if src.ColumnName != "" { if src.ColumnName != "" {
buf.WriteByte('c') dst = append(dst, 'c')
buf.WriteString(src.ColumnName) dst = append(dst, src.ColumnName...)
buf.WriteByte(0) dst = append(dst, 0)
} }
if src.DataTypeName != "" { if src.DataTypeName != "" {
buf.WriteByte('d') dst = append(dst, 'd')
buf.WriteString(src.DataTypeName) dst = append(dst, src.DataTypeName...)
buf.WriteByte(0) dst = append(dst, 0)
} }
if src.ConstraintName != "" { if src.ConstraintName != "" {
buf.WriteByte('n') dst = append(dst, 'n')
buf.WriteString(src.ConstraintName) dst = append(dst, src.ConstraintName...)
buf.WriteByte(0) dst = append(dst, 0)
} }
if src.File != "" { if src.File != "" {
buf.WriteByte('F') dst = append(dst, 'F')
buf.WriteString(src.File) dst = append(dst, src.File...)
buf.WriteByte(0) dst = append(dst, 0)
} }
if src.Line != 0 { if src.Line != 0 {
buf.WriteByte('L') dst = append(dst, 'L')
buf.WriteString(strconv.Itoa(int(src.Line))) dst = append(dst, strconv.Itoa(int(src.Line))...)
buf.WriteByte(0) dst = append(dst, 0)
} }
if src.Routine != "" { if src.Routine != "" {
buf.WriteByte('R') dst = append(dst, 'R')
buf.WriteString(src.Routine) dst = append(dst, src.Routine...)
buf.WriteByte(0) dst = append(dst, 0)
} }
for k, v := range src.UnknownFields { for k, v := range src.UnknownFields {
buf.WriteByte(k) dst = append(dst, k)
buf.WriteString(v) dst = append(dst, v...)
buf.WriteByte(0) dst = append(dst, 0)
} }
buf.WriteByte(0) dst = append(dst, 0)
binary.BigEndian.PutUint32(buf.Bytes()[1:5], uint32(buf.Len()-1)) return dst
return buf.Bytes()
} }
// MarshalJSON implements encoding/json.Marshaler. // MarshalJSON implements encoding/json.Marshaler.

View file

@ -36,19 +36,12 @@ func (dst *Execute) Decode(src []byte) error {
} }
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *Execute) Encode(dst []byte) []byte { func (src *Execute) Encode(dst []byte) ([]byte, error) {
dst = append(dst, 'E') dst, sp := beginMessage(dst, 'E')
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
dst = append(dst, src.Portal...) dst = append(dst, src.Portal...)
dst = append(dst, 0) dst = append(dst, 0)
dst = pgio.AppendUint32(dst, src.MaxRows) dst = pgio.AppendUint32(dst, src.MaxRows)
return finishMessage(dst, sp)
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
return dst
} }
// MarshalJSON implements encoding/json.Marshaler. // MarshalJSON implements encoding/json.Marshaler.

View file

@ -20,8 +20,8 @@ func (dst *Flush) Decode(src []byte) error {
} }
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *Flush) Encode(dst []byte) []byte { func (src *Flush) Encode(dst []byte) ([]byte, error) {
return append(dst, 'H', 0, 0, 0, 4) return append(dst, 'H', 0, 0, 0, 4), nil
} }
// MarshalJSON implements encoding/json.Marshaler. // MarshalJSON implements encoding/json.Marshaler.

View file

@ -19,6 +19,7 @@ type Frontend struct {
tracer *tracer tracer *tracer
wbuf []byte wbuf []byte
encodeError error
// Backend message flyweights // Backend message flyweights
authenticationOk AuthenticationOk authenticationOk AuthenticationOk
@ -64,16 +65,26 @@ func NewFrontend(r io.Reader, w io.Writer) *Frontend {
return &Frontend{cr: cr, w: w} return &Frontend{cr: cr, w: w}
} }
// Send sends a message to the backend (i.e. the server). The message is not guaranteed to be written until Flush is // Send sends a message to the backend (i.e. the server). The message is buffered until Flush is called. Any error
// called. // encountered will be returned from Flush.
// //
// Send can work with any FrontendMessage. Some commonly used message types such as Bind have specialized send methods // Send can work with any FrontendMessage. Some commonly used message types such as Bind have specialized send methods
// such as SendBind. These methods should be preferred when the type of message is known up front (e.g. when building an // such as SendBind. These methods should be preferred when the type of message is known up front (e.g. when building an
// extended query protocol query) as they may be faster due to knowing the type of msg rather than it being hidden // extended query protocol query) as they may be faster due to knowing the type of msg rather than it being hidden
// behind an interface. // behind an interface.
func (f *Frontend) Send(msg FrontendMessage) { func (f *Frontend) Send(msg FrontendMessage) {
if f.encodeError != nil {
return
}
prevLen := len(f.wbuf) prevLen := len(f.wbuf)
f.wbuf = msg.Encode(f.wbuf) newBuf, err := msg.Encode(f.wbuf)
if err != nil {
f.encodeError = err
return
}
f.wbuf = newBuf
if f.tracer != nil { if f.tracer != nil {
f.tracer.traceMessage('F', int32(len(f.wbuf)-prevLen), msg) f.tracer.traceMessage('F', int32(len(f.wbuf)-prevLen), msg)
} }
@ -81,6 +92,12 @@ func (f *Frontend) Send(msg FrontendMessage) {
// Flush writes any pending messages to the backend (i.e. the server). // Flush writes any pending messages to the backend (i.e. the server).
func (f *Frontend) Flush() error { func (f *Frontend) Flush() error {
if err := f.encodeError; err != nil {
f.encodeError = nil
f.wbuf = f.wbuf[:0]
return &writeError{err: err, safeToRetry: true}
}
if len(f.wbuf) == 0 { if len(f.wbuf) == 0 {
return nil return nil
} }
@ -116,71 +133,141 @@ func (f *Frontend) Untrace() {
f.tracer = nil f.tracer = nil
} }
// SendBind sends a Bind message to the backend (i.e. the server). The message is not guaranteed to be written until // SendBind sends a Bind message to the backend (i.e. the server). The message is buffered until Flush is called. Any
// Flush is called. // error encountered will be returned from Flush.
func (f *Frontend) SendBind(msg *Bind) { func (f *Frontend) SendBind(msg *Bind) {
if f.encodeError != nil {
return
}
prevLen := len(f.wbuf) prevLen := len(f.wbuf)
f.wbuf = msg.Encode(f.wbuf) newBuf, err := msg.Encode(f.wbuf)
if err != nil {
f.encodeError = err
return
}
f.wbuf = newBuf
if f.tracer != nil { if f.tracer != nil {
f.tracer.traceBind('F', int32(len(f.wbuf)-prevLen), msg) f.tracer.traceBind('F', int32(len(f.wbuf)-prevLen), msg)
} }
} }
// SendParse sends a Parse message to the backend (i.e. the server). The message is not guaranteed to be written until // SendParse sends a Parse message to the backend (i.e. the server). The message is buffered until Flush is called. Any
// Flush is called. // error encountered will be returned from Flush.
func (f *Frontend) SendParse(msg *Parse) { func (f *Frontend) SendParse(msg *Parse) {
if f.encodeError != nil {
return
}
prevLen := len(f.wbuf) prevLen := len(f.wbuf)
f.wbuf = msg.Encode(f.wbuf) newBuf, err := msg.Encode(f.wbuf)
if err != nil {
f.encodeError = err
return
}
f.wbuf = newBuf
if f.tracer != nil { if f.tracer != nil {
f.tracer.traceParse('F', int32(len(f.wbuf)-prevLen), msg) f.tracer.traceParse('F', int32(len(f.wbuf)-prevLen), msg)
} }
} }
// SendClose sends a Close message to the backend (i.e. the server). The message is not guaranteed to be written until // SendClose sends a Close message to the backend (i.e. the server). The message is buffered until Flush is called. Any
// Flush is called. // error encountered will be returned from Flush.
func (f *Frontend) SendClose(msg *Close) { func (f *Frontend) SendClose(msg *Close) {
if f.encodeError != nil {
return
}
prevLen := len(f.wbuf) prevLen := len(f.wbuf)
f.wbuf = msg.Encode(f.wbuf) newBuf, err := msg.Encode(f.wbuf)
if err != nil {
f.encodeError = err
return
}
f.wbuf = newBuf
if f.tracer != nil { if f.tracer != nil {
f.tracer.traceClose('F', int32(len(f.wbuf)-prevLen), msg) f.tracer.traceClose('F', int32(len(f.wbuf)-prevLen), msg)
} }
} }
// SendDescribe sends a Describe message to the backend (i.e. the server). The message is not guaranteed to be written until // SendDescribe sends a Describe message to the backend (i.e. the server). The message is buffered until Flush is
// Flush is called. // called. Any error encountered will be returned from Flush.
func (f *Frontend) SendDescribe(msg *Describe) { func (f *Frontend) SendDescribe(msg *Describe) {
if f.encodeError != nil {
return
}
prevLen := len(f.wbuf) prevLen := len(f.wbuf)
f.wbuf = msg.Encode(f.wbuf) newBuf, err := msg.Encode(f.wbuf)
if err != nil {
f.encodeError = err
return
}
f.wbuf = newBuf
if f.tracer != nil { if f.tracer != nil {
f.tracer.traceDescribe('F', int32(len(f.wbuf)-prevLen), msg) f.tracer.traceDescribe('F', int32(len(f.wbuf)-prevLen), msg)
} }
} }
// SendExecute sends an Execute message to the backend (i.e. the server). The message is not guaranteed to be written until // SendExecute sends an Execute message to the backend (i.e. the server). The message is buffered until Flush is called.
// Flush is called. // Any error encountered will be returned from Flush.
func (f *Frontend) SendExecute(msg *Execute) { func (f *Frontend) SendExecute(msg *Execute) {
if f.encodeError != nil {
return
}
prevLen := len(f.wbuf) prevLen := len(f.wbuf)
f.wbuf = msg.Encode(f.wbuf) newBuf, err := msg.Encode(f.wbuf)
if err != nil {
f.encodeError = err
return
}
f.wbuf = newBuf
if f.tracer != nil { if f.tracer != nil {
f.tracer.TraceQueryute('F', int32(len(f.wbuf)-prevLen), msg) f.tracer.TraceQueryute('F', int32(len(f.wbuf)-prevLen), msg)
} }
} }
// SendSync sends a Sync message to the backend (i.e. the server). The message is not guaranteed to be written until // SendSync sends a Sync message to the backend (i.e. the server). The message is buffered until Flush is called. Any
// Flush is called. // error encountered will be returned from Flush.
func (f *Frontend) SendSync(msg *Sync) { func (f *Frontend) SendSync(msg *Sync) {
if f.encodeError != nil {
return
}
prevLen := len(f.wbuf) prevLen := len(f.wbuf)
f.wbuf = msg.Encode(f.wbuf) newBuf, err := msg.Encode(f.wbuf)
if err != nil {
f.encodeError = err
return
}
f.wbuf = newBuf
if f.tracer != nil { if f.tracer != nil {
f.tracer.traceSync('F', int32(len(f.wbuf)-prevLen), msg) f.tracer.traceSync('F', int32(len(f.wbuf)-prevLen), msg)
} }
} }
// SendQuery sends a Query message to the backend (i.e. the server). The message is not guaranteed to be written until // SendQuery sends a Query message to the backend (i.e. the server). The message is buffered until Flush is called. Any
// Flush is called. // error encountered will be returned from Flush.
func (f *Frontend) SendQuery(msg *Query) { func (f *Frontend) SendQuery(msg *Query) {
if f.encodeError != nil {
return
}
prevLen := len(f.wbuf) prevLen := len(f.wbuf)
f.wbuf = msg.Encode(f.wbuf) newBuf, err := msg.Encode(f.wbuf)
if err != nil {
f.encodeError = err
return
}
f.wbuf = newBuf
if f.tracer != nil { if f.tracer != nil {
f.tracer.traceQuery('F', int32(len(f.wbuf)-prevLen), msg) f.tracer.traceQuery('F', int32(len(f.wbuf)-prevLen), msg)
} }

View file

@ -2,6 +2,8 @@ package pgproto3
import ( import (
"encoding/binary" "encoding/binary"
"errors"
"math"
"github.com/jackc/pgx/v5/internal/pgio" "github.com/jackc/pgx/v5/internal/pgio"
) )
@ -71,15 +73,21 @@ func (dst *FunctionCall) Decode(src []byte) error {
} }
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *FunctionCall) Encode(dst []byte) []byte { func (src *FunctionCall) Encode(dst []byte) ([]byte, error) {
dst = append(dst, 'F') dst, sp := beginMessage(dst, 'F')
sp := len(dst)
dst = pgio.AppendUint32(dst, 0) // Unknown length, set it at the end
dst = pgio.AppendUint32(dst, src.Function) dst = pgio.AppendUint32(dst, src.Function)
if len(src.ArgFormatCodes) > math.MaxUint16 {
return nil, errors.New("too many arg format codes")
}
dst = pgio.AppendUint16(dst, uint16(len(src.ArgFormatCodes))) dst = pgio.AppendUint16(dst, uint16(len(src.ArgFormatCodes)))
for _, argFormatCode := range src.ArgFormatCodes { for _, argFormatCode := range src.ArgFormatCodes {
dst = pgio.AppendUint16(dst, argFormatCode) dst = pgio.AppendUint16(dst, argFormatCode)
} }
if len(src.Arguments) > math.MaxUint16 {
return nil, errors.New("too many arguments")
}
dst = pgio.AppendUint16(dst, uint16(len(src.Arguments))) dst = pgio.AppendUint16(dst, uint16(len(src.Arguments)))
for _, argument := range src.Arguments { for _, argument := range src.Arguments {
if argument == nil { if argument == nil {
@ -90,6 +98,5 @@ func (src *FunctionCall) Encode(dst []byte) []byte {
} }
} }
dst = pgio.AppendUint16(dst, src.ResultFormatCode) dst = pgio.AppendUint16(dst, src.ResultFormatCode)
pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) return finishMessage(dst, sp)
return dst
} }

View file

@ -39,10 +39,8 @@ func (dst *FunctionCallResponse) Decode(src []byte) error {
} }
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *FunctionCallResponse) Encode(dst []byte) []byte { func (src *FunctionCallResponse) Encode(dst []byte) ([]byte, error) {
dst = append(dst, 'V') dst, sp := beginMessage(dst, 'V')
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
if src.Result == nil { if src.Result == nil {
dst = pgio.AppendInt32(dst, -1) dst = pgio.AppendInt32(dst, -1)
@ -51,9 +49,7 @@ func (src *FunctionCallResponse) Encode(dst []byte) []byte {
dst = append(dst, src.Result...) dst = append(dst, src.Result...)
} }
pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) return finishMessage(dst, sp)
return dst
} }
// MarshalJSON implements encoding/json.Marshaler. // MarshalJSON implements encoding/json.Marshaler.

View file

@ -31,10 +31,10 @@ func (dst *GSSEncRequest) Decode(src []byte) error {
} }
// Encode encodes src into dst. dst will include the 4 byte message length. // Encode encodes src into dst. dst will include the 4 byte message length.
func (src *GSSEncRequest) Encode(dst []byte) []byte { func (src *GSSEncRequest) Encode(dst []byte) ([]byte, error) {
dst = pgio.AppendInt32(dst, 8) dst = pgio.AppendInt32(dst, 8)
dst = pgio.AppendInt32(dst, gssEncReqNumber) dst = pgio.AppendInt32(dst, gssEncReqNumber)
return dst return dst, nil
} }
// MarshalJSON implements encoding/json.Marshaler. // MarshalJSON implements encoding/json.Marshaler.

View file

@ -2,8 +2,6 @@ package pgproto3
import ( import (
"encoding/json" "encoding/json"
"github.com/jackc/pgx/v5/internal/pgio"
) )
type GSSResponse struct { type GSSResponse struct {
@ -18,11 +16,10 @@ func (g *GSSResponse) Decode(data []byte) error {
return nil return nil
} }
func (g *GSSResponse) Encode(dst []byte) []byte { func (g *GSSResponse) Encode(dst []byte) ([]byte, error) {
dst = append(dst, 'p') dst, sp := beginMessage(dst, 'p')
dst = pgio.AppendInt32(dst, int32(4+len(g.Data)))
dst = append(dst, g.Data...) dst = append(dst, g.Data...)
return dst return finishMessage(dst, sp)
} }
// MarshalJSON implements encoding/json.Marshaler. // MarshalJSON implements encoding/json.Marshaler.

View file

@ -20,8 +20,8 @@ func (dst *NoData) Decode(src []byte) error {
} }
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *NoData) Encode(dst []byte) []byte { func (src *NoData) Encode(dst []byte) ([]byte, error) {
return append(dst, 'n', 0, 0, 0, 4) return append(dst, 'n', 0, 0, 0, 4), nil
} }
// MarshalJSON implements encoding/json.Marshaler. // MarshalJSON implements encoding/json.Marshaler.

View file

@ -12,6 +12,8 @@ func (dst *NoticeResponse) Decode(src []byte) error {
} }
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *NoticeResponse) Encode(dst []byte) []byte { func (src *NoticeResponse) Encode(dst []byte) ([]byte, error) {
return append(dst, (*ErrorResponse)(src).marshalBinary('N')...) dst, sp := beginMessage(dst, 'N')
dst = (*ErrorResponse)(src).appendFields(dst)
return finishMessage(dst, sp)
} }

View file

@ -45,20 +45,14 @@ func (dst *NotificationResponse) Decode(src []byte) error {
} }
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *NotificationResponse) Encode(dst []byte) []byte { func (src *NotificationResponse) Encode(dst []byte) ([]byte, error) {
dst = append(dst, 'A') dst, sp := beginMessage(dst, 'A')
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
dst = pgio.AppendUint32(dst, src.PID) dst = pgio.AppendUint32(dst, src.PID)
dst = append(dst, src.Channel...) dst = append(dst, src.Channel...)
dst = append(dst, 0) dst = append(dst, 0)
dst = append(dst, src.Payload...) dst = append(dst, src.Payload...)
dst = append(dst, 0) dst = append(dst, 0)
return finishMessage(dst, sp)
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
return dst
} }
// MarshalJSON implements encoding/json.Marshaler. // MarshalJSON implements encoding/json.Marshaler.

View file

@ -4,6 +4,8 @@ import (
"bytes" "bytes"
"encoding/binary" "encoding/binary"
"encoding/json" "encoding/json"
"errors"
"math"
"github.com/jackc/pgx/v5/internal/pgio" "github.com/jackc/pgx/v5/internal/pgio"
) )
@ -39,19 +41,18 @@ func (dst *ParameterDescription) Decode(src []byte) error {
} }
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *ParameterDescription) Encode(dst []byte) []byte { func (src *ParameterDescription) Encode(dst []byte) ([]byte, error) {
dst = append(dst, 't') dst, sp := beginMessage(dst, 't')
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
if len(src.ParameterOIDs) > math.MaxUint16 {
return nil, errors.New("too many parameter oids")
}
dst = pgio.AppendUint16(dst, uint16(len(src.ParameterOIDs))) dst = pgio.AppendUint16(dst, uint16(len(src.ParameterOIDs)))
for _, oid := range src.ParameterOIDs { for _, oid := range src.ParameterOIDs {
dst = pgio.AppendUint32(dst, oid) dst = pgio.AppendUint32(dst, oid)
} }
pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) return finishMessage(dst, sp)
return dst
} }
// MarshalJSON implements encoding/json.Marshaler. // MarshalJSON implements encoding/json.Marshaler.

View file

@ -3,8 +3,6 @@ package pgproto3
import ( import (
"bytes" "bytes"
"encoding/json" "encoding/json"
"github.com/jackc/pgx/v5/internal/pgio"
) )
type ParameterStatus struct { type ParameterStatus struct {
@ -37,19 +35,13 @@ func (dst *ParameterStatus) Decode(src []byte) error {
} }
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *ParameterStatus) Encode(dst []byte) []byte { func (src *ParameterStatus) Encode(dst []byte) ([]byte, error) {
dst = append(dst, 'S') dst, sp := beginMessage(dst, 'S')
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
dst = append(dst, src.Name...) dst = append(dst, src.Name...)
dst = append(dst, 0) dst = append(dst, 0)
dst = append(dst, src.Value...) dst = append(dst, src.Value...)
dst = append(dst, 0) dst = append(dst, 0)
return finishMessage(dst, sp)
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
return dst
} }
// MarshalJSON implements encoding/json.Marshaler. // MarshalJSON implements encoding/json.Marshaler.

View file

@ -4,6 +4,8 @@ import (
"bytes" "bytes"
"encoding/binary" "encoding/binary"
"encoding/json" "encoding/json"
"errors"
"math"
"github.com/jackc/pgx/v5/internal/pgio" "github.com/jackc/pgx/v5/internal/pgio"
) )
@ -52,24 +54,23 @@ func (dst *Parse) Decode(src []byte) error {
} }
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *Parse) Encode(dst []byte) []byte { func (src *Parse) Encode(dst []byte) ([]byte, error) {
dst = append(dst, 'P') dst, sp := beginMessage(dst, 'P')
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
dst = append(dst, src.Name...) dst = append(dst, src.Name...)
dst = append(dst, 0) dst = append(dst, 0)
dst = append(dst, src.Query...) dst = append(dst, src.Query...)
dst = append(dst, 0) dst = append(dst, 0)
if len(src.ParameterOIDs) > math.MaxUint16 {
return nil, errors.New("too many parameter oids")
}
dst = pgio.AppendUint16(dst, uint16(len(src.ParameterOIDs))) dst = pgio.AppendUint16(dst, uint16(len(src.ParameterOIDs)))
for _, oid := range src.ParameterOIDs { for _, oid := range src.ParameterOIDs {
dst = pgio.AppendUint32(dst, oid) dst = pgio.AppendUint32(dst, oid)
} }
pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) return finishMessage(dst, sp)
return dst
} }
// MarshalJSON implements encoding/json.Marshaler. // MarshalJSON implements encoding/json.Marshaler.

View file

@ -20,8 +20,8 @@ func (dst *ParseComplete) Decode(src []byte) error {
} }
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *ParseComplete) Encode(dst []byte) []byte { func (src *ParseComplete) Encode(dst []byte) ([]byte, error) {
return append(dst, '1', 0, 0, 0, 4) return append(dst, '1', 0, 0, 0, 4), nil
} }
// MarshalJSON implements encoding/json.Marshaler. // MarshalJSON implements encoding/json.Marshaler.

View file

@ -3,8 +3,6 @@ package pgproto3
import ( import (
"bytes" "bytes"
"encoding/json" "encoding/json"
"github.com/jackc/pgx/v5/internal/pgio"
) )
type PasswordMessage struct { type PasswordMessage struct {
@ -32,14 +30,11 @@ func (dst *PasswordMessage) Decode(src []byte) error {
} }
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *PasswordMessage) Encode(dst []byte) []byte { func (src *PasswordMessage) Encode(dst []byte) ([]byte, error) {
dst = append(dst, 'p') dst, sp := beginMessage(dst, 'p')
dst = pgio.AppendInt32(dst, int32(4+len(src.Password)+1))
dst = append(dst, src.Password...) dst = append(dst, src.Password...)
dst = append(dst, 0) dst = append(dst, 0)
return finishMessage(dst, sp)
return dst
} }
// MarshalJSON implements encoding/json.Marshaler. // MarshalJSON implements encoding/json.Marshaler.

View file

@ -4,8 +4,14 @@ import (
"encoding/hex" "encoding/hex"
"errors" "errors"
"fmt" "fmt"
"github.com/jackc/pgx/v5/internal/pgio"
) )
// maxMessageBodyLen is the maximum length of a message body in bytes. See PG_LARGE_MESSAGE_LIMIT in the PostgreSQL
// source. It is defined as (MaxAllocSize - 1). MaxAllocSize is defined as 0x3fffffff.
const maxMessageBodyLen = (0x3fffffff - 1)
// Message is the interface implemented by an object that can decode and encode // Message is the interface implemented by an object that can decode and encode
// a particular PostgreSQL message. // a particular PostgreSQL message.
type Message interface { type Message interface {
@ -14,7 +20,7 @@ type Message interface {
Decode(data []byte) error Decode(data []byte) error
// Encode appends itself to dst and returns the new buffer. // Encode appends itself to dst and returns the new buffer.
Encode(dst []byte) []byte Encode(dst []byte) ([]byte, error)
} }
// FrontendMessage is a message sent by the frontend (i.e. the client). // FrontendMessage is a message sent by the frontend (i.e. the client).
@ -92,3 +98,23 @@ func getValueFromJSON(v map[string]string) ([]byte, error) {
} }
return nil, errors.New("unknown protocol representation") return nil, errors.New("unknown protocol representation")
} }
// beginMessage begines a new message of type t. It appends the message type and a placeholder for the message length to
// dst. It returns the new buffer and the position of the message length placeholder.
func beginMessage(dst []byte, t byte) ([]byte, int) {
dst = append(dst, t)
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
return dst, sp
}
// finishMessage finishes a message that was started with beginMessage. It computes the message length and writes it to
// dst[sp]. If the message length is too large it returns an error. Otherwise it returns the final message buffer.
func finishMessage(dst []byte, sp int) ([]byte, error) {
messageBodyLen := len(dst[sp:])
if messageBodyLen > maxMessageBodyLen {
return nil, errors.New("message body too large")
}
pgio.SetInt32(dst[sp:], int32(messageBodyLen))
return dst, nil
}

View file

@ -20,8 +20,8 @@ func (dst *PortalSuspended) Decode(src []byte) error {
} }
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *PortalSuspended) Encode(dst []byte) []byte { func (src *PortalSuspended) Encode(dst []byte) ([]byte, error) {
return append(dst, 's', 0, 0, 0, 4) return append(dst, 's', 0, 0, 0, 4), nil
} }
// MarshalJSON implements encoding/json.Marshaler. // MarshalJSON implements encoding/json.Marshaler.

View file

@ -3,8 +3,6 @@ package pgproto3
import ( import (
"bytes" "bytes"
"encoding/json" "encoding/json"
"github.com/jackc/pgx/v5/internal/pgio"
) )
type Query struct { type Query struct {
@ -28,14 +26,11 @@ func (dst *Query) Decode(src []byte) error {
} }
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *Query) Encode(dst []byte) []byte { func (src *Query) Encode(dst []byte) ([]byte, error) {
dst = append(dst, 'Q') dst, sp := beginMessage(dst, 'Q')
dst = pgio.AppendInt32(dst, int32(4+len(src.String)+1))
dst = append(dst, src.String...) dst = append(dst, src.String...)
dst = append(dst, 0) dst = append(dst, 0)
return finishMessage(dst, sp)
return dst
} }
// MarshalJSON implements encoding/json.Marshaler. // MarshalJSON implements encoding/json.Marshaler.

View file

@ -25,8 +25,8 @@ func (dst *ReadyForQuery) Decode(src []byte) error {
} }
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *ReadyForQuery) Encode(dst []byte) []byte { func (src *ReadyForQuery) Encode(dst []byte) ([]byte, error) {
return append(dst, 'Z', 0, 0, 0, 5, src.TxStatus) return append(dst, 'Z', 0, 0, 0, 5, src.TxStatus), nil
} }
// MarshalJSON implements encoding/json.Marshaler. // MarshalJSON implements encoding/json.Marshaler.

View file

@ -4,6 +4,8 @@ import (
"bytes" "bytes"
"encoding/binary" "encoding/binary"
"encoding/json" "encoding/json"
"errors"
"math"
"github.com/jackc/pgx/v5/internal/pgio" "github.com/jackc/pgx/v5/internal/pgio"
) )
@ -99,11 +101,12 @@ func (dst *RowDescription) Decode(src []byte) error {
} }
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *RowDescription) Encode(dst []byte) []byte { func (src *RowDescription) Encode(dst []byte) ([]byte, error) {
dst = append(dst, 'T') dst, sp := beginMessage(dst, 'T')
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
if len(src.Fields) > math.MaxUint16 {
return nil, errors.New("too many fields")
}
dst = pgio.AppendUint16(dst, uint16(len(src.Fields))) dst = pgio.AppendUint16(dst, uint16(len(src.Fields)))
for _, fd := range src.Fields { for _, fd := range src.Fields {
dst = append(dst, fd.Name...) dst = append(dst, fd.Name...)
@ -117,9 +120,7 @@ func (src *RowDescription) Encode(dst []byte) []byte {
dst = pgio.AppendInt16(dst, fd.Format) dst = pgio.AppendInt16(dst, fd.Format)
} }
pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) return finishMessage(dst, sp)
return dst
} }
// MarshalJSON implements encoding/json.Marshaler. // MarshalJSON implements encoding/json.Marshaler.

View file

@ -39,10 +39,8 @@ func (dst *SASLInitialResponse) Decode(src []byte) error {
} }
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *SASLInitialResponse) Encode(dst []byte) []byte { func (src *SASLInitialResponse) Encode(dst []byte) ([]byte, error) {
dst = append(dst, 'p') dst, sp := beginMessage(dst, 'p')
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
dst = append(dst, []byte(src.AuthMechanism)...) dst = append(dst, []byte(src.AuthMechanism)...)
dst = append(dst, 0) dst = append(dst, 0)
@ -50,9 +48,7 @@ func (src *SASLInitialResponse) Encode(dst []byte) []byte {
dst = pgio.AppendInt32(dst, int32(len(src.Data))) dst = pgio.AppendInt32(dst, int32(len(src.Data)))
dst = append(dst, src.Data...) dst = append(dst, src.Data...)
pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) return finishMessage(dst, sp)
return dst
} }
// MarshalJSON implements encoding/json.Marshaler. // MarshalJSON implements encoding/json.Marshaler.

View file

@ -3,8 +3,6 @@ package pgproto3
import ( import (
"encoding/hex" "encoding/hex"
"encoding/json" "encoding/json"
"github.com/jackc/pgx/v5/internal/pgio"
) )
type SASLResponse struct { type SASLResponse struct {
@ -22,13 +20,10 @@ func (dst *SASLResponse) Decode(src []byte) error {
} }
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *SASLResponse) Encode(dst []byte) []byte { func (src *SASLResponse) Encode(dst []byte) ([]byte, error) {
dst = append(dst, 'p') dst, sp := beginMessage(dst, 'p')
dst = pgio.AppendInt32(dst, int32(4+len(src.Data)))
dst = append(dst, src.Data...) dst = append(dst, src.Data...)
return finishMessage(dst, sp)
return dst
} }
// MarshalJSON implements encoding/json.Marshaler. // MarshalJSON implements encoding/json.Marshaler.

View file

@ -31,10 +31,10 @@ func (dst *SSLRequest) Decode(src []byte) error {
} }
// Encode encodes src into dst. dst will include the 4 byte message length. // Encode encodes src into dst. dst will include the 4 byte message length.
func (src *SSLRequest) Encode(dst []byte) []byte { func (src *SSLRequest) Encode(dst []byte) ([]byte, error) {
dst = pgio.AppendInt32(dst, 8) dst = pgio.AppendInt32(dst, 8)
dst = pgio.AppendInt32(dst, sslRequestNumber) dst = pgio.AppendInt32(dst, sslRequestNumber)
return dst return dst, nil
} }
// MarshalJSON implements encoding/json.Marshaler. // MarshalJSON implements encoding/json.Marshaler.

View file

@ -64,7 +64,7 @@ func (dst *StartupMessage) Decode(src []byte) error {
} }
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *StartupMessage) Encode(dst []byte) []byte { func (src *StartupMessage) Encode(dst []byte) ([]byte, error) {
sp := len(dst) sp := len(dst)
dst = pgio.AppendInt32(dst, -1) dst = pgio.AppendInt32(dst, -1)
@ -77,9 +77,7 @@ func (src *StartupMessage) Encode(dst []byte) []byte {
} }
dst = append(dst, 0) dst = append(dst, 0)
pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) return finishMessage(dst, sp)
return dst
} }
// MarshalJSON implements encoding/json.Marshaler. // MarshalJSON implements encoding/json.Marshaler.

View file

@ -20,8 +20,8 @@ func (dst *Sync) Decode(src []byte) error {
} }
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *Sync) Encode(dst []byte) []byte { func (src *Sync) Encode(dst []byte) ([]byte, error) {
return append(dst, 'S', 0, 0, 0, 4) return append(dst, 'S', 0, 0, 0, 4), nil
} }
// MarshalJSON implements encoding/json.Marshaler. // MarshalJSON implements encoding/json.Marshaler.

View file

@ -20,8 +20,8 @@ func (dst *Terminate) Decode(src []byte) error {
} }
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *Terminate) Encode(dst []byte) []byte { func (src *Terminate) Encode(dst []byte) ([]byte, error) {
return append(dst, 'X', 0, 0, 0, 4) return append(dst, 'X', 0, 0, 0, 4), nil
} }
// MarshalJSON implements encoding/json.Marshaler. // MarshalJSON implements encoding/json.Marshaler.

View file

@ -176,8 +176,10 @@ func (scanPlanBinaryBitsToBitsScanner) Scan(src []byte, dst any) error {
bitLen := int32(binary.BigEndian.Uint32(src)) bitLen := int32(binary.BigEndian.Uint32(src))
rp := 4 rp := 4
buf := make([]byte, len(src[rp:]))
copy(buf, src[rp:])
return scanner.ScanBits(Bits{Bytes: src[rp:], Len: bitLen, Valid: true}) return scanner.ScanBits(Bits{Bytes: buf, Len: bitLen, Valid: true})
} }
type scanPlanTextAnyToBitsScanner struct{} type scanPlanTextAnyToBitsScanner struct{}

View file

@ -297,12 +297,12 @@ func (c Float4Codec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, sr
return nil, nil return nil, nil
} }
var n float64 var n float32
err := codecScan(c, m, oid, format, src, &n) err := codecScan(c, m, oid, format, src, &n)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return n, nil return float64(n), nil
} }
func (c Float4Codec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { func (c Float4Codec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) {

View file

@ -25,6 +25,11 @@ func (c JSONCodec) PlanEncode(m *Map, oid uint32, format int16, value any) Encod
case []byte: case []byte:
return encodePlanJSONCodecEitherFormatByteSlice{} return encodePlanJSONCodecEitherFormatByteSlice{}
// Handle json.RawMessage specifically because if it is run through json.Marshal it may be mutated.
// e.g. `{"foo": "bar"}` -> `{"foo":"bar"}`.
case json.RawMessage:
return encodePlanJSONCodecEitherFormatJSONRawMessage{}
// Cannot rely on driver.Valuer being handled later because anything can be marshalled. // Cannot rely on driver.Valuer being handled later because anything can be marshalled.
// //
// https://github.com/jackc/pgx/issues/1430 // https://github.com/jackc/pgx/issues/1430
@ -79,6 +84,18 @@ func (encodePlanJSONCodecEitherFormatByteSlice) Encode(value any, buf []byte) (n
return buf, nil return buf, nil
} }
type encodePlanJSONCodecEitherFormatJSONRawMessage struct{}
func (encodePlanJSONCodecEitherFormatJSONRawMessage) Encode(value any, buf []byte) (newBuf []byte, err error) {
jsonBytes := value.(json.RawMessage)
if jsonBytes == nil {
return nil, nil
}
buf = append(buf, jsonBytes...)
return buf, nil
}
type encodePlanJSONCodecEitherFormatMarshal struct{} type encodePlanJSONCodecEitherFormatMarshal struct{}
func (encodePlanJSONCodecEitherFormatMarshal) Encode(value any, buf []byte) (newBuf []byte, err error) { func (encodePlanJSONCodecEitherFormatMarshal) Encode(value any, buf []byte) (newBuf []byte, err error) {

View file

@ -561,7 +561,7 @@ func TryFindUnderlyingTypeScanPlan(dst any) (plan WrappedScanPlanNextSetter, nex
} }
} }
if nextDstType != nil && dstValue.Type() != nextDstType { if nextDstType != nil && dstValue.Type() != nextDstType && dstValue.CanConvert(nextDstType) {
return &underlyingTypeScanPlan{dstType: dstValue.Type(), nextDstType: nextDstType}, dstValue.Convert(nextDstType).Interface(), true return &underlyingTypeScanPlan{dstType: dstValue.Type(), nextDstType: nextDstType}, dstValue.Convert(nextDstType).Interface(), true
} }

View file

@ -1,6 +1,7 @@
package pgtype package pgtype
import ( import (
"encoding/json"
"net" "net"
"net/netip" "net/netip"
"reflect" "reflect"
@ -173,6 +174,7 @@ func initDefaultMap() {
registerDefaultPgTypeVariants[time.Time](defaultMap, "timestamptz") registerDefaultPgTypeVariants[time.Time](defaultMap, "timestamptz")
registerDefaultPgTypeVariants[time.Duration](defaultMap, "interval") registerDefaultPgTypeVariants[time.Duration](defaultMap, "interval")
registerDefaultPgTypeVariants[string](defaultMap, "text") registerDefaultPgTypeVariants[string](defaultMap, "text")
registerDefaultPgTypeVariants[json.RawMessage](defaultMap, "json")
registerDefaultPgTypeVariants[[]byte](defaultMap, "bytea") registerDefaultPgTypeVariants[[]byte](defaultMap, "bytea")
registerDefaultPgTypeVariants[net.IP](defaultMap, "inet") registerDefaultPgTypeVariants[net.IP](defaultMap, "inet")

View file

@ -438,7 +438,7 @@ func AppendRows[T any, S ~[]T](slice S, rows Rows, fn RowToFunc[T]) (S, error) {
// CollectRows iterates through rows, calling fn for each row, and collecting the results into a slice of T. // CollectRows iterates through rows, calling fn for each row, and collecting the results into a slice of T.
func CollectRows[T any](rows Rows, fn RowToFunc[T]) ([]T, error) { func CollectRows[T any](rows Rows, fn RowToFunc[T]) ([]T, error) {
return AppendRows([]T(nil), rows, fn) return AppendRows([]T{}, rows, fn)
} }
// CollectOneRow calls fn for the first row in rows and returns the result. If no rows are found returns an error where errors.Is(ErrNoRows) is true. // CollectOneRow calls fn for the first row in rows and returns the result. If no rows are found returns an error where errors.Is(ErrNoRows) is true.

2
vendor/modules.txt vendored
View file

@ -405,7 +405,7 @@ github.com/jackc/pgpassfile
# github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a # github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a
## explicit; go 1.14 ## explicit; go 1.14
github.com/jackc/pgservicefile github.com/jackc/pgservicefile
# github.com/jackc/pgx/v5 v5.5.3 # github.com/jackc/pgx/v5 v5.5.5
## explicit; go 1.19 ## explicit; go 1.19
github.com/jackc/pgx/v5 github.com/jackc/pgx/v5
github.com/jackc/pgx/v5/internal/anynil github.com/jackc/pgx/v5/internal/anynil