forked from mirrors/gotosocial
bdcc090851
Signed-off-by: kim (grufwub) <grufwub@gmail.com>
149 lines
2.5 KiB
Go
149 lines
2.5 KiB
Go
package schema
|
|
|
|
import (
|
|
"fmt"
|
|
"reflect"
|
|
"sync"
|
|
)
|
|
|
|
type tableInProgress struct {
|
|
table *Table
|
|
|
|
init1Once sync.Once
|
|
init2Once sync.Once
|
|
}
|
|
|
|
func newTableInProgress(table *Table) *tableInProgress {
|
|
return &tableInProgress{
|
|
table: table,
|
|
}
|
|
}
|
|
|
|
func (inp *tableInProgress) init1() bool {
|
|
var inited bool
|
|
inp.init1Once.Do(func() {
|
|
inp.table.init1()
|
|
inited = true
|
|
})
|
|
return inited
|
|
}
|
|
|
|
func (inp *tableInProgress) init2() bool {
|
|
var inited bool
|
|
inp.init2Once.Do(func() {
|
|
inp.table.init2()
|
|
inited = true
|
|
})
|
|
return inited
|
|
}
|
|
|
|
type Tables struct {
|
|
dialect Dialect
|
|
tables sync.Map
|
|
|
|
mu sync.RWMutex
|
|
inProgress map[reflect.Type]*tableInProgress
|
|
}
|
|
|
|
func NewTables(dialect Dialect) *Tables {
|
|
return &Tables{
|
|
dialect: dialect,
|
|
inProgress: make(map[reflect.Type]*tableInProgress),
|
|
}
|
|
}
|
|
|
|
func (t *Tables) Register(models ...interface{}) {
|
|
for _, model := range models {
|
|
_ = t.Get(reflect.TypeOf(model).Elem())
|
|
}
|
|
}
|
|
|
|
func (t *Tables) Get(typ reflect.Type) *Table {
|
|
return t.table(typ, false)
|
|
}
|
|
|
|
func (t *Tables) Ref(typ reflect.Type) *Table {
|
|
return t.table(typ, true)
|
|
}
|
|
|
|
func (t *Tables) table(typ reflect.Type, allowInProgress bool) *Table {
|
|
typ = indirectType(typ)
|
|
if typ.Kind() != reflect.Struct {
|
|
panic(fmt.Errorf("got %s, wanted %s", typ.Kind(), reflect.Struct))
|
|
}
|
|
|
|
if v, ok := t.tables.Load(typ); ok {
|
|
return v.(*Table)
|
|
}
|
|
|
|
t.mu.Lock()
|
|
|
|
if v, ok := t.tables.Load(typ); ok {
|
|
t.mu.Unlock()
|
|
return v.(*Table)
|
|
}
|
|
|
|
var table *Table
|
|
|
|
inProgress := t.inProgress[typ]
|
|
if inProgress == nil {
|
|
table = newTable(t.dialect, typ)
|
|
inProgress = newTableInProgress(table)
|
|
t.inProgress[typ] = inProgress
|
|
} else {
|
|
table = inProgress.table
|
|
}
|
|
|
|
t.mu.Unlock()
|
|
|
|
inProgress.init1()
|
|
if allowInProgress {
|
|
return table
|
|
}
|
|
|
|
if inProgress.init2() {
|
|
t.mu.Lock()
|
|
delete(t.inProgress, typ)
|
|
t.tables.Store(typ, table)
|
|
t.mu.Unlock()
|
|
}
|
|
|
|
t.dialect.OnTable(table)
|
|
|
|
for _, field := range table.FieldMap {
|
|
if field.UserSQLType == "" {
|
|
field.UserSQLType = field.DiscoveredSQLType
|
|
}
|
|
if field.CreateTableSQLType == "" {
|
|
field.CreateTableSQLType = field.UserSQLType
|
|
}
|
|
}
|
|
|
|
return table
|
|
}
|
|
|
|
func (t *Tables) ByModel(name string) *Table {
|
|
var found *Table
|
|
t.tables.Range(func(key, value interface{}) bool {
|
|
t := value.(*Table)
|
|
if t.TypeName == name {
|
|
found = t
|
|
return false
|
|
}
|
|
return true
|
|
})
|
|
return found
|
|
}
|
|
|
|
func (t *Tables) ByName(name string) *Table {
|
|
var found *Table
|
|
t.tables.Range(func(key, value interface{}) bool {
|
|
t := value.(*Table)
|
|
if t.Name == name {
|
|
found = t
|
|
return false
|
|
}
|
|
return true
|
|
})
|
|
return found
|
|
}
|