package pgtype import ( "bytes" "database/sql/driver" "encoding/binary" "fmt" "math" "strconv" "strings" "github.com/jackc/pgx/v5/internal/pgio" ) type Vec2 struct { X float64 Y float64 } type PointScanner interface { ScanPoint(v Point) error } type PointValuer interface { PointValue() (Point, error) } type Point struct { P Vec2 Valid bool } func (p *Point) ScanPoint(v Point) error { *p = v return nil } func (p Point) PointValue() (Point, error) { return p, nil } func parsePoint(src []byte) (*Point, error) { if src == nil || bytes.Equal(src, []byte("null")) { return &Point{}, nil } if len(src) < 5 { return nil, fmt.Errorf("invalid length for point: %v", len(src)) } if src[0] == '"' && src[len(src)-1] == '"' { src = src[1 : len(src)-1] } sx, sy, found := strings.Cut(string(src[1:len(src)-1]), ",") if !found { return nil, fmt.Errorf("invalid format for point") } x, err := strconv.ParseFloat(sx, 64) if err != nil { return nil, err } y, err := strconv.ParseFloat(sy, 64) if err != nil { return nil, err } return &Point{P: Vec2{x, y}, Valid: true}, nil } // Scan implements the database/sql Scanner interface. func (dst *Point) Scan(src any) error { if src == nil { *dst = Point{} return nil } switch src := src.(type) { case string: return scanPlanTextAnyToPointScanner{}.Scan([]byte(src), dst) } return fmt.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. func (src Point) Value() (driver.Value, error) { if !src.Valid { return nil, nil } buf, err := PointCodec{}.PlanEncode(nil, 0, TextFormatCode, src).Encode(src, nil) if err != nil { return nil, err } return string(buf), err } func (src Point) MarshalJSON() ([]byte, error) { if !src.Valid { return []byte("null"), nil } var buff bytes.Buffer buff.WriteByte('"') buff.WriteString(fmt.Sprintf("(%g,%g)", src.P.X, src.P.Y)) buff.WriteByte('"') return buff.Bytes(), nil } func (dst *Point) UnmarshalJSON(point []byte) error { p, err := parsePoint(point) if err != nil { return err } *dst = *p return nil } type PointCodec struct{} func (PointCodec) FormatSupported(format int16) bool { return format == TextFormatCode || format == BinaryFormatCode } func (PointCodec) PreferredFormat() int16 { return BinaryFormatCode } func (PointCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { if _, ok := value.(PointValuer); !ok { return nil } switch format { case BinaryFormatCode: return encodePlanPointCodecBinary{} case TextFormatCode: return encodePlanPointCodecText{} } return nil } type encodePlanPointCodecBinary struct{} func (encodePlanPointCodecBinary) Encode(value any, buf []byte) (newBuf []byte, err error) { point, err := value.(PointValuer).PointValue() if err != nil { return nil, err } if !point.Valid { return nil, nil } buf = pgio.AppendUint64(buf, math.Float64bits(point.P.X)) buf = pgio.AppendUint64(buf, math.Float64bits(point.P.Y)) return buf, nil } type encodePlanPointCodecText struct{} func (encodePlanPointCodecText) Encode(value any, buf []byte) (newBuf []byte, err error) { point, err := value.(PointValuer).PointValue() if err != nil { return nil, err } if !point.Valid { return nil, nil } return append(buf, fmt.Sprintf(`(%s,%s)`, strconv.FormatFloat(point.P.X, 'f', -1, 64), strconv.FormatFloat(point.P.Y, 'f', -1, 64), )...), nil } func (PointCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { switch format { case BinaryFormatCode: switch target.(type) { case PointScanner: return scanPlanBinaryPointToPointScanner{} } case TextFormatCode: switch target.(type) { case PointScanner: return scanPlanTextAnyToPointScanner{} } } return nil } func (c PointCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { return codecDecodeToTextFormat(c, m, oid, format, src) } func (c PointCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { if src == nil { return nil, nil } var point Point err := codecScan(c, m, oid, format, src, &point) if err != nil { return nil, err } return point, nil } type scanPlanBinaryPointToPointScanner struct{} func (scanPlanBinaryPointToPointScanner) Scan(src []byte, dst any) error { scanner := (dst).(PointScanner) if src == nil { return scanner.ScanPoint(Point{}) } if len(src) != 16 { return fmt.Errorf("invalid length for point: %v", len(src)) } x := binary.BigEndian.Uint64(src) y := binary.BigEndian.Uint64(src[8:]) return scanner.ScanPoint(Point{ P: Vec2{math.Float64frombits(x), math.Float64frombits(y)}, Valid: true, }) } type scanPlanTextAnyToPointScanner struct{} func (scanPlanTextAnyToPointScanner) Scan(src []byte, dst any) error { scanner := (dst).(PointScanner) if src == nil { return scanner.ScanPoint(Point{}) } if len(src) < 5 { return fmt.Errorf("invalid length for point: %v", len(src)) } sx, sy, found := strings.Cut(string(src[1:len(src)-1]), ",") if !found { return fmt.Errorf("invalid format for point") } x, err := strconv.ParseFloat(sx, 64) if err != nil { return err } y, err := strconv.ParseFloat(sy, 64) if err != nil { return err } return scanner.ScanPoint(Point{P: Vec2{x, y}, Valid: true}) }