gotosocial/vendor/github.com/jackc/pgx/v5/pgtype/composite.go
tobi ec325fee14
[chore] Update a bunch of database dependencies (#1772)
* [chore] Update a bunch of database dependencies

* fix lil thing
2023-05-12 14:33:40 +02:00

603 lines
14 KiB
Go

package pgtype
import (
"database/sql/driver"
"encoding/binary"
"errors"
"fmt"
"strings"
"github.com/jackc/pgx/v5/internal/pgio"
)
// CompositeIndexGetter is a type accessed by index that can be converted into a PostgreSQL composite.
type CompositeIndexGetter interface {
// IsNull returns true if the value is SQL NULL.
IsNull() bool
// Index returns the element at i.
Index(i int) any
}
// CompositeIndexScanner is a type accessed by index that can be scanned from a PostgreSQL composite.
type CompositeIndexScanner interface {
// ScanNull sets the value to SQL NULL.
ScanNull() error
// ScanIndex returns a value usable as a scan target for i.
ScanIndex(i int) any
}
type CompositeCodecField struct {
Name string
Type *Type
}
type CompositeCodec struct {
Fields []CompositeCodecField
}
func (c *CompositeCodec) FormatSupported(format int16) bool {
for _, f := range c.Fields {
if !f.Type.Codec.FormatSupported(format) {
return false
}
}
return true
}
func (c *CompositeCodec) PreferredFormat() int16 {
if c.FormatSupported(BinaryFormatCode) {
return BinaryFormatCode
}
return TextFormatCode
}
func (c *CompositeCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan {
if _, ok := value.(CompositeIndexGetter); !ok {
return nil
}
switch format {
case BinaryFormatCode:
return &encodePlanCompositeCodecCompositeIndexGetterToBinary{cc: c, m: m}
case TextFormatCode:
return &encodePlanCompositeCodecCompositeIndexGetterToText{cc: c, m: m}
}
return nil
}
type encodePlanCompositeCodecCompositeIndexGetterToBinary struct {
cc *CompositeCodec
m *Map
}
func (plan *encodePlanCompositeCodecCompositeIndexGetterToBinary) Encode(value any, buf []byte) (newBuf []byte, err error) {
getter := value.(CompositeIndexGetter)
if getter.IsNull() {
return nil, nil
}
builder := NewCompositeBinaryBuilder(plan.m, buf)
for i, field := range plan.cc.Fields {
builder.AppendValue(field.Type.OID, getter.Index(i))
}
return builder.Finish()
}
type encodePlanCompositeCodecCompositeIndexGetterToText struct {
cc *CompositeCodec
m *Map
}
func (plan *encodePlanCompositeCodecCompositeIndexGetterToText) Encode(value any, buf []byte) (newBuf []byte, err error) {
getter := value.(CompositeIndexGetter)
if getter.IsNull() {
return nil, nil
}
b := NewCompositeTextBuilder(plan.m, buf)
for i, field := range plan.cc.Fields {
b.AppendValue(field.Type.OID, getter.Index(i))
}
return b.Finish()
}
func (c *CompositeCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan {
switch format {
case BinaryFormatCode:
switch target.(type) {
case CompositeIndexScanner:
return &scanPlanBinaryCompositeToCompositeIndexScanner{cc: c, m: m}
}
case TextFormatCode:
switch target.(type) {
case CompositeIndexScanner:
return &scanPlanTextCompositeToCompositeIndexScanner{cc: c, m: m}
}
}
return nil
}
type scanPlanBinaryCompositeToCompositeIndexScanner struct {
cc *CompositeCodec
m *Map
}
func (plan *scanPlanBinaryCompositeToCompositeIndexScanner) Scan(src []byte, target any) error {
targetScanner := (target).(CompositeIndexScanner)
if src == nil {
return targetScanner.ScanNull()
}
scanner := NewCompositeBinaryScanner(plan.m, src)
for i, field := range plan.cc.Fields {
if scanner.Next() {
fieldTarget := targetScanner.ScanIndex(i)
if fieldTarget != nil {
fieldPlan := plan.m.PlanScan(field.Type.OID, BinaryFormatCode, fieldTarget)
if fieldPlan == nil {
return fmt.Errorf("unable to encode %v into OID %d in binary format", field, field.Type.OID)
}
err := fieldPlan.Scan(scanner.Bytes(), fieldTarget)
if err != nil {
return err
}
}
} else {
return errors.New("read past end of composite")
}
}
if err := scanner.Err(); err != nil {
return err
}
return nil
}
type scanPlanTextCompositeToCompositeIndexScanner struct {
cc *CompositeCodec
m *Map
}
func (plan *scanPlanTextCompositeToCompositeIndexScanner) Scan(src []byte, target any) error {
targetScanner := (target).(CompositeIndexScanner)
if src == nil {
return targetScanner.ScanNull()
}
scanner := NewCompositeTextScanner(plan.m, src)
for i, field := range plan.cc.Fields {
if scanner.Next() {
fieldTarget := targetScanner.ScanIndex(i)
if fieldTarget != nil {
fieldPlan := plan.m.PlanScan(field.Type.OID, TextFormatCode, fieldTarget)
if fieldPlan == nil {
return fmt.Errorf("unable to encode %v into OID %d in text format", field, field.Type.OID)
}
err := fieldPlan.Scan(scanner.Bytes(), fieldTarget)
if err != nil {
return err
}
}
} else {
return errors.New("read past end of composite")
}
}
if err := scanner.Err(); err != nil {
return err
}
return nil
}
func (c *CompositeCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) {
if src == nil {
return nil, nil
}
switch format {
case TextFormatCode:
return string(src), nil
case BinaryFormatCode:
buf := make([]byte, len(src))
copy(buf, src)
return buf, nil
default:
return nil, fmt.Errorf("unknown format code %d", format)
}
}
func (c *CompositeCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) {
if src == nil {
return nil, nil
}
switch format {
case TextFormatCode:
scanner := NewCompositeTextScanner(m, src)
values := make(map[string]any, len(c.Fields))
for i := 0; scanner.Next() && i < len(c.Fields); i++ {
var v any
fieldPlan := m.PlanScan(c.Fields[i].Type.OID, TextFormatCode, &v)
if fieldPlan == nil {
return nil, fmt.Errorf("unable to scan OID %d in text format into %v", c.Fields[i].Type.OID, v)
}
err := fieldPlan.Scan(scanner.Bytes(), &v)
if err != nil {
return nil, err
}
values[c.Fields[i].Name] = v
}
if err := scanner.Err(); err != nil {
return nil, err
}
return values, nil
case BinaryFormatCode:
scanner := NewCompositeBinaryScanner(m, src)
values := make(map[string]any, len(c.Fields))
for i := 0; scanner.Next() && i < len(c.Fields); i++ {
var v any
fieldPlan := m.PlanScan(scanner.OID(), BinaryFormatCode, &v)
if fieldPlan == nil {
return nil, fmt.Errorf("unable to scan OID %d in binary format into %v", scanner.OID(), v)
}
err := fieldPlan.Scan(scanner.Bytes(), &v)
if err != nil {
return nil, err
}
values[c.Fields[i].Name] = v
}
if err := scanner.Err(); err != nil {
return nil, err
}
return values, nil
default:
return nil, fmt.Errorf("unknown format code %d", format)
}
}
type CompositeBinaryScanner struct {
m *Map
rp int
src []byte
fieldCount int32
fieldBytes []byte
fieldOID uint32
err error
}
// NewCompositeBinaryScanner a scanner over a binary encoded composite balue.
func NewCompositeBinaryScanner(m *Map, src []byte) *CompositeBinaryScanner {
rp := 0
if len(src[rp:]) < 4 {
return &CompositeBinaryScanner{err: fmt.Errorf("Record incomplete %v", src)}
}
fieldCount := int32(binary.BigEndian.Uint32(src[rp:]))
rp += 4
return &CompositeBinaryScanner{
m: m,
rp: rp,
src: src,
fieldCount: fieldCount,
}
}
// Next advances the scanner to the next field. It returns false after the last field is read or an error occurs. After
// Next returns false, the Err method can be called to check if any errors occurred.
func (cfs *CompositeBinaryScanner) Next() bool {
if cfs.err != nil {
return false
}
if cfs.rp == len(cfs.src) {
return false
}
if len(cfs.src[cfs.rp:]) < 8 {
cfs.err = fmt.Errorf("Record incomplete %v", cfs.src)
return false
}
cfs.fieldOID = binary.BigEndian.Uint32(cfs.src[cfs.rp:])
cfs.rp += 4
fieldLen := int(int32(binary.BigEndian.Uint32(cfs.src[cfs.rp:])))
cfs.rp += 4
if fieldLen >= 0 {
if len(cfs.src[cfs.rp:]) < fieldLen {
cfs.err = fmt.Errorf("Record incomplete rp=%d src=%v", cfs.rp, cfs.src)
return false
}
cfs.fieldBytes = cfs.src[cfs.rp : cfs.rp+fieldLen]
cfs.rp += fieldLen
} else {
cfs.fieldBytes = nil
}
return true
}
func (cfs *CompositeBinaryScanner) FieldCount() int {
return int(cfs.fieldCount)
}
// Bytes returns the bytes of the field most recently read by Scan().
func (cfs *CompositeBinaryScanner) Bytes() []byte {
return cfs.fieldBytes
}
// OID returns the OID of the field most recently read by Scan().
func (cfs *CompositeBinaryScanner) OID() uint32 {
return cfs.fieldOID
}
// Err returns any error encountered by the scanner.
func (cfs *CompositeBinaryScanner) Err() error {
return cfs.err
}
type CompositeTextScanner struct {
m *Map
rp int
src []byte
fieldBytes []byte
err error
}
// NewCompositeTextScanner a scanner over a text encoded composite value.
func NewCompositeTextScanner(m *Map, src []byte) *CompositeTextScanner {
if len(src) < 2 {
return &CompositeTextScanner{err: fmt.Errorf("Record incomplete %v", src)}
}
if src[0] != '(' {
return &CompositeTextScanner{err: fmt.Errorf("composite text format must start with '('")}
}
if src[len(src)-1] != ')' {
return &CompositeTextScanner{err: fmt.Errorf("composite text format must end with ')'")}
}
return &CompositeTextScanner{
m: m,
rp: 1,
src: src,
}
}
// Next advances the scanner to the next field. It returns false after the last field is read or an error occurs. After
// Next returns false, the Err method can be called to check if any errors occurred.
func (cfs *CompositeTextScanner) Next() bool {
if cfs.err != nil {
return false
}
if cfs.rp == len(cfs.src) {
return false
}
switch cfs.src[cfs.rp] {
case ',', ')': // null
cfs.rp++
cfs.fieldBytes = nil
return true
case '"': // quoted value
cfs.rp++
cfs.fieldBytes = make([]byte, 0, 16)
for {
ch := cfs.src[cfs.rp]
if ch == '"' {
cfs.rp++
if cfs.src[cfs.rp] == '"' {
cfs.fieldBytes = append(cfs.fieldBytes, '"')
cfs.rp++
} else {
break
}
} else if ch == '\\' {
cfs.rp++
cfs.fieldBytes = append(cfs.fieldBytes, cfs.src[cfs.rp])
cfs.rp++
} else {
cfs.fieldBytes = append(cfs.fieldBytes, ch)
cfs.rp++
}
}
cfs.rp++
return true
default: // unquoted value
start := cfs.rp
for {
ch := cfs.src[cfs.rp]
if ch == ',' || ch == ')' {
break
}
cfs.rp++
}
cfs.fieldBytes = cfs.src[start:cfs.rp]
cfs.rp++
return true
}
}
// Bytes returns the bytes of the field most recently read by Scan().
func (cfs *CompositeTextScanner) Bytes() []byte {
return cfs.fieldBytes
}
// Err returns any error encountered by the scanner.
func (cfs *CompositeTextScanner) Err() error {
return cfs.err
}
type CompositeBinaryBuilder struct {
m *Map
buf []byte
startIdx int
fieldCount uint32
err error
}
func NewCompositeBinaryBuilder(m *Map, buf []byte) *CompositeBinaryBuilder {
startIdx := len(buf)
buf = append(buf, 0, 0, 0, 0) // allocate room for number of fields
return &CompositeBinaryBuilder{m: m, buf: buf, startIdx: startIdx}
}
func (b *CompositeBinaryBuilder) AppendValue(oid uint32, field any) {
if b.err != nil {
return
}
if field == nil {
b.buf = pgio.AppendUint32(b.buf, oid)
b.buf = pgio.AppendInt32(b.buf, -1)
b.fieldCount++
return
}
plan := b.m.PlanEncode(oid, BinaryFormatCode, field)
if plan == nil {
b.err = fmt.Errorf("unable to encode %v into OID %d in binary format", field, oid)
return
}
b.buf = pgio.AppendUint32(b.buf, oid)
lengthPos := len(b.buf)
b.buf = pgio.AppendInt32(b.buf, -1)
fieldBuf, err := plan.Encode(field, b.buf)
if err != nil {
b.err = err
return
}
if fieldBuf != nil {
binary.BigEndian.PutUint32(fieldBuf[lengthPos:], uint32(len(fieldBuf)-len(b.buf)))
b.buf = fieldBuf
}
b.fieldCount++
}
func (b *CompositeBinaryBuilder) Finish() ([]byte, error) {
if b.err != nil {
return nil, b.err
}
binary.BigEndian.PutUint32(b.buf[b.startIdx:], b.fieldCount)
return b.buf, nil
}
type CompositeTextBuilder struct {
m *Map
buf []byte
startIdx int
fieldCount uint32
err error
fieldBuf [32]byte
}
func NewCompositeTextBuilder(m *Map, buf []byte) *CompositeTextBuilder {
buf = append(buf, '(') // allocate room for number of fields
return &CompositeTextBuilder{m: m, buf: buf}
}
func (b *CompositeTextBuilder) AppendValue(oid uint32, field any) {
if b.err != nil {
return
}
if field == nil {
b.buf = append(b.buf, ',')
return
}
plan := b.m.PlanEncode(oid, TextFormatCode, field)
if plan == nil {
b.err = fmt.Errorf("unable to encode %v into OID %d in text format", field, oid)
return
}
fieldBuf, err := plan.Encode(field, b.fieldBuf[0:0])
if err != nil {
b.err = err
return
}
if fieldBuf != nil {
b.buf = append(b.buf, quoteCompositeFieldIfNeeded(string(fieldBuf))...)
}
b.buf = append(b.buf, ',')
}
func (b *CompositeTextBuilder) Finish() ([]byte, error) {
if b.err != nil {
return nil, b.err
}
b.buf[len(b.buf)-1] = ')'
return b.buf, nil
}
var quoteCompositeReplacer = strings.NewReplacer(`\`, `\\`, `"`, `\"`)
func quoteCompositeField(src string) string {
return `"` + quoteCompositeReplacer.Replace(src) + `"`
}
func quoteCompositeFieldIfNeeded(src string) string {
if src == "" || src[0] == ' ' || src[len(src)-1] == ' ' || strings.ContainsAny(src, `(),"\`) {
return quoteCompositeField(src)
}
return src
}
// CompositeFields represents the values of a composite value. It can be used as an encoding source or as a scan target.
// It cannot scan a NULL, but the composite fields can be NULL.
type CompositeFields []any
func (cf CompositeFields) SkipUnderlyingTypePlan() {}
func (cf CompositeFields) IsNull() bool {
return cf == nil
}
func (cf CompositeFields) Index(i int) any {
return cf[i]
}
func (cf CompositeFields) ScanNull() error {
return fmt.Errorf("cannot scan NULL into CompositeFields")
}
func (cf CompositeFields) ScanIndex(i int) any {
return cf[i]
}