package pgtype import ( "database/sql/driver" "encoding/binary" "fmt" "strconv" "strings" "github.com/jackc/pgx/v5/internal/pgio" ) const ( microsecondsPerSecond = 1000000 microsecondsPerMinute = 60 * microsecondsPerSecond microsecondsPerHour = 60 * microsecondsPerMinute microsecondsPerDay = 24 * microsecondsPerHour microsecondsPerMonth = 30 * microsecondsPerDay ) type IntervalScanner interface { ScanInterval(v Interval) error } type IntervalValuer interface { IntervalValue() (Interval, error) } type Interval struct { Microseconds int64 Days int32 Months int32 Valid bool } func (interval *Interval) ScanInterval(v Interval) error { *interval = v return nil } func (interval Interval) IntervalValue() (Interval, error) { return interval, nil } // Scan implements the database/sql Scanner interface. func (interval *Interval) Scan(src any) error { if src == nil { *interval = Interval{} return nil } switch src := src.(type) { case string: return scanPlanTextAnyToIntervalScanner{}.Scan([]byte(src), interval) } return fmt.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. func (interval Interval) Value() (driver.Value, error) { if !interval.Valid { return nil, nil } buf, err := IntervalCodec{}.PlanEncode(nil, 0, TextFormatCode, interval).Encode(interval, nil) if err != nil { return nil, err } return string(buf), err } type IntervalCodec struct{} func (IntervalCodec) FormatSupported(format int16) bool { return format == TextFormatCode || format == BinaryFormatCode } func (IntervalCodec) PreferredFormat() int16 { return BinaryFormatCode } func (IntervalCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { if _, ok := value.(IntervalValuer); !ok { return nil } switch format { case BinaryFormatCode: return encodePlanIntervalCodecBinary{} case TextFormatCode: return encodePlanIntervalCodecText{} } return nil } type encodePlanIntervalCodecBinary struct{} func (encodePlanIntervalCodecBinary) Encode(value any, buf []byte) (newBuf []byte, err error) { interval, err := value.(IntervalValuer).IntervalValue() if err != nil { return nil, err } if !interval.Valid { return nil, nil } buf = pgio.AppendInt64(buf, interval.Microseconds) buf = pgio.AppendInt32(buf, interval.Days) buf = pgio.AppendInt32(buf, interval.Months) return buf, nil } type encodePlanIntervalCodecText struct{} func (encodePlanIntervalCodecText) Encode(value any, buf []byte) (newBuf []byte, err error) { interval, err := value.(IntervalValuer).IntervalValue() if err != nil { return nil, err } if !interval.Valid { return nil, nil } if interval.Months != 0 { buf = append(buf, strconv.FormatInt(int64(interval.Months), 10)...) buf = append(buf, " mon "...) } if interval.Days != 0 { buf = append(buf, strconv.FormatInt(int64(interval.Days), 10)...) buf = append(buf, " day "...) } absMicroseconds := interval.Microseconds if absMicroseconds < 0 { absMicroseconds = -absMicroseconds buf = append(buf, '-') } hours := absMicroseconds / microsecondsPerHour minutes := (absMicroseconds % microsecondsPerHour) / microsecondsPerMinute seconds := (absMicroseconds % microsecondsPerMinute) / microsecondsPerSecond microseconds := absMicroseconds % microsecondsPerSecond timeStr := fmt.Sprintf("%02d:%02d:%02d.%06d", hours, minutes, seconds, microseconds) buf = append(buf, timeStr...) return buf, nil } func (IntervalCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { switch format { case BinaryFormatCode: switch target.(type) { case IntervalScanner: return scanPlanBinaryIntervalToIntervalScanner{} } case TextFormatCode: switch target.(type) { case IntervalScanner: return scanPlanTextAnyToIntervalScanner{} } } return nil } type scanPlanBinaryIntervalToIntervalScanner struct{} func (scanPlanBinaryIntervalToIntervalScanner) Scan(src []byte, dst any) error { scanner := (dst).(IntervalScanner) if src == nil { return scanner.ScanInterval(Interval{}) } if len(src) != 16 { return fmt.Errorf("Received an invalid size for an interval: %d", len(src)) } microseconds := int64(binary.BigEndian.Uint64(src)) days := int32(binary.BigEndian.Uint32(src[8:])) months := int32(binary.BigEndian.Uint32(src[12:])) return scanner.ScanInterval(Interval{Microseconds: microseconds, Days: days, Months: months, Valid: true}) } type scanPlanTextAnyToIntervalScanner struct{} func (scanPlanTextAnyToIntervalScanner) Scan(src []byte, dst any) error { scanner := (dst).(IntervalScanner) if src == nil { return scanner.ScanInterval(Interval{}) } var microseconds int64 var days int32 var months int32 parts := strings.Split(string(src), " ") for i := 0; i < len(parts)-1; i += 2 { scalar, err := strconv.ParseInt(parts[i], 10, 64) if err != nil { return fmt.Errorf("bad interval format") } switch parts[i+1] { case "year", "years": months += int32(scalar * 12) case "mon", "mons": months += int32(scalar) case "day", "days": days = int32(scalar) } } if len(parts)%2 == 1 { timeParts := strings.SplitN(parts[len(parts)-1], ":", 3) if len(timeParts) != 3 { return fmt.Errorf("bad interval format") } var negative bool if timeParts[0][0] == '-' { negative = true timeParts[0] = timeParts[0][1:] } hours, err := strconv.ParseInt(timeParts[0], 10, 64) if err != nil { return fmt.Errorf("bad interval hour format: %s", timeParts[0]) } minutes, err := strconv.ParseInt(timeParts[1], 10, 64) if err != nil { return fmt.Errorf("bad interval minute format: %s", timeParts[1]) } sec, secFrac, secFracFound := strings.Cut(timeParts[2], ".") seconds, err := strconv.ParseInt(sec, 10, 64) if err != nil { return fmt.Errorf("bad interval second format: %s", sec) } var uSeconds int64 if secFracFound { uSeconds, err = strconv.ParseInt(secFrac, 10, 64) if err != nil { return fmt.Errorf("bad interval decimal format: %s", secFrac) } for i := 0; i < 6-len(secFrac); i++ { uSeconds *= 10 } } microseconds = hours * microsecondsPerHour microseconds += minutes * microsecondsPerMinute microseconds += seconds * microsecondsPerSecond microseconds += uSeconds if negative { microseconds = -microseconds } } return scanner.ScanInterval(Interval{Months: months, Days: days, Microseconds: microseconds, Valid: true}) } func (c IntervalCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { return codecDecodeToTextFormat(c, m, oid, format, src) } func (c IntervalCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { if src == nil { return nil, nil } var interval Interval err := codecScan(c, m, oid, format, src, &interval) if err != nil { return nil, err } return interval, nil }