gotosocial/vendor/github.com/jackc/pgx/v4/extended_query_builder.go
tobi 2dc9fc1626
Pg to bun (#148)
* start moving to bun

* changing more stuff

* more

* and yet more

* tests passing

* seems stable now

* more big changes

* small fix

* little fixes
2021-08-25 15:34:33 +02:00

169 lines
4 KiB
Go

package pgx
import (
"database/sql/driver"
"fmt"
"reflect"
"github.com/jackc/pgtype"
)
type extendedQueryBuilder struct {
paramValues [][]byte
paramValueBytes []byte
paramFormats []int16
resultFormats []int16
resetCount int
}
func (eqb *extendedQueryBuilder) AppendParam(ci *pgtype.ConnInfo, oid uint32, arg interface{}) error {
f := chooseParameterFormatCode(ci, oid, arg)
eqb.paramFormats = append(eqb.paramFormats, f)
v, err := eqb.encodeExtendedParamValue(ci, oid, f, arg)
if err != nil {
return err
}
eqb.paramValues = append(eqb.paramValues, v)
return nil
}
func (eqb *extendedQueryBuilder) AppendResultFormat(f int16) {
eqb.resultFormats = append(eqb.resultFormats, f)
}
func (eqb *extendedQueryBuilder) Reset() {
eqb.paramValues = eqb.paramValues[0:0]
eqb.paramValueBytes = eqb.paramValueBytes[0:0]
eqb.paramFormats = eqb.paramFormats[0:0]
eqb.resultFormats = eqb.resultFormats[0:0]
eqb.resetCount++
// Every so often shrink our reserved memory if it is abnormally high
if eqb.resetCount%128 == 0 {
if cap(eqb.paramValues) > 64 {
eqb.paramValues = make([][]byte, 0, cap(eqb.paramValues)/2)
}
if cap(eqb.paramValueBytes) > 256 {
eqb.paramValueBytes = make([]byte, 0, cap(eqb.paramValueBytes)/2)
}
if cap(eqb.paramFormats) > 64 {
eqb.paramFormats = make([]int16, 0, cap(eqb.paramFormats)/2)
}
if cap(eqb.resultFormats) > 64 {
eqb.resultFormats = make([]int16, 0, cap(eqb.resultFormats)/2)
}
}
}
func (eqb *extendedQueryBuilder) encodeExtendedParamValue(ci *pgtype.ConnInfo, oid uint32, formatCode int16, arg interface{}) ([]byte, error) {
if arg == nil {
return nil, nil
}
refVal := reflect.ValueOf(arg)
argIsPtr := refVal.Kind() == reflect.Ptr
if argIsPtr && refVal.IsNil() {
return nil, nil
}
if eqb.paramValueBytes == nil {
eqb.paramValueBytes = make([]byte, 0, 128)
}
var err error
var buf []byte
pos := len(eqb.paramValueBytes)
if arg, ok := arg.(string); ok {
return []byte(arg), nil
}
if formatCode == TextFormatCode {
if arg, ok := arg.(pgtype.TextEncoder); ok {
buf, err = arg.EncodeText(ci, eqb.paramValueBytes)
if err != nil {
return nil, err
}
if buf == nil {
return nil, nil
}
eqb.paramValueBytes = buf
return eqb.paramValueBytes[pos:], nil
}
} else if formatCode == BinaryFormatCode {
if arg, ok := arg.(pgtype.BinaryEncoder); ok {
buf, err = arg.EncodeBinary(ci, eqb.paramValueBytes)
if err != nil {
return nil, err
}
if buf == nil {
return nil, nil
}
eqb.paramValueBytes = buf
return eqb.paramValueBytes[pos:], nil
}
}
if argIsPtr {
// We have already checked that arg is not pointing to nil,
// so it is safe to dereference here.
arg = refVal.Elem().Interface()
return eqb.encodeExtendedParamValue(ci, oid, formatCode, arg)
}
if dt, ok := ci.DataTypeForOID(oid); ok {
value := dt.Value
err := value.Set(arg)
if err != nil {
{
if arg, ok := arg.(driver.Valuer); ok {
v, err := callValuerValue(arg)
if err != nil {
return nil, err
}
return eqb.encodeExtendedParamValue(ci, oid, formatCode, v)
}
}
return nil, err
}
return eqb.encodeExtendedParamValue(ci, oid, formatCode, value)
}
// There is no data type registered for the destination OID, but maybe there is data type registered for the arg
// type. If so use it's text encoder (if available).
if dt, ok := ci.DataTypeForValue(arg); ok {
value := dt.Value
if textEncoder, ok := value.(pgtype.TextEncoder); ok {
err := value.Set(arg)
if err != nil {
return nil, err
}
buf, err = textEncoder.EncodeText(ci, eqb.paramValueBytes)
if err != nil {
return nil, err
}
if buf == nil {
return nil, nil
}
eqb.paramValueBytes = buf
return eqb.paramValueBytes[pos:], nil
}
}
if strippedArg, ok := stripNamedType(&refVal); ok {
return eqb.encodeExtendedParamValue(ci, oid, formatCode, strippedArg)
}
return nil, SerializationError(fmt.Sprintf("Cannot encode %T into oid %v - %T must implement Encoder or be converted to a string", arg, oid, arg))
}