package meddler import ( "bytes" "compress/gzip" "encoding/gob" "encoding/json" "fmt" "reflect" "time" ) // Meddler is the interface for a field meddler. Implementations can be // registered to convert struct fields being loaded and saved in the database. type Meddler interface { // PreRead is called before a Scan operation. It is given a pointer to // the raw struct field, and returns the value that will be given to // the database driver. PreRead(fieldAddr interface{}) (scanTarget interface{}, err error) // PostRead is called after a Scan operation. It is given the value returned // by PreRead and a pointer to the raw struct field. It is expected to fill // in the struct field if the two are different. PostRead(fieldAddr interface{}, scanTarget interface{}) error // PreWrite is called before an Insert or Update operation. It is given // a pointer to the raw struct field, and returns the value that will be // given to the database driver. PreWrite(field interface{}) (saveValue interface{}, err error) } // Register sets up a meddler type. Meddlers get a chance to meddle with the // data being loaded or saved when a field is annotated with the name of the meddler. // The registry is global. func Register(name string, m Meddler) { if name == "pk" { panic("meddler.Register: pk cannot be used as a meddler name") } registry[name] = m } var registry = make(map[string]Meddler) func init() { Register("identity", IdentityMeddler(false)) Register("localtime", TimeMeddler{ZeroIsNull: false, Local: true}) Register("localtimez", TimeMeddler{ZeroIsNull: true, Local: true}) Register("utctime", TimeMeddler{ZeroIsNull: false, Local: false}) Register("utctimez", TimeMeddler{ZeroIsNull: true, Local: false}) Register("zeroisnull", ZeroIsNullMeddler(false)) Register("json", JSONMeddler(false)) Register("jsongzip", JSONMeddler(true)) Register("gob", GobMeddler(false)) Register("gobgzip", GobMeddler(true)) } // IdentityMeddler is the default meddler, and it passes the original value through with // no changes. type IdentityMeddler bool // PreRead is called before a Scan operation for fields that have the IdentityMeddler func (elt IdentityMeddler) PreRead(fieldAddr interface{}) (scanTarget interface{}, err error) { return fieldAddr, nil } // PostRead is called after a Scan operation for fields that have the IdentityMeddler func (elt IdentityMeddler) PostRead(fieldAddr, scanTarget interface{}) error { return nil } // PreWrite is called before an Insert or Update operation for fields that have the IdentityMeddler func (elt IdentityMeddler) PreWrite(field interface{}) (saveValue interface{}, err error) { return field, nil } // TimeMeddler provides useful operations on time.Time fields. It can convert the zero time // to and from a null column, and it can convert the time zone to UTC on save and to Local on load. type TimeMeddler struct { ZeroIsNull bool Local bool } // PreRead is called before a Scan operation for fields that have a TimeMeddler func (elt TimeMeddler) PreRead(fieldAddr interface{}) (scanTarget interface{}, err error) { switch tgt := fieldAddr.(type) { case *time.Time: if elt.ZeroIsNull { return &tgt, nil } return fieldAddr, nil case **time.Time: if elt.ZeroIsNull { return nil, fmt.Errorf("meddler.TimeMeddler cannot be used on a *time.Time field, only time.Time") } return fieldAddr, nil default: return nil, fmt.Errorf("meddler.TimeMeddler.PreRead: unknown struct field type: %T", fieldAddr) } } // PostRead is called after a Scan operation for fields that have a TimeMeddler func (elt TimeMeddler) PostRead(fieldAddr, scanTarget interface{}) error { switch tgt := fieldAddr.(type) { case *time.Time: if elt.ZeroIsNull { src := scanTarget.(**time.Time) if *src == nil { *tgt = time.Time{} } else if elt.Local { *tgt = (*src).Local() } else { *tgt = (*src).UTC() } return nil } src := scanTarget.(*time.Time) if elt.Local { *tgt = src.Local() } else { *tgt = src.UTC() } return nil case **time.Time: if elt.ZeroIsNull { return fmt.Errorf("meddler TimeMeddler cannot be used on a *time.Time field, only time.Time") } src := scanTarget.(**time.Time) if *src == nil { *tgt = nil } else if elt.Local { **src = (*src).Local() *tgt = *src } else { **src = (*src).UTC() *tgt = *src } return nil default: return fmt.Errorf("meddler.TimeMeddler.PostRead: unknown struct field type: %T", fieldAddr) } } // PreWrite is called before an Insert or Update operation for fields that have a TimeMeddler func (elt TimeMeddler) PreWrite(field interface{}) (saveValue interface{}, err error) { switch tgt := field.(type) { case time.Time: if elt.ZeroIsNull && tgt.IsZero() { return nil, nil } return tgt.UTC(), nil case *time.Time: if tgt == nil || elt.ZeroIsNull && tgt.IsZero() { return nil, nil } return tgt.UTC(), nil default: return nil, fmt.Errorf("meddler.TimeMeddler.PreWrite: unknown struct field type: %T", field) } } // ZeroIsNullMeddler converts zero value fields (integers both signed and unsigned, floats, complex numbers, // and strings) to and from null database columns. type ZeroIsNullMeddler bool // PreRead is called before a Scan operation for fields that have the ZeroIsNullMeddler func (elt ZeroIsNullMeddler) PreRead(fieldAddr interface{}) (scanTarget interface{}, err error) { // create a pointer to this element // the database driver will set it to nil if the column value is null return reflect.New(reflect.TypeOf(fieldAddr)).Interface(), nil } // PostRead is called after a Scan operation for fields that have the ZeroIsNullMeddler func (elt ZeroIsNullMeddler) PostRead(fieldAddr, scanTarget interface{}) error { sv := reflect.ValueOf(scanTarget) fv := reflect.ValueOf(fieldAddr) if sv.Elem().IsNil() { // null column, so set target to be zero value fv.Elem().Set(reflect.Zero(fv.Elem().Type())) } else { // copy the value that scan found fv.Elem().Set(sv.Elem().Elem()) } return nil } // PreWrite is called before an Insert or Update operation for fields that have the ZeroIsNullMeddler func (elt ZeroIsNullMeddler) PreWrite(field interface{}) (saveValue interface{}, err error) { val := reflect.ValueOf(field) switch val.Kind() { case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: if val.Int() == 0 { return nil, nil } case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: if val.Uint() == 0 { return nil, nil } case reflect.Float32, reflect.Float64: if val.Float() == 0 { return nil, nil } case reflect.Complex64, reflect.Complex128: if val.Complex() == 0 { return nil, nil } case reflect.String: if val.String() == "" { return nil, nil } case reflect.Bool: if !val.Bool() { return nil, nil } default: return nil, fmt.Errorf("ZeroIsNullMeddler.PreWrite: unknown struct field type: %T", field) } return field, nil } // JSONMeddler encodes or decodes the field value to or from JSON type JSONMeddler bool // PreRead is called before a Scan operation for fields that have the JSONMeddler func (zip JSONMeddler) PreRead(fieldAddr interface{}) (scanTarget interface{}, err error) { // give a pointer to a byte buffer to grab the raw data return new([]byte), nil } // PostRead is called after a Scan operation for fields that have the JSONMeddler func (zip JSONMeddler) PostRead(fieldAddr, scanTarget interface{}) error { ptr := scanTarget.(*[]byte) if ptr == nil { return fmt.Errorf("JSONMeddler.PostRead: nil pointer") } raw := *ptr if zip { // un-gzip and decode json gzipReader, err := gzip.NewReader(bytes.NewReader(raw)) if err != nil { return fmt.Errorf("Error creating gzip Reader: %v", err) } defer gzipReader.Close() jsonDecoder := json.NewDecoder(gzipReader) if err := jsonDecoder.Decode(fieldAddr); err != nil { return fmt.Errorf("JSON decoder/gzip error: %v", err) } if err := gzipReader.Close(); err != nil { return fmt.Errorf("Closing gzip reader: %v", err) } return nil } // decode json jsonDecoder := json.NewDecoder(bytes.NewReader(raw)) if err := jsonDecoder.Decode(fieldAddr); err != nil { return fmt.Errorf("JSON decode error: %v", err) } return nil } // PreWrite is called before an Insert or Update operation for fields that have the JSONMeddler func (zip JSONMeddler) PreWrite(field interface{}) (saveValue interface{}, err error) { buffer := new(bytes.Buffer) if zip { // json encode and gzip gzipWriter := gzip.NewWriter(buffer) defer gzipWriter.Close() jsonEncoder := json.NewEncoder(gzipWriter) if err := jsonEncoder.Encode(field); err != nil { return nil, fmt.Errorf("JSON encoding/gzip error: %v", err) } if err := gzipWriter.Close(); err != nil { return nil, fmt.Errorf("Closing gzip writer: %v", err) } return buffer.Bytes(), nil } // json encode jsonEncoder := json.NewEncoder(buffer) if err := jsonEncoder.Encode(field); err != nil { return nil, fmt.Errorf("JSON encoding error: %v", err) } return buffer.Bytes(), nil } // GobMeddler encodes or decodes the field value to or from gob type GobMeddler bool // PreRead is called before a Scan operation for fields that have the GobMeddler func (zip GobMeddler) PreRead(fieldAddr interface{}) (scanTarget interface{}, err error) { // give a pointer to a byte buffer to grab the raw data return new([]byte), nil } // PostRead is called after a Scan operation for fields that have the GobMeddler func (zip GobMeddler) PostRead(fieldAddr, scanTarget interface{}) error { ptr := scanTarget.(*[]byte) if ptr == nil { return fmt.Errorf("GobMeddler.PostRead: nil pointer") } raw := *ptr if zip { // un-gzip and decode gob gzipReader, err := gzip.NewReader(bytes.NewReader(raw)) if err != nil { return fmt.Errorf("Error creating gzip Reader: %v", err) } defer gzipReader.Close() gobDecoder := gob.NewDecoder(gzipReader) if err := gobDecoder.Decode(fieldAddr); err != nil { return fmt.Errorf("Gob decoder/gzip error: %v", err) } if err := gzipReader.Close(); err != nil { return fmt.Errorf("Closing gzip reader: %v", err) } return nil } // decode gob gobDecoder := gob.NewDecoder(bytes.NewReader(raw)) if err := gobDecoder.Decode(fieldAddr); err != nil { return fmt.Errorf("Gob decode error: %v", err) } return nil } // PreWrite is called before an Insert or Update operation for fields that have the GobMeddler func (zip GobMeddler) PreWrite(field interface{}) (saveValue interface{}, err error) { buffer := new(bytes.Buffer) if zip { // gob encode and gzip gzipWriter := gzip.NewWriter(buffer) defer gzipWriter.Close() gobEncoder := gob.NewEncoder(gzipWriter) if err := gobEncoder.Encode(field); err != nil { return nil, fmt.Errorf("Gob encoding/gzip error: %v", err) } if err := gzipWriter.Close(); err != nil { return nil, fmt.Errorf("Closing gzip writer: %v", err) } return buffer.Bytes(), nil } // gob encode gobEncoder := gob.NewEncoder(buffer) if err := gobEncoder.Encode(field); err != nil { return nil, fmt.Errorf("Gob encoding error: %v", err) } return buffer.Bytes(), nil }