gotosocial/vendor/github.com/jackc/pgx/v5/pgtype/multirange.go
dependabot[bot] 3ab6214449
[chore]: Bump github.com/jackc/pgx/v5 from 5.5.0 to 5.5.1 (#2468)
Bumps [github.com/jackc/pgx/v5](https://github.com/jackc/pgx) from 5.5.0 to 5.5.1.
- [Changelog](https://github.com/jackc/pgx/blob/master/CHANGELOG.md)
- [Commits](https://github.com/jackc/pgx/compare/v5.5.0...v5.5.1)

---
updated-dependencies:
- dependency-name: github.com/jackc/pgx/v5
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2024-01-03 10:03:16 +00:00

444 lines
9.6 KiB
Go

package pgtype
import (
"bytes"
"database/sql/driver"
"encoding/binary"
"fmt"
"reflect"
"github.com/jackc/pgx/v5/internal/pgio"
)
// MultirangeGetter is a type that can be converted into a PostgreSQL multirange.
type MultirangeGetter interface {
// IsNull returns true if the value is SQL NULL.
IsNull() bool
// Len returns the number of elements in the multirange.
Len() int
// Index returns the element at i.
Index(i int) any
// IndexType returns a non-nil scan target of the type Index will return. This is used by MultirangeCodec.PlanEncode.
IndexType() any
}
// MultirangeSetter is a type can be set from a PostgreSQL multirange.
type MultirangeSetter interface {
// ScanNull sets the value to SQL NULL.
ScanNull() error
// SetLen prepares the value such that ScanIndex can be called for each element. This will remove any existing
// elements.
SetLen(n int) error
// ScanIndex returns a value usable as a scan target for i. SetLen must be called before ScanIndex.
ScanIndex(i int) any
// ScanIndexType returns a non-nil scan target of the type ScanIndex will return. This is used by
// MultirangeCodec.PlanScan.
ScanIndexType() any
}
// MultirangeCodec is a codec for any multirange type.
type MultirangeCodec struct {
ElementType *Type
}
func (c *MultirangeCodec) FormatSupported(format int16) bool {
return c.ElementType.Codec.FormatSupported(format)
}
func (c *MultirangeCodec) PreferredFormat() int16 {
return c.ElementType.Codec.PreferredFormat()
}
func (c *MultirangeCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan {
multirangeValuer, ok := value.(MultirangeGetter)
if !ok {
return nil
}
elementType := multirangeValuer.IndexType()
elementEncodePlan := m.PlanEncode(c.ElementType.OID, format, elementType)
if elementEncodePlan == nil {
return nil
}
switch format {
case BinaryFormatCode:
return &encodePlanMultirangeCodecBinary{ac: c, m: m, oid: oid}
case TextFormatCode:
return &encodePlanMultirangeCodecText{ac: c, m: m, oid: oid}
}
return nil
}
type encodePlanMultirangeCodecText struct {
ac *MultirangeCodec
m *Map
oid uint32
}
func (p *encodePlanMultirangeCodecText) Encode(value any, buf []byte) (newBuf []byte, err error) {
multirange := value.(MultirangeGetter)
if multirange.IsNull() {
return nil, nil
}
elementCount := multirange.Len()
buf = append(buf, '{')
var encodePlan EncodePlan
var lastElemType reflect.Type
inElemBuf := make([]byte, 0, 32)
for i := 0; i < elementCount; i++ {
if i > 0 {
buf = append(buf, ',')
}
elem := multirange.Index(i)
var elemBuf []byte
if elem != nil {
elemType := reflect.TypeOf(elem)
if lastElemType != elemType {
lastElemType = elemType
encodePlan = p.m.PlanEncode(p.ac.ElementType.OID, TextFormatCode, elem)
if encodePlan == nil {
return nil, fmt.Errorf("unable to encode %v", multirange.Index(i))
}
}
elemBuf, err = encodePlan.Encode(elem, inElemBuf)
if err != nil {
return nil, err
}
}
if elemBuf == nil {
return nil, fmt.Errorf("multirange cannot contain NULL element")
} else {
buf = append(buf, elemBuf...)
}
}
buf = append(buf, '}')
return buf, nil
}
type encodePlanMultirangeCodecBinary struct {
ac *MultirangeCodec
m *Map
oid uint32
}
func (p *encodePlanMultirangeCodecBinary) Encode(value any, buf []byte) (newBuf []byte, err error) {
multirange := value.(MultirangeGetter)
if multirange.IsNull() {
return nil, nil
}
elementCount := multirange.Len()
buf = pgio.AppendInt32(buf, int32(elementCount))
var encodePlan EncodePlan
var lastElemType reflect.Type
for i := 0; i < elementCount; i++ {
sp := len(buf)
buf = pgio.AppendInt32(buf, -1)
elem := multirange.Index(i)
var elemBuf []byte
if elem != nil {
elemType := reflect.TypeOf(elem)
if lastElemType != elemType {
lastElemType = elemType
encodePlan = p.m.PlanEncode(p.ac.ElementType.OID, BinaryFormatCode, elem)
if encodePlan == nil {
return nil, fmt.Errorf("unable to encode %v", multirange.Index(i))
}
}
elemBuf, err = encodePlan.Encode(elem, buf)
if err != nil {
return nil, err
}
}
if elemBuf == nil {
return nil, fmt.Errorf("multirange cannot contain NULL element")
} else {
buf = elemBuf
pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4))
}
}
return buf, nil
}
func (c *MultirangeCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan {
multirangeScanner, ok := target.(MultirangeSetter)
if !ok {
return nil
}
elementType := multirangeScanner.ScanIndexType()
elementScanPlan := m.PlanScan(c.ElementType.OID, format, elementType)
if _, ok := elementScanPlan.(*scanPlanFail); ok {
return nil
}
return &scanPlanMultirangeCodec{
multirangeCodec: c,
m: m,
oid: oid,
formatCode: format,
}
}
func (c *MultirangeCodec) decodeBinary(m *Map, multirangeOID uint32, src []byte, multirange MultirangeSetter) error {
rp := 0
elementCount := int(binary.BigEndian.Uint32(src[rp:]))
rp += 4
err := multirange.SetLen(elementCount)
if err != nil {
return err
}
if elementCount == 0 {
return nil
}
elementScanPlan := c.ElementType.Codec.PlanScan(m, c.ElementType.OID, BinaryFormatCode, multirange.ScanIndex(0))
if elementScanPlan == nil {
elementScanPlan = m.PlanScan(c.ElementType.OID, BinaryFormatCode, multirange.ScanIndex(0))
}
for i := 0; i < elementCount; i++ {
elem := multirange.ScanIndex(i)
elemLen := int(int32(binary.BigEndian.Uint32(src[rp:])))
rp += 4
var elemSrc []byte
if elemLen >= 0 {
elemSrc = src[rp : rp+elemLen]
rp += elemLen
}
err = elementScanPlan.Scan(elemSrc, elem)
if err != nil {
return fmt.Errorf("failed to scan multirange element %d: %w", i, err)
}
}
return nil
}
func (c *MultirangeCodec) decodeText(m *Map, multirangeOID uint32, src []byte, multirange MultirangeSetter) error {
elements, err := parseUntypedTextMultirange(src)
if err != nil {
return err
}
err = multirange.SetLen(len(elements))
if err != nil {
return err
}
if len(elements) == 0 {
return nil
}
elementScanPlan := c.ElementType.Codec.PlanScan(m, c.ElementType.OID, TextFormatCode, multirange.ScanIndex(0))
if elementScanPlan == nil {
elementScanPlan = m.PlanScan(c.ElementType.OID, TextFormatCode, multirange.ScanIndex(0))
}
for i, s := range elements {
elem := multirange.ScanIndex(i)
err = elementScanPlan.Scan([]byte(s), elem)
if err != nil {
return err
}
}
return nil
}
type scanPlanMultirangeCodec struct {
multirangeCodec *MultirangeCodec
m *Map
oid uint32
formatCode int16
elementScanPlan ScanPlan
}
func (spac *scanPlanMultirangeCodec) Scan(src []byte, dst any) error {
c := spac.multirangeCodec
m := spac.m
oid := spac.oid
formatCode := spac.formatCode
multirange := dst.(MultirangeSetter)
if src == nil {
return multirange.ScanNull()
}
switch formatCode {
case BinaryFormatCode:
return c.decodeBinary(m, oid, src, multirange)
case TextFormatCode:
return c.decodeText(m, oid, src, multirange)
default:
return fmt.Errorf("unknown format code %d", formatCode)
}
}
func (c *MultirangeCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) {
if src == nil {
return nil, nil
}
switch format {
case TextFormatCode:
return string(src), nil
case BinaryFormatCode:
buf := make([]byte, len(src))
copy(buf, src)
return buf, nil
default:
return nil, fmt.Errorf("unknown format code %d", format)
}
}
func (c *MultirangeCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) {
if src == nil {
return nil, nil
}
var multirange Multirange[Range[any]]
err := m.PlanScan(oid, format, &multirange).Scan(src, &multirange)
return multirange, err
}
func parseUntypedTextMultirange(src []byte) ([]string, error) {
elements := make([]string, 0)
buf := bytes.NewBuffer(src)
skipWhitespace(buf)
r, _, err := buf.ReadRune()
if err != nil {
return nil, fmt.Errorf("invalid array: %w", err)
}
if r != '{' {
return nil, fmt.Errorf("invalid multirange, expected '{' got %v", r)
}
parseValueLoop:
for {
r, _, err = buf.ReadRune()
if err != nil {
return nil, fmt.Errorf("invalid multirange: %w", err)
}
switch r {
case ',': // skip range separator
case '}':
break parseValueLoop
default:
buf.UnreadRune()
value, err := parseRange(buf)
if err != nil {
return nil, fmt.Errorf("invalid multirange value: %w", err)
}
elements = append(elements, value)
}
}
skipWhitespace(buf)
if buf.Len() > 0 {
return nil, fmt.Errorf("unexpected trailing data: %v", buf.String())
}
return elements, nil
}
func parseRange(buf *bytes.Buffer) (string, error) {
s := &bytes.Buffer{}
boundSepRead := false
for {
r, _, err := buf.ReadRune()
if err != nil {
return "", err
}
switch r {
case ',', '}':
if r == ',' && !boundSepRead {
boundSepRead = true
break
}
buf.UnreadRune()
return s.String(), nil
}
s.WriteRune(r)
}
}
// Multirange is a generic multirange type.
//
// T should implement RangeValuer and *T should implement RangeScanner. However, there does not appear to be a way to
// enforce the RangeScanner constraint.
type Multirange[T RangeValuer] []T
func (r Multirange[T]) IsNull() bool {
return r == nil
}
func (r Multirange[T]) Len() int {
return len(r)
}
func (r Multirange[T]) Index(i int) any {
return r[i]
}
func (r Multirange[T]) IndexType() any {
var zero T
return zero
}
func (r *Multirange[T]) ScanNull() error {
*r = nil
return nil
}
func (r *Multirange[T]) SetLen(n int) error {
*r = make([]T, n)
return nil
}
func (r Multirange[T]) ScanIndex(i int) any {
return &r[i]
}
func (r Multirange[T]) ScanIndexType() any {
return new(T)
}