gotosocial/vendor/github.com/go-pg/pg/v10/hook.go
Tobi Smethurst 98263a7de6
Grand test fixup (#138)
* start fixing up tests

* fix up tests + automate with drone

* fiddle with linting

* messing about with drone.yml

* some more fiddling

* hmmm

* add cache

* add vendor directory

* verbose

* ci updates

* update some little things

* update sig
2021-08-12 21:03:24 +02:00

139 lines
3.2 KiB
Go

package pg
import (
"context"
"fmt"
"time"
"github.com/go-pg/pg/v10/orm"
)
type (
BeforeScanHook = orm.BeforeScanHook
AfterScanHook = orm.AfterScanHook
AfterSelectHook = orm.AfterSelectHook
BeforeInsertHook = orm.BeforeInsertHook
AfterInsertHook = orm.AfterInsertHook
BeforeUpdateHook = orm.BeforeUpdateHook
AfterUpdateHook = orm.AfterUpdateHook
BeforeDeleteHook = orm.BeforeDeleteHook
AfterDeleteHook = orm.AfterDeleteHook
)
//------------------------------------------------------------------------------
type dummyFormatter struct{}
func (dummyFormatter) FormatQuery(b []byte, query string, params ...interface{}) []byte {
return append(b, query...)
}
// QueryEvent ...
type QueryEvent struct {
StartTime time.Time
DB orm.DB
Model interface{}
Query interface{}
Params []interface{}
fmtedQuery []byte
Result Result
Err error
Stash map[interface{}]interface{}
}
// QueryHook ...
type QueryHook interface {
BeforeQuery(context.Context, *QueryEvent) (context.Context, error)
AfterQuery(context.Context, *QueryEvent) error
}
// UnformattedQuery returns the unformatted query of a query event.
// The query is only valid until the query Result is returned to the user.
func (e *QueryEvent) UnformattedQuery() ([]byte, error) {
return queryString(e.Query)
}
func queryString(query interface{}) ([]byte, error) {
switch query := query.(type) {
case orm.TemplateAppender:
return query.AppendTemplate(nil)
case string:
return dummyFormatter{}.FormatQuery(nil, query), nil
default:
return nil, fmt.Errorf("pg: can't append %T", query)
}
}
// FormattedQuery returns the formatted query of a query event.
// The query is only valid until the query Result is returned to the user.
func (e *QueryEvent) FormattedQuery() ([]byte, error) {
return e.fmtedQuery, nil
}
// AddQueryHook adds a hook into query processing.
func (db *baseDB) AddQueryHook(hook QueryHook) {
db.queryHooks = append(db.queryHooks, hook)
}
func (db *baseDB) beforeQuery(
ctx context.Context,
ormDB orm.DB,
model, query interface{},
params []interface{},
fmtedQuery []byte,
) (context.Context, *QueryEvent, error) {
if len(db.queryHooks) == 0 {
return ctx, nil, nil
}
event := &QueryEvent{
StartTime: time.Now(),
DB: ormDB,
Model: model,
Query: query,
Params: params,
fmtedQuery: fmtedQuery,
}
for i, hook := range db.queryHooks {
var err error
ctx, err = hook.BeforeQuery(ctx, event)
if err != nil {
if err := db.afterQueryFromIndex(ctx, event, i); err != nil {
return ctx, nil, err
}
return ctx, nil, err
}
}
return ctx, event, nil
}
func (db *baseDB) afterQuery(
ctx context.Context,
event *QueryEvent,
res Result,
err error,
) error {
if event == nil {
return nil
}
event.Err = err
event.Result = res
return db.afterQueryFromIndex(ctx, event, len(db.queryHooks)-1)
}
func (db *baseDB) afterQueryFromIndex(ctx context.Context, event *QueryEvent, hookIndex int) error {
for ; hookIndex >= 0; hookIndex-- {
if err := db.queryHooks[hookIndex].AfterQuery(ctx, event); err != nil {
return err
}
}
return nil
}
func copyQueryHooks(s []QueryHook) []QueryHook {
return s[:len(s):len(s)]
}