diff --git a/cmd/doctor.go b/cmd/doctor.go index 9957053365..045d67003d 100644 --- a/cmd/doctor.go +++ b/cmd/doctor.go @@ -130,11 +130,17 @@ func runRecreateTable(ctx *cli.Context) error { } recreateTables := migrate_base.RecreateTables(beans...) - return db.InitEngineWithMigration(stdCtx, func(x *xorm.Engine) error { - if err := migrations.EnsureUpToDate(x); err != nil { + return db.InitEngineWithMigration(stdCtx, func(x db.Engine) error { + var engine *xorm.Engine + if getter, ok := x.(interface{ Master() *xorm.Engine }); ok { + engine = getter.Master() + } else { + engine = x.(*xorm.Engine) + } + if err := migrations.EnsureUpToDate(engine); err != nil { return err } - return recreateTables(x) + return recreateTables(engine) }) } diff --git a/cmd/migrate.go b/cmd/migrate.go index 53c496a36c..ad4658ec7d 100644 --- a/cmd/migrate.go +++ b/cmd/migrate.go @@ -12,6 +12,7 @@ import ( "code.gitea.io/gitea/modules/setting" "github.com/urfave/cli/v2" + "xorm.io/xorm" ) // CmdMigrate represents the available migrate sub-command. @@ -36,7 +37,15 @@ func runMigrate(ctx *cli.Context) error { log.Info("Log path: %s", setting.Log.RootPath) log.Info("Configuration file: %s", setting.CustomConf) - if err := db.InitEngineWithMigration(context.Background(), migrations.Migrate); err != nil { + if err := db.InitEngineWithMigration(context.Background(), func(engine db.Engine) error { + var e *xorm.Engine + if getter, ok := engine.(interface{ Master() *xorm.Engine }); ok { + e = getter.Master() + } else { + e = engine.(*xorm.Engine) + } + return migrations.Migrate(e) + }); err != nil { log.Fatal("Failed to initialize ORM engine: %v", err) return err } diff --git a/models/db/engine.go b/models/db/engine.go index 822618a7e3..b45cae383f 100755 --- a/models/db/engine.go +++ b/models/db/engine.go @@ -12,6 +12,7 @@ import ( "io" "reflect" "runtime/trace" + "strconv" "strings" "time" @@ -34,6 +35,7 @@ var ( ) // Engine represents a xorm engine or session. +// (Our Engine interface remains unchanged.) type Engine interface { Table(tableNameOrBean any) *xorm.Session Count(...any) (int64, error) @@ -70,7 +72,7 @@ type Engine interface { Ping() error } -// TableInfo returns table's information via an object +// TableInfo remains the same – it will call x.TableInfo on the underlying engine or group. func TableInfo(v any) (*schemas.Table, error) { return x.TableInfo(v) } @@ -80,7 +82,7 @@ func DumpTables(tables []*schemas.Table, w io.Writer, tp ...schemas.DBType) erro return x.DumpTables(tables, w, tp...) } -// RegisterModel registers model, if initfunc provided, it will be invoked after data model sync +// RegisterModel registers model, if initfunc provided, it will be invoked after data model sync. func RegisterModel(bean any, initFunc ...func() error) { tables = append(tables, bean) if len(initFuncs) > 0 && initFunc[0] != nil { @@ -95,34 +97,96 @@ func init() { } } -// newXORMEngine returns a new XORM engine from the configuration -func newXORMEngine() (*xorm.Engine, error) { - connStr, err := setting.DBConnStr() +// newXORMEngineGroup creates an xorm.EngineGroup (with one master and one or more slaves). +// It assumes you have separate master and slave DSNs defined via the settings package. +func newXORMEngineGroup() (Engine, error) { + // Retrieve master DSN from settings. + masterConnStr, err := setting.DBMasterConnStr() if err != nil { - return nil, err + return nil, fmt.Errorf("failed to determine master DSN: %w", err) } - var engine *xorm.Engine - + var masterEngine *xorm.Engine + // For PostgreSQL: if a schema is provided, we use the special “postgresschema” driver. if setting.Database.Type.IsPostgreSQL() && len(setting.Database.Schema) > 0 { - // OK whilst we sort out our schema issues - create a schema aware postgres registerPostgresSchemaDriver() - engine, err = xorm.NewEngine("postgresschema", connStr) + masterEngine, err = xorm.NewEngine("postgresschema", masterConnStr) } else { - engine, err = xorm.NewEngine(setting.Database.Type.String(), connStr) + masterEngine, err = xorm.NewEngine(setting.Database.Type.String(), masterConnStr) } - if err != nil { - return nil, err + return nil, fmt.Errorf("failed to create master engine: %w", err) } if setting.Database.Type.IsMySQL() { - engine.Dialect().SetParams(map[string]string{"rowFormat": "DYNAMIC"}) + masterEngine.Dialect().SetParams(map[string]string{"rowFormat": "DYNAMIC"}) } - engine.SetSchema(setting.Database.Schema) - return engine, nil + masterEngine.SetSchema(setting.Database.Schema) + + // Get slave DSNs. + slaveConnStrs, err := setting.DBSlaveConnStrs() + if err != nil { + return nil, fmt.Errorf("failed to load slave DSNs: %w", err) + } + + var slaveEngines []*xorm.Engine + // Iterate over all slave DSNs and create engines. + for _, dsn := range slaveConnStrs { + slaveEngine, err := xorm.NewEngine(setting.Database.Type.String(), dsn) + if err != nil { + return nil, fmt.Errorf("failed to create slave engine for dsn %q: %w", dsn, err) + } + if setting.Database.Type.IsMySQL() { + slaveEngine.Dialect().SetParams(map[string]string{"rowFormat": "DYNAMIC"}) + } + slaveEngine.SetSchema(setting.Database.Schema) + slaveEngines = append(slaveEngines, slaveEngine) + } + + // Build load balance policy from user settings. + var policy xorm.GroupPolicy + switch setting.Database.LoadBalancePolicy { + case "WeightRandom": + var weights []int + if setting.Database.LoadBalanceWeights != "" { + for part := range strings.SplitSeq(setting.Database.LoadBalanceWeights, ",") { + w, err := strconv.Atoi(strings.TrimSpace(part)) + if err != nil { + w = 1 // use a default weight if conversion fails + } + weights = append(weights, w) + } + } + // If no valid weights were provided, default each slave to weight 1. + if len(weights) == 0 { + weights = make([]int, len(slaveEngines)) + for i := range weights { + weights[i] = 1 + } + } + policy = xorm.WeightRandomPolicy(weights) + case "RoundRobin": + policy = xorm.RoundRobinPolicy() + default: + policy = xorm.RandomPolicy() + } + // Create the EngineGroup using the selected policy. + group, err := xorm.NewEngineGroup(masterEngine, slaveEngines, policy) + if err != nil { + return nil, fmt.Errorf("failed to create engine group: %w", err) + } + return engineGroupWrapper{group}, nil } -// SyncAllTables sync the schemas of all tables, is required by unit test code +type engineGroupWrapper struct { + *xorm.EngineGroup +} + +func (w engineGroupWrapper) AddHook(hook contexts.Hook) bool { + w.EngineGroup.AddHook(hook) + return true +} + +// SyncAllTables sync the schemas of all tables. func SyncAllTables() error { _, err := x.StoreEngine("InnoDB").SyncWithOptions(xorm.SyncOptions{ WarnIfDatabaseColumnMissed: true, @@ -130,59 +194,65 @@ func SyncAllTables() error { return err } -// InitEngine initializes the xorm.Engine and sets it as db.DefaultContext +// InitEngine initializes the xorm EngineGroup and sets it as db.DefaultContext. func InitEngine(ctx context.Context) error { - xormEngine, err := newXORMEngine() + xormEngine, err := newXORMEngineGroup() if err != nil { return fmt.Errorf("failed to connect to database: %w", err) } + // Try to cast to the concrete type to access diagnostic methods. + if eng, ok := xormEngine.(engineGroupWrapper); ok { + eng.SetMapper(names.GonicMapper{}) + // WARNING: for serv command, MUST remove the output to os.Stdout, + // so use a log file instead of printing to stdout. + eng.SetLogger(NewXORMLogger(setting.Database.LogSQL)) + eng.ShowSQL(setting.Database.LogSQL) + eng.SetMaxOpenConns(setting.Database.MaxOpenConns) + eng.SetMaxIdleConns(setting.Database.MaxIdleConns) + eng.SetConnMaxLifetime(setting.Database.ConnMaxLifetime) + eng.SetConnMaxIdleTime(setting.Database.ConnMaxIdleTime) + eng.SetDefaultContext(ctx) - xormEngine.SetMapper(names.GonicMapper{}) - // WARNING: for serv command, MUST remove the output to os.stdout, - // so use log file to instead print to stdout. - xormEngine.SetLogger(NewXORMLogger(setting.Database.LogSQL)) - xormEngine.ShowSQL(setting.Database.LogSQL) - xormEngine.SetMaxOpenConns(setting.Database.MaxOpenConns) - xormEngine.SetMaxIdleConns(setting.Database.MaxIdleConns) - xormEngine.SetConnMaxLifetime(setting.Database.ConnMaxLifetime) - xormEngine.SetConnMaxIdleTime(setting.Database.ConnMaxIdleTime) - xormEngine.SetDefaultContext(ctx) + if setting.Database.SlowQueryThreshold > 0 { + eng.AddHook(&SlowQueryHook{ + Treshold: setting.Database.SlowQueryThreshold, + Logger: log.GetLogger("xorm"), + }) + } - if setting.Database.SlowQueryThreshold > 0 { - xormEngine.AddHook(&SlowQueryHook{ - Treshold: setting.Database.SlowQueryThreshold, - Logger: log.GetLogger("xorm"), + errorLogger := log.GetLogger("xorm") + if setting.IsInTesting { + errorLogger = log.GetLogger(log.DEFAULT) + } + + eng.AddHook(&ErrorQueryHook{ + Logger: errorLogger, }) + + eng.AddHook(&TracingHook{}) + + SetDefaultEngine(ctx, eng) + } else { + // Fallback: if type assertion fails, set default engine without extended diagnostics. + SetDefaultEngine(ctx, xormEngine) } - - errorLogger := log.GetLogger("xorm") - if setting.IsInTesting { - errorLogger = log.GetLogger(log.DEFAULT) - } - - xormEngine.AddHook(&ErrorQueryHook{ - Logger: errorLogger, - }) - - xormEngine.AddHook(&TracingHook{}) - - SetDefaultEngine(ctx, xormEngine) return nil } -// SetDefaultEngine sets the default engine for db -func SetDefaultEngine(ctx context.Context, eng *xorm.Engine) { - x = eng +// SetDefaultEngine sets the default engine for db. +func SetDefaultEngine(ctx context.Context, eng Engine) { + if engine, ok := eng.(*xorm.Engine); ok { + x = engine + } else if group, ok := eng.(engineGroupWrapper); ok { + x = group.Master() + } DefaultContext = &Context{ Context: ctx, - e: x, + e: eng, } } -// UnsetDefaultEngine closes and unsets the default engine -// We hope the SetDefaultEngine and UnsetDefaultEngine can be paired, but it's impossible now, -// there are many calls to InitEngine -> SetDefaultEngine directly to overwrite the `x` and DefaultContext without close -// Global database engine related functions are all racy and there is no graceful close right now. +// UnsetDefaultEngine closes and unsets the default engine. func UnsetDefaultEngine() { if x != nil { _ = x.Close() @@ -191,12 +261,8 @@ func UnsetDefaultEngine() { DefaultContext = nil } -// InitEngineWithMigration initializes a new xorm.Engine and sets it as the db.DefaultContext -// This function must never call .Sync() if the provided migration function fails. -// When called from the "doctor" command, the migration function is a version check -// that prevents the doctor from fixing anything in the database if the migration level -// is different from the expected value. -func InitEngineWithMigration(ctx context.Context, migrateFunc func(*xorm.Engine) error) (err error) { +// InitEngineWithMigration initializes a new xorm EngineGroup, runs migrations, and sets it as db.DefaultContext. +func InitEngineWithMigration(ctx context.Context, migrateFunc func(Engine) error) (err error) { if err = InitEngine(ctx); err != nil { return err } @@ -207,12 +273,7 @@ func InitEngineWithMigration(ctx context.Context, migrateFunc func(*xorm.Engine) preprocessDatabaseCollation(x) - // We have to run migrateFunc here in case the user is re-running installation on a previously created DB. - // If we do not then table schemas will be changed and there will be conflicts when the migrations run properly. - // - // Installation should only be being re-run if users want to recover an old database. - // However, we should think carefully about should we support re-install on an installed instance, - // as there may be other problems due to secret reinitialization. + // Run migration function. if err = migrateFunc(x); err != nil { return fmt.Errorf("migrate: %w", err) } @@ -230,14 +291,14 @@ func InitEngineWithMigration(ctx context.Context, migrateFunc func(*xorm.Engine) return nil } -// NamesToBean return a list of beans or an error +// NamesToBean returns a list of beans given names. func NamesToBean(names ...string) ([]any, error) { beans := []any{} if len(names) == 0 { beans = append(beans, tables...) return beans, nil } - // Need to map provided names to beans... + // Map provided names to beans. beanMap := make(map[string]any) for _, bean := range tables { beanMap[strings.ToLower(reflect.Indirect(reflect.ValueOf(bean)).Type().Name())] = bean @@ -259,7 +320,7 @@ func NamesToBean(names ...string) ([]any, error) { return beans, nil } -// DumpDatabase dumps all data from database according the special database SQL syntax to file system. +// DumpDatabase dumps all data from database using special SQL syntax to the file system. func DumpDatabase(filePath, dbType string) error { var tbs []*schemas.Table for _, t := range tables { @@ -286,7 +347,7 @@ func DumpDatabase(filePath, dbType string) error { return x.DumpTablesToFile(tbs, filePath) } -// MaxBatchInsertSize returns the table's max batch insert size +// MaxBatchInsertSize returns the table's max batch insert size. func MaxBatchInsertSize(bean any) int { t, err := x.TableInfo(bean) if err != nil { @@ -295,18 +356,18 @@ func MaxBatchInsertSize(bean any) int { return 999 / len(t.ColumnsSeq()) } -// IsTableNotEmpty returns true if table has at least one record +// IsTableNotEmpty returns true if the table has at least one record. func IsTableNotEmpty(beanOrTableName any) (bool, error) { return x.Table(beanOrTableName).Exist() } -// DeleteAllRecords will delete all the records of this table +// DeleteAllRecords deletes all records in the given table. func DeleteAllRecords(tableName string) error { _, err := x.Exec(fmt.Sprintf("DELETE FROM %s", tableName)) return err } -// GetMaxID will return max id of the table +// GetMaxID returns the maximum id in the table. func GetMaxID(beanOrTableName any) (maxID int64, err error) { _, err = x.Select("MAX(id)").Table(beanOrTableName).Get(&maxID) return maxID, err @@ -314,8 +375,8 @@ func GetMaxID(beanOrTableName any) (maxID int64, err error) { func SetLogSQL(ctx context.Context, on bool) { e := GetEngine(ctx) - if x, ok := e.(*xorm.Engine); ok { - x.ShowSQL(on) + if eng, ok := e.(*xorm.Engine); ok { + eng.ShowSQL(on) } else if sess, ok := e.(*xorm.Session); ok { sess.Engine().ShowSQL(on) } diff --git a/modules/setting/database.go b/modules/setting/database.go index 76fae27164..60cbaeb228 100644 --- a/modules/setting/database.go +++ b/modules/setting/database.go @@ -1,6 +1,3 @@ -// Copyright 2019 The Gitea Authors. All rights reserved. -// SPDX-License-Identifier: MIT - package setting import ( @@ -12,6 +9,8 @@ import ( "path/filepath" "strings" "time" + + "code.gitea.io/gitea/modules/log" ) var ( @@ -27,6 +26,10 @@ var ( Database = struct { Type DatabaseType Host string + HostPrimary string + HostReplica string + LoadBalancePolicy string + LoadBalanceWeights string Name string User string Passwd string @@ -63,6 +66,10 @@ func loadDBSetting(rootCfg ConfigProvider) { Database.Type = DatabaseType(sec.Key("DB_TYPE").String()) Database.Host = sec.Key("HOST").String() + Database.HostPrimary = sec.Key("HOST_PRIMARY").String() + Database.HostReplica = sec.Key("HOST_REPLICA").String() + Database.LoadBalancePolicy = sec.Key("LOAD_BALANCE_POLICY").MustString("xorm.RandomPolicy()") + Database.LoadBalanceWeights = sec.Key("LOAD_BALANCE_WEIGHTS").String() Database.Name = sec.Key("NAME").String() Database.User = sec.Key("USER").String() if len(Database.Passwd) == 0 { @@ -99,8 +106,62 @@ func loadDBSetting(rootCfg ConfigProvider) { } } -// DBConnStr returns database connection string +// DBConnStr returns a database connection string using Database.Host. func DBConnStr() (string, error) { + return dbConnStrWithHost(Database.Host) +} + +// DBMasterConnStr returns the connection string for the master (primary) database. +// If a primary host is defined in the configuration, it is used; +// otherwise, it falls back to Database.Host. +// Returns an error if no master host is provided. +func DBMasterConnStr() (string, error) { + var host string + if Database.HostPrimary != "" { + host = Database.HostPrimary + } else { + host = Database.Host + } + if host == "" { + return "", fmt.Errorf("master host is not defined while slave is defined; cannot proceed") + } + return dbConnStrWithHost(host) +} + +// DBSlaveConnStrs returns one or more connection strings for the replica databases. +// If a replica host is defined (possibly as a comma‐separated list) then those DSNs are returned. +// Otherwise, this function falls back to the master DSN (with a warning log). +func DBSlaveConnStrs() ([]string, error) { + var dsns []string + if Database.HostReplica != "" { + // support multiple replica hosts separated by commas + replicas := strings.SplitSeq(Database.HostReplica, ",") + for r := range replicas { + trimmed := strings.TrimSpace(r) + if trimmed == "" { + continue + } + dsn, err := dbConnStrWithHost(trimmed) + if err != nil { + return nil, err + } + dsns = append(dsns, dsn) + } + } + // Fall back to master if no slave DSN was provided. + if len(dsns) == 0 { + master, err := DBMasterConnStr() + if err != nil { + return nil, err + } + log.Info("DB: No dedicated replica host defined; falling back to primary DSN for replica connections") + dsns = append(dsns, master) + } + return dsns, nil +} + +// dbConnStrWithHost constructs the connection string, given a host value. +func dbConnStrWithHost(host string) (string, error) { var connStr string paramSep := "?" if strings.Contains(Database.Name, paramSep) { @@ -109,23 +170,25 @@ func DBConnStr() (string, error) { switch Database.Type { case "mysql": connType := "tcp" - if len(Database.Host) > 0 && Database.Host[0] == '/' { // looks like a unix socket + // if the host starts with '/' it is assumed to be a unix socket path + if len(host) > 0 && host[0] == '/' { connType = "unix" } tls := Database.SSLMode - if tls == "disable" { // allow (Postgres-inspired) default value to work in MySQL + // allow the "disable" value (borrowed from Postgres defaults) to behave as false + if tls == "disable" { tls = "false" } connStr = fmt.Sprintf("%s:%s@%s(%s)/%s%sparseTime=true&tls=%s", - Database.User, Database.Passwd, connType, Database.Host, Database.Name, paramSep, tls) + Database.User, Database.Passwd, connType, host, Database.Name, paramSep, tls) case "postgres": - connStr = getPostgreSQLConnectionString(Database.Host, Database.User, Database.Passwd, Database.Name, Database.SSLMode) + connStr = getPostgreSQLConnectionString(host, Database.User, Database.Passwd, Database.Name, Database.SSLMode) case "sqlite3": if !EnableSQLite3 { return "", errors.New("this Gitea binary was not built with SQLite3 support") } if err := os.MkdirAll(filepath.Dir(Database.Path), os.ModePerm); err != nil { - return "", fmt.Errorf("Failed to create directories: %w", err) + return "", fmt.Errorf("failed to create directories: %w", err) } journalMode := "" if Database.SQLiteJournalMode != "" { @@ -136,7 +199,6 @@ func DBConnStr() (string, error) { default: return "", fmt.Errorf("unknown database type: %s", Database.Type) } - return connStr, nil } @@ -185,6 +247,31 @@ func getPostgreSQLConnectionString(dbHost, dbUser, dbPasswd, dbName, dbsslMode s return connURL.String() } +func getPostgreSQLEngineGroupConnectionStrings(primaryHost, replicaHosts, user, passwd, name, sslmode string) (string, []string) { + // Determine the primary connection string. + primary := primaryHost + if strings.TrimSpace(primary) == "" { + primary = "127.0.0.1:5432" + } + primaryConn := getPostgreSQLConnectionString(primary, user, passwd, name, sslmode) + + // Build the replica connection strings. + replicaConns := []string{} + if strings.TrimSpace(replicaHosts) != "" { + // Split comma-separated replica host values. + hosts := strings.Split(replicaHosts, ",") + for _, h := range hosts { + trimmed := strings.TrimSpace(h) + if trimmed != "" { + replicaConns = append(replicaConns, + getPostgreSQLConnectionString(trimmed, user, passwd, name, sslmode)) + } + } + } + + return primaryConn, replicaConns +} + type DatabaseType string func (t DatabaseType) String() string { diff --git a/modules/setting/database_test.go b/modules/setting/database_test.go index a742d54f8c..1079d7a6f9 100644 --- a/modules/setting/database_test.go +++ b/modules/setting/database_test.go @@ -107,3 +107,79 @@ func Test_getPostgreSQLConnectionString(t *testing.T) { assert.Equal(t, test.Output, connStr) } } + +func Test_getPostgreSQLEngineGroupConnectionStrings(t *testing.T) { + tests := []struct { + primaryHost string // primary host setting (e.g. "localhost" or "[::1]:1234") + replicaHosts string // comma-separated replica hosts (e.g. "replica1,replica2:2345") + user string + passwd string + name string + sslmode string + outputPrimary string + outputReplicas []string + }{ + { + // No primary override (empty => default) and no replicas. + primaryHost: "", + replicaHosts: "", + user: "", + passwd: "", + name: "", + sslmode: "", + outputPrimary: "postgres://:@127.0.0.1:5432?sslmode=", + outputReplicas: []string{}, + }, + { + // Primary set and one replica. + primaryHost: "localhost", + replicaHosts: "replicahost", + user: "user", + passwd: "pass", + name: "gitea", + sslmode: "disable", + outputPrimary: "postgres://user:pass@localhost:5432/gitea?sslmode=disable", + outputReplicas: []string{"postgres://user:pass@replicahost:5432/gitea?sslmode=disable"}, + }, + { + // Primary with explicit port; multiple replicas (one without and one with an explicit port). + primaryHost: "localhost:5433", + replicaHosts: "replica1,replica2:5434", + user: "test", + passwd: "secret", + name: "db", + sslmode: "require", + outputPrimary: "postgres://test:secret@localhost:5433/db?sslmode=require", + outputReplicas: []string{ + "postgres://test:secret@replica1:5432/db?sslmode=require", + "postgres://test:secret@replica2:5434/db?sslmode=require", + }, + }, + { + // IPv6 addresses for primary and replica. + primaryHost: "[::1]:1234", + replicaHosts: "[::2]:2345", + user: "ipv6", + passwd: "ipv6pass", + name: "ipv6db", + sslmode: "disable", + outputPrimary: "postgres://ipv6:ipv6pass@::1:1234/ipv6db?sslmode=disable", + outputReplicas: []string{ + "postgres://ipv6:ipv6pass@::2:2345/ipv6db?sslmode=disable", + }, + }, + } + + for _, test := range tests { + primary, replicas := getPostgreSQLEngineGroupConnectionStrings( + test.primaryHost, + test.replicaHosts, + test.user, + test.passwd, + test.name, + test.sslmode, + ) + assert.Equal(t, test.outputPrimary, primary) + assert.Equal(t, test.outputReplicas, replicas) + } +} diff --git a/routers/common/db.go b/routers/common/db.go index ac24303989..d6e95f149a 100644 --- a/routers/common/db.go +++ b/routers/common/db.go @@ -28,7 +28,7 @@ func InitDBEngine(ctx context.Context) (err error) { default: } log.Info("ORM engine initialization attempt #%d/%d...", i+1, setting.Database.DBConnectRetries) - if err = db.InitEngineWithMigration(ctx, migrateWithSetting); err == nil { + if err = db.InitEngineWithMigration(ctx, func(eng db.Engine) error { return migrateWithSetting(eng.(*xorm.Engine)) }); err == nil { break } else if i == setting.Database.DBConnectRetries-1 { return err diff --git a/routers/install/install.go b/routers/install/install.go index 86e342f1f9..aaf251f0e7 100644 --- a/routers/install/install.go +++ b/routers/install/install.go @@ -36,6 +36,7 @@ import ( "code.gitea.io/gitea/services/forms" "code.forgejo.org/go-chi/session" + "xorm.io/xorm" ) const ( @@ -361,7 +362,17 @@ func SubmitInstall(ctx *context.Context) { } // Init the engine with migration - if err = db.InitEngineWithMigration(ctx, migrations.Migrate); err != nil { + // Wrap migrations.Migrate into a function of type func(db.Engine) error to fix diagnostics. + wrapperMigrate := func(e db.Engine) error { + var xe *xorm.Engine + if getter, ok := e.(interface{ Master() *xorm.Engine }); ok { + xe = getter.Master() + } else { + xe = e.(*xorm.Engine) + } + return migrations.Migrate(xe) + } + if err = db.InitEngineWithMigration(ctx, wrapperMigrate); err != nil { db.UnsetDefaultEngine() ctx.Data["Err_DbSetting"] = true ctx.RenderWithErr(ctx.Tr("install.invalid_db_setting", err), tplInstall, &form) @@ -587,7 +598,7 @@ func SubmitInstall(ctx *context.Context) { go func() { // Sleep for a while to make sure the user's browser has loaded the post-install page and its assets (images, css, js) - // What if this duration is not long enough? That's impossible -- if the user can't load the simple page in time, how could they install or use Gitea in the future .... + // What if this duration is not long enough? That's impossible -- if the user can't load the simple page in time, how could they install or use Forgejo in the future .... time.Sleep(3 * time.Second) // Now get the http.Server from this request and shut it down diff --git a/services/doctor/dbconsistency.go b/services/doctor/dbconsistency.go index 9e2fcb645f..9968a6384a 100644 --- a/services/doctor/dbconsistency.go +++ b/services/doctor/dbconsistency.go @@ -16,6 +16,8 @@ import ( repo_model "code.gitea.io/gitea/models/repo" "code.gitea.io/gitea/modules/log" "code.gitea.io/gitea/modules/setting" + + "xorm.io/xorm" ) type consistencyCheck struct { @@ -78,7 +80,16 @@ func genericOrphanCheck(name, subject, refobject, joincond string) consistencyCh func checkDBConsistency(ctx context.Context, logger log.Logger, autofix bool) error { // make sure DB version is up-to-date - if err := db.InitEngineWithMigration(ctx, migrations.EnsureUpToDate); err != nil { + ensureUpToDateWrapper := func(e db.Engine) error { + var engine *xorm.Engine + if getter, ok := e.(interface{ Master() *xorm.Engine }); ok { + engine = getter.Master() + } else { + engine = e.(*xorm.Engine) + } + return migrations.EnsureUpToDate(engine) + } + if err := db.InitEngineWithMigration(ctx, ensureUpToDateWrapper); err != nil { logger.Critical("Model version on the database does not match the current Gitea version. Model consistency will not be checked until the database is upgraded") return err } diff --git a/services/doctor/dbversion.go b/services/doctor/dbversion.go index 2a102b2194..bc53d4b4a7 100644 --- a/services/doctor/dbversion.go +++ b/services/doctor/dbversion.go @@ -9,11 +9,15 @@ import ( "code.gitea.io/gitea/models/db" "code.gitea.io/gitea/models/migrations" "code.gitea.io/gitea/modules/log" + + "xorm.io/xorm" ) func checkDBVersion(ctx context.Context, logger log.Logger, autofix bool) error { logger.Info("Expected database version: %d", migrations.ExpectedDBVersion()) - if err := db.InitEngineWithMigration(ctx, migrations.EnsureUpToDate); err != nil { + if err := db.InitEngineWithMigration(ctx, func(eng db.Engine) error { + return migrations.EnsureUpToDate(eng.(*xorm.Engine)) + }); err != nil { if !autofix { logger.Critical("Error: %v during ensure up to date", err) return err @@ -21,7 +25,9 @@ func checkDBVersion(ctx context.Context, logger log.Logger, autofix bool) error logger.Warn("Got Error: %v during ensure up to date", err) logger.Warn("Attempting to migrate to the latest DB version to fix this.") - err = db.InitEngineWithMigration(ctx, migrations.Migrate) + err = db.InitEngineWithMigration(ctx, func(eng db.Engine) error { + return migrations.Migrate(eng.(*xorm.Engine)) + }) if err != nil { logger.Critical("Error: %v during migration", err) } diff --git a/tests/integration/migration-test/migration_test.go b/tests/integration/migration-test/migration_test.go index 729d8e0dff..2e0b3f2ccb 100644 --- a/tests/integration/migration-test/migration_test.go +++ b/tests/integration/migration-test/migration_test.go @@ -278,23 +278,44 @@ func doMigrationTest(t *testing.T, version string) { setting.InitSQLLoggersForCli(log.INFO) - err := db.InitEngineWithMigration(t.Context(), wrappedMigrate) + err := db.InitEngineWithMigration(t.Context(), func(e db.Engine) error { + var engine *xorm.Engine + if eg, ok := e.(interface{ Master() *xorm.Engine }); ok { + engine = eg.Master() + } else { + engine = e.(*xorm.Engine) + } + currentEngine = engine + return wrappedMigrate(engine) + }) require.NoError(t, err) currentEngine.Close() beans, _ := db.NamesToBean() - err = db.InitEngineWithMigration(t.Context(), func(x *xorm.Engine) error { - currentEngine = x - return migrate_base.RecreateTables(beans...)(x) + err = db.InitEngineWithMigration(t.Context(), func(e db.Engine) error { + var engine *xorm.Engine + if eg, ok := e.(interface{ Master() *xorm.Engine }); ok { + engine = eg.Master() + } else { + engine = e.(*xorm.Engine) + } + currentEngine = engine + return migrate_base.RecreateTables(beans...)(engine) }) require.NoError(t, err) currentEngine.Close() // We do this a second time to ensure that there is not a problem with retained indices - err = db.InitEngineWithMigration(t.Context(), func(x *xorm.Engine) error { - currentEngine = x - return migrate_base.RecreateTables(beans...)(x) + err = db.InitEngineWithMigration(t.Context(), func(e db.Engine) error { + var engine *xorm.Engine + if eg, ok := e.(interface{ Master() *xorm.Engine }); ok { + engine = eg.Master() + } else { + engine = e.(*xorm.Engine) + } + currentEngine = engine + return migrate_base.RecreateTables(beans...)(engine) }) require.NoError(t, err)