woodpecker/server/store/datastore/migration/common.go
2022-01-05 21:50:23 +01:00

243 lines
7.7 KiB
Go

// Copyright 2021 Woodpecker Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package migration
import (
"fmt"
"regexp"
"strings"
"xorm.io/xorm"
"xorm.io/xorm/schemas"
)
func renameTable(sess *xorm.Session, old, new string) error {
dialect := sess.Engine().Dialect().URI().DBType
switch dialect {
case schemas.MYSQL:
_, err := sess.Exec(fmt.Sprintf("RENAME TABLE `%s` TO `%s`;", old, new))
return err
case schemas.POSTGRES, schemas.SQLITE:
_, err := sess.Exec(fmt.Sprintf("ALTER TABLE `%s` RENAME TO `%s`;", old, new))
return err
default:
return fmt.Errorf("dialect '%s' not supported", dialect)
}
}
// WARNING: YOU MUST COMMIT THE SESSION AT THE END
func dropTableColumns(sess *xorm.Session, tableName string, columnNames ...string) (err error) {
// Copyright 2017 The Gitea Authors. All rights reserved.
// Use of this source code is governed by a MIT-style
// license that can be found in the LICENSE file.
if tableName == "" || len(columnNames) == 0 {
return nil
}
// TODO: This will not work if there are foreign keys
dialect := sess.Engine().Dialect().URI().DBType
switch dialect {
case schemas.SQLITE:
// First drop the indexes on the columns
res, errIndex := sess.Query(fmt.Sprintf("PRAGMA index_list(`%s`)", tableName))
if errIndex != nil {
return errIndex
}
for _, row := range res {
indexName := row["name"]
indexRes, err := sess.Query(fmt.Sprintf("PRAGMA index_info(`%s`)", indexName))
if err != nil {
return err
}
if len(indexRes) != 1 {
continue
}
indexColumn := string(indexRes[0]["name"])
for _, name := range columnNames {
if name == indexColumn {
_, err := sess.Exec(fmt.Sprintf("DROP INDEX `%s`", indexName))
if err != nil {
return err
}
}
}
}
// Here we need to get the columns from the original table
sql := fmt.Sprintf("SELECT sql FROM sqlite_master WHERE tbl_name='%s' and type='table'", tableName)
res, err := sess.Query(sql)
if err != nil {
return err
}
tableSQL := normalizeSQLiteTableSchema(string(res[0]["sql"]))
// Separate out the column definitions
tableSQL = tableSQL[strings.Index(tableSQL, "("):]
// Remove the required columnNames
tableSQL = removeColumnFromSQLITETableSchema(tableSQL, columnNames...)
// Ensure the query is ended properly
tableSQL = strings.TrimSpace(tableSQL)
if tableSQL[len(tableSQL)-1] != ')' {
if tableSQL[len(tableSQL)-1] == ',' {
tableSQL = tableSQL[:len(tableSQL)-1]
}
tableSQL += ")"
}
// Find all the columns in the table
var columns []string
for _, rawColumn := range strings.Split(strings.ReplaceAll(tableSQL[1:len(tableSQL)-1], ", ", ",\n"), "\n") {
if strings.ContainsAny(rawColumn, "()") {
continue
}
rawColumn = strings.TrimSpace(rawColumn)
columns = append(columns,
strings.ReplaceAll(rawColumn[0:strings.Index(rawColumn, " ")], "`", ""),
)
}
tableSQL = fmt.Sprintf("CREATE TABLE `new_%s_new` ", tableName) + tableSQL
if _, err := sess.Exec(tableSQL); err != nil {
return err
}
// Now restore the data
columnsSeparated := strings.Join(columns, ",")
insertSQL := fmt.Sprintf("INSERT INTO `new_%s_new` (%s) SELECT %s FROM %s", tableName, columnsSeparated, columnsSeparated, tableName)
if _, err := sess.Exec(insertSQL); err != nil {
return err
}
// Now drop the old table
if _, err := sess.Exec(fmt.Sprintf("DROP TABLE `%s`", tableName)); err != nil {
return err
}
// Rename the table
if _, err := sess.Exec(fmt.Sprintf("ALTER TABLE `new_%s_new` RENAME TO `%s`", tableName, tableName)); err != nil {
return err
}
case schemas.POSTGRES:
cols := ""
for _, col := range columnNames {
if cols != "" {
cols += ", "
}
cols += "DROP COLUMN `" + col + "` CASCADE"
}
if _, err := sess.Exec(fmt.Sprintf("ALTER TABLE `%s` %s", tableName, cols)); err != nil {
return fmt.Errorf("drop table `%s` columns %v: %v", tableName, columnNames, err)
}
case schemas.MYSQL:
// Drop indexes on columns first
sql := fmt.Sprintf("SHOW INDEX FROM %s WHERE column_name IN ('%s')", tableName, strings.Join(columnNames, "','"))
res, err := sess.Query(sql)
if err != nil {
return err
}
for _, index := range res {
indexName := index["column_name"]
if len(indexName) > 0 {
_, err := sess.Exec(fmt.Sprintf("DROP INDEX `%s` ON `%s`", indexName, tableName))
if err != nil {
return err
}
}
}
// Now drop the columns
cols := ""
for _, col := range columnNames {
if cols != "" {
cols += ", "
}
cols += "DROP COLUMN `" + col + "`"
}
if _, err := sess.Exec(fmt.Sprintf("ALTER TABLE `%s` %s", tableName, cols)); err != nil {
return fmt.Errorf("drop table `%s` columns %v: %v", tableName, columnNames, err)
}
case schemas.MSSQL:
cols := ""
for _, col := range columnNames {
if cols != "" {
cols += ", "
}
cols += "`" + strings.ToLower(col) + "`"
}
sql := fmt.Sprintf("SELECT Name FROM sys.default_constraints WHERE parent_object_id = OBJECT_ID('%[1]s') AND parent_column_id IN (SELECT column_id FROM sys.columns WHERE LOWER(name) IN (%[2]s) AND object_id = OBJECT_ID('%[1]s'))",
tableName, strings.ReplaceAll(cols, "`", "'"))
constraints := make([]string, 0)
if err := sess.SQL(sql).Find(&constraints); err != nil {
return fmt.Errorf("find constraints: %v", err)
}
for _, constraint := range constraints {
if _, err := sess.Exec(fmt.Sprintf("ALTER TABLE `%s` DROP CONSTRAINT `%s`", tableName, constraint)); err != nil {
return fmt.Errorf("drop table `%s` default constraint `%s`: %v", tableName, constraint, err)
}
}
sql = fmt.Sprintf("SELECT DISTINCT Name FROM sys.indexes INNER JOIN sys.index_columns ON indexes.index_id = index_columns.index_id AND indexes.object_id = index_columns.object_id WHERE indexes.object_id = OBJECT_ID('%[1]s') AND index_columns.column_id IN (SELECT column_id FROM sys.columns WHERE LOWER(name) IN (%[2]s) AND object_id = OBJECT_ID('%[1]s'))",
tableName, strings.ReplaceAll(cols, "`", "'"))
constraints = make([]string, 0)
if err := sess.SQL(sql).Find(&constraints); err != nil {
return fmt.Errorf("find constraints: %v", err)
}
for _, constraint := range constraints {
if _, err := sess.Exec(fmt.Sprintf("DROP INDEX `%[2]s` ON `%[1]s`", tableName, constraint)); err != nil {
return fmt.Errorf("drop index `%[2]s` on `%[1]s`: %v", tableName, constraint, err)
}
}
if _, err := sess.Exec(fmt.Sprintf("ALTER TABLE `%s` DROP COLUMN %s", tableName, cols)); err != nil {
return fmt.Errorf("drop table `%s` columns %v: %v", tableName, columnNames, err)
}
default:
return fmt.Errorf("dialect '%s' not supported", dialect)
}
return nil
}
var (
whitespaces = regexp.MustCompile(`\s+`)
columnSeparator = regexp.MustCompile(`\s?,\s?`)
)
func removeColumnFromSQLITETableSchema(schema string, names ...string) string {
if len(names) == 0 {
return schema
}
for i := range names {
if len(names[i]) == 0 {
continue
}
schema = regexp.MustCompile(`\s(`+
regexp.QuoteMeta("`"+names[i]+"`")+
"|"+
regexp.QuoteMeta(names[i])+
")[^`,)]*?[,)]").ReplaceAllString(schema, "")
}
return schema
}
func normalizeSQLiteTableSchema(schema string) string {
return columnSeparator.ReplaceAllString(
whitespaces.ReplaceAllString(
strings.ReplaceAll(schema, "\n", " "),
" "),
", ")
}