gotosocial/vendor/github.com/jackc/pgx/v5/pgtype/point.go
dependabot[bot] 9b76afc851
[chore]: Bump github.com/jackc/pgx/v5 from 5.4.3 to 5.5.0 (#2336)
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-11-06 14:44:53 +00:00

267 lines
5.2 KiB
Go

package pgtype
import (
"bytes"
"database/sql/driver"
"encoding/binary"
"fmt"
"math"
"strconv"
"strings"
"github.com/jackc/pgx/v5/internal/pgio"
)
type Vec2 struct {
X float64
Y float64
}
type PointScanner interface {
ScanPoint(v Point) error
}
type PointValuer interface {
PointValue() (Point, error)
}
type Point struct {
P Vec2
Valid bool
}
func (p *Point) ScanPoint(v Point) error {
*p = v
return nil
}
func (p Point) PointValue() (Point, error) {
return p, nil
}
func parsePoint(src []byte) (*Point, error) {
if src == nil || bytes.Equal(src, []byte("null")) {
return &Point{}, nil
}
if len(src) < 5 {
return nil, fmt.Errorf("invalid length for point: %v", len(src))
}
if src[0] == '"' && src[len(src)-1] == '"' {
src = src[1 : len(src)-1]
}
sx, sy, found := strings.Cut(string(src[1:len(src)-1]), ",")
if !found {
return nil, fmt.Errorf("invalid format for point")
}
x, err := strconv.ParseFloat(sx, 64)
if err != nil {
return nil, err
}
y, err := strconv.ParseFloat(sy, 64)
if err != nil {
return nil, err
}
return &Point{P: Vec2{x, y}, Valid: true}, nil
}
// Scan implements the database/sql Scanner interface.
func (dst *Point) Scan(src any) error {
if src == nil {
*dst = Point{}
return nil
}
switch src := src.(type) {
case string:
return scanPlanTextAnyToPointScanner{}.Scan([]byte(src), dst)
}
return fmt.Errorf("cannot scan %T", src)
}
// Value implements the database/sql/driver Valuer interface.
func (src Point) Value() (driver.Value, error) {
if !src.Valid {
return nil, nil
}
buf, err := PointCodec{}.PlanEncode(nil, 0, TextFormatCode, src).Encode(src, nil)
if err != nil {
return nil, err
}
return string(buf), err
}
func (src Point) MarshalJSON() ([]byte, error) {
if !src.Valid {
return []byte("null"), nil
}
var buff bytes.Buffer
buff.WriteByte('"')
buff.WriteString(fmt.Sprintf("(%g,%g)", src.P.X, src.P.Y))
buff.WriteByte('"')
return buff.Bytes(), nil
}
func (dst *Point) UnmarshalJSON(point []byte) error {
p, err := parsePoint(point)
if err != nil {
return err
}
*dst = *p
return nil
}
type PointCodec struct{}
func (PointCodec) FormatSupported(format int16) bool {
return format == TextFormatCode || format == BinaryFormatCode
}
func (PointCodec) PreferredFormat() int16 {
return BinaryFormatCode
}
func (PointCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan {
if _, ok := value.(PointValuer); !ok {
return nil
}
switch format {
case BinaryFormatCode:
return encodePlanPointCodecBinary{}
case TextFormatCode:
return encodePlanPointCodecText{}
}
return nil
}
type encodePlanPointCodecBinary struct{}
func (encodePlanPointCodecBinary) Encode(value any, buf []byte) (newBuf []byte, err error) {
point, err := value.(PointValuer).PointValue()
if err != nil {
return nil, err
}
if !point.Valid {
return nil, nil
}
buf = pgio.AppendUint64(buf, math.Float64bits(point.P.X))
buf = pgio.AppendUint64(buf, math.Float64bits(point.P.Y))
return buf, nil
}
type encodePlanPointCodecText struct{}
func (encodePlanPointCodecText) Encode(value any, buf []byte) (newBuf []byte, err error) {
point, err := value.(PointValuer).PointValue()
if err != nil {
return nil, err
}
if !point.Valid {
return nil, nil
}
return append(buf, fmt.Sprintf(`(%s,%s)`,
strconv.FormatFloat(point.P.X, 'f', -1, 64),
strconv.FormatFloat(point.P.Y, 'f', -1, 64),
)...), nil
}
func (PointCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan {
switch format {
case BinaryFormatCode:
switch target.(type) {
case PointScanner:
return scanPlanBinaryPointToPointScanner{}
}
case TextFormatCode:
switch target.(type) {
case PointScanner:
return scanPlanTextAnyToPointScanner{}
}
}
return nil
}
func (c PointCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) {
return codecDecodeToTextFormat(c, m, oid, format, src)
}
func (c PointCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) {
if src == nil {
return nil, nil
}
var point Point
err := codecScan(c, m, oid, format, src, &point)
if err != nil {
return nil, err
}
return point, nil
}
type scanPlanBinaryPointToPointScanner struct{}
func (scanPlanBinaryPointToPointScanner) Scan(src []byte, dst any) error {
scanner := (dst).(PointScanner)
if src == nil {
return scanner.ScanPoint(Point{})
}
if len(src) != 16 {
return fmt.Errorf("invalid length for point: %v", len(src))
}
x := binary.BigEndian.Uint64(src)
y := binary.BigEndian.Uint64(src[8:])
return scanner.ScanPoint(Point{
P: Vec2{math.Float64frombits(x), math.Float64frombits(y)},
Valid: true,
})
}
type scanPlanTextAnyToPointScanner struct{}
func (scanPlanTextAnyToPointScanner) Scan(src []byte, dst any) error {
scanner := (dst).(PointScanner)
if src == nil {
return scanner.ScanPoint(Point{})
}
if len(src) < 5 {
return fmt.Errorf("invalid length for point: %v", len(src))
}
sx, sy, found := strings.Cut(string(src[1:len(src)-1]), ",")
if !found {
return fmt.Errorf("invalid format for point")
}
x, err := strconv.ParseFloat(sx, 64)
if err != nil {
return err
}
y, err := strconv.ParseFloat(sy, 64)
if err != nil {
return err
}
return scanner.ScanPoint(Point{P: Vec2{x, y}, Valid: true})
}