diff --git a/pkg/database/migrate/migration b/pkg/database/migrate/migration new file mode 100755 index 000000000..3e3a2ff63 --- /dev/null +++ b/pkg/database/migrate/migration @@ -0,0 +1,30 @@ +#!/usr/bin/env bash + +REV=$(date -u +%Y%m%d%H%M%S) +filename=$1 + +TAB="$(printf '\t')" + +titleize() { + echo "$1" | sed -r -e "s/-|_/ /g" -e 's/\b(.)/\U\1/g' -e 's/ //g' +} + +cat > ${REV}_$filename.go << EOF +package migrate + +type rev${REV} struct{} + +var $(titleize $filename) = &rev${REV}{} + +func (r *rev$REV) Revision() int64 { +${TAB}return $REV +} + +func (r *rev$REV) Up(op Operation) error { +${TAB}// Migration steps here +} + +func (r *rev$REV) Down(op Operation) error { +${TAB}// Revert migration steps here +} +EOF diff --git a/pkg/database/migrate/sqlite.go b/pkg/database/migrate/sqlite.go index 140988a47..2cec5a026 100644 --- a/pkg/database/migrate/sqlite.go +++ b/pkg/database/migrate/sqlite.go @@ -16,15 +16,15 @@ func SQLite(tx *sql.Tx) Operation { } func (s *SQLiteDriver) Exec(query string, args ...interface{}) (sql.Result, error) { - return s.Tx.Exec(query, args) + return s.Tx.Exec(query, args...) } func (s *SQLiteDriver) Query(query string, args ...interface{}) (*sql.Rows, error) { - return s.Tx.Query(query, args) + return s.Tx.Query(query, args...) } func (s *SQLiteDriver) QueryRow(query string, args ...interface{}) *sql.Row { - return s.Tx.QueryRow(query, args) + return s.Tx.QueryRow(query, args...) } func (s *SQLiteDriver) CreateTable(tableName string, args []string) (sql.Result, error) { @@ -44,17 +44,19 @@ func (s *SQLiteDriver) AddColumn(tableName, columnSpec string) (sql.Result, erro } func (s *SQLiteDriver) DropColumns(tableName string, columnsToDrop []string) (sql.Result, error) { + var err error + var result sql.Result if len(columnsToDrop) == 0 { return nil, fmt.Errorf("No columns to drop.") } - sql, err := s.getDDLFromTable(tableName) + tableSQL, err := s.getDDLFromTable(tableName) if err != nil { return nil, err } - columns, err := fetchColumns(sql) + columns, err := fetchColumns(tableSQL) if err != nil { return nil, err } @@ -79,6 +81,38 @@ func (s *SQLiteDriver) DropColumns(tableName string, columnsToDrop []string) (sq return nil, fmt.Errorf("No columns match, drops nothing.") } + // fetch indices for this table + oldSQLIndices, err := s.getDDLFromIndex(tableName) + if err != nil { + return nil, err + } + + var oldIdxColumns [][]string + for _, idx := range oldSQLIndices { + idxCols, err := fetchColumns(idx) + if err != nil { + return nil, err + } + oldIdxColumns = append(oldIdxColumns, idxCols) + } + + var indices []string + for k, idx := range oldSQLIndices { + listed := false + OIdxLoop: + for _, oidx := range oldIdxColumns[k] { + for _, cols := range columnsToDrop { + if oidx == cols { + listed = true + break OIdxLoop + } + } + } + if !listed { + indices = append(indices, idx) + } + } + // Rename old table, here's our proxy proxyName := fmt.Sprintf("%s_%s", tableName, uniuri.NewLen(16)) if result, err := s.RenameTable(tableName, proxyName); err != nil { @@ -86,27 +120,40 @@ func (s *SQLiteDriver) DropColumns(tableName string, columnsToDrop []string) (sq } // Recreate table with dropped columns omitted - if result, err := s.CreateTable(tableName, preparedColumns); err != nil { + if result, err = s.CreateTable(tableName, preparedColumns); err != nil { return result, err } // Move data from old table - if result, err := s.Tx.Exec(fmt.Sprintf("INSERT INTO %s SELECT %s FROM %s;", tableName, + if result, err = s.Tx.Exec(fmt.Sprintf("INSERT INTO %s SELECT %s FROM %s;", tableName, strings.Join(selectName(preparedColumns), ", "), proxyName)); err != nil { return result, err } // Clean up proxy table - return s.DropTable(proxyName) + if result, err = s.DropTable(proxyName); err != nil { + return result, err + } + + // Recreate Indices + for _, idx := range indices { + if result, err = s.Tx.Exec(idx); err != nil { + return result, err + } + } + return result, err } func (s *SQLiteDriver) RenameColumns(tableName string, columnChanges map[string]string) (sql.Result, error) { - sql, err := s.getDDLFromTable(tableName) + var err error + var result sql.Result + + tableSQL, err := s.getDDLFromTable(tableName) if err != nil { return nil, err } - columns, err := fetchColumns(sql) + columns, err := fetchColumns(tableSQL) if err != nil { return nil, err } @@ -132,6 +179,40 @@ func (s *SQLiteDriver) RenameColumns(tableName string, columnChanges map[string] } } + // fetch indices for this table + oldSQLIndices, err := s.getDDLFromIndex(tableName) + if err != nil { + return nil, err + } + + var idxColumns [][]string + for _, idx := range oldSQLIndices { + idxCols, err := fetchColumns(idx) + if err != nil { + return nil, err + } + idxColumns = append(idxColumns, idxCols) + } + + var indices []string + for k, idx := range oldSQLIndices { + added := false + IdcLoop: + for _, oldIdx := range idxColumns[k] { + for Old, New := range columnChanges { + if oldIdx == Old { + indx := strings.Replace(idx, Old, New, 2) + indices = append(indices, indx) + added = true + break IdcLoop + } + } + } + if !added { + indices = append(indices, idx) + } + } + // Rename current table proxyName := fmt.Sprintf("%s_%s", tableName, uniuri.NewLen(16)) if result, err := s.RenameTable(tableName, proxyName); err != nil { @@ -139,18 +220,27 @@ func (s *SQLiteDriver) RenameColumns(tableName string, columnChanges map[string] } // Create new table with the new columns - if result, err := s.CreateTable(tableName, newColumns); err != nil { + if result, err = s.CreateTable(tableName, newColumns); err != nil { return result, err } // Migrate data - if result, err := s.Tx.Exec(fmt.Sprintf("INSERT INTO %s SELECT %s FROM %s", tableName, + if result, err = s.Tx.Exec(fmt.Sprintf("INSERT INTO %s SELECT %s FROM %s", tableName, strings.Join(oldColumnsName, ", "), proxyName)); err != nil { return result, err } // Clean up proxy table - return s.DropTable(proxyName) + if result, err = s.DropTable(proxyName); err != nil { + return result, err + } + + for _, idx := range indices { + if result, err = s.Tx.Exec(idx); err != nil { + return result, err + } + } + return result, err } func (s *SQLiteDriver) getDDLFromTable(tableName string) (string, error) { @@ -162,3 +252,32 @@ func (s *SQLiteDriver) getDDLFromTable(tableName string) (string, error) { } return sql, nil } + +func (s *SQLiteDriver) getDDLFromIndex(tableName string) ([]string, error) { + var sqls []string + + query := `SELECT sql FROM sqlite_master WHERE type='index' and tbl_name=?;` + rows, err := s.Tx.Query(query, tableName) + if err != nil { + return sqls, err + } + + for rows.Next() { + var sql string + if err := rows.Scan(&sql); err != nil { + // This error came from autoindex, since its sql value is null, + // we want to continue. + if strings.Contains(err.Error(), "Scan pair: -> *string") { + continue + } + return sqls, err + } + sqls = append(sqls, sql) + } + + if err := rows.Err(); err != nil { + return sqls, err + } + + return sqls, nil +} diff --git a/pkg/database/migrate/sqlite_test.go b/pkg/database/migrate/sqlite_test.go index af81588f7..c26d5b9bf 100644 --- a/pkg/database/migrate/sqlite_test.go +++ b/pkg/database/migrate/sqlite_test.go @@ -3,6 +3,7 @@ package migrate import ( "database/sql" "os" + "strings" "testing" "github.com/russross/meddler" @@ -117,7 +118,117 @@ func (r *revision4) Revision() int64 { return 4 } -// ---------- +// ---------- end of revision 4 + +// ---------- revision 5 + +type revision5 struct{} + +func (r *revision5) Up(op Operation) error { + _, err := op.Exec(`CREATE INDEX samples_url_name_ix ON samples (url, name)`) + return err +} + +func (r *revision5) Down(op Operation) error { + _, err := op.Exec(`DROP INDEX samples_url_name_ix`) + return err +} + +func (r *revision5) Revision() int64 { + return 5 +} + +// ---------- end of revision 5 + +// ---------- revision 6 +type revision6 struct{} + +func (r *revision6) Up(op Operation) error { + _, err := op.RenameColumns("samples", map[string]string{ + "url": "host", + }) + return err +} + +func (r *revision6) Down(op Operation) error { + _, err := op.RenameColumns("samples", map[string]string{ + "host": "url", + }) + return err +} + +func (r *revision6) Revision() int64 { + return 6 +} + +// ---------- end of revision 6 + +// ---------- revision 7 +type revision7 struct{} + +func (r *revision7) Up(op Operation) error { + _, err := op.DropColumns("samples", []string{"host", "num"}) + return err +} + +func (r *revision7) Down(op Operation) error { + if _, err := op.AddColumn("samples", "host VARCHAR(255)"); err != nil { + return err + } + _, err := op.AddColumn("samples", "num INSTEGER") + return err +} + +func (r *revision7) Revision() int64 { + return 7 +} + +// ---------- end of revision 7 + +// ---------- revision 8 +type revision8 struct{} + +func (r *revision8) Up(op Operation) error { + if _, err := op.AddColumn("samples", "repo_id INTEGER"); err != nil { + return err + } + _, err := op.AddColumn("samples", "repo VARCHAR(255)") + return err +} + +func (r *revision8) Down(op Operation) error { + _, err := op.DropColumns("samples", []string{"repo", "repo_id"}) + return err +} + +func (r *revision8) Revision() int64 { + return 8 +} + +// ---------- end of revision 8 + +// ---------- revision 9 +type revision9 struct{} + +func (r *revision9) Up(op Operation) error { + _, err := op.RenameColumns("samples", map[string]string{ + "repo": "repository", + }) + return err +} + +func (r *revision9) Down(op Operation) error { + _, err := op.RenameColumns("samples", map[string]string{ + "repository": "repo", + }) + return err +} + +func (r *revision9) Revision() int64 { + return 9 +} + +// ---------- end of revision 9 var db *sql.DB @@ -305,6 +416,90 @@ func TestMigrateExistingTable(t *testing.T) { } } +type sqliteMaster struct { + Sql interface{} `meddler:"sql"` +} + +func TestIndexOperations(t *testing.T) { + defer tearDown() + if err := setUp(); err != nil { + t.Fatalf("Error preparing database: %q", err) + } + + Driver = SQLite + + mgr := New(db) + + // Migrate, create index + if err := mgr.Add(&revision1{}, &revision3{}, &revision5{}).Migrate(); err != nil { + t.Errorf("Can not migrate: %q", err) + } + + var esquel []*sqliteMaster + // Query sqlite_master, check if index is exists. + query := `SELECT sql FROM sqlite_master WHERE type='index' and tbl_name='samples'` + if err := meddler.QueryAll(db, &esquel, query); err != nil { + t.Errorf("Can not find index: %q", err) + } + + indexStatement := `CREATE INDEX samples_url_name_ix ON samples (url, name)` + if string(esquel[1].Sql.([]byte)) != indexStatement { + t.Errorf("Can not find index") + } + + // Migrate, rename indexed columns + if err := mgr.Add(&revision6{}).Migrate(); err != nil { + t.Errorf("Can not migrate: %q", err) + } + + var esquel1 []*sqliteMaster + if err := meddler.QueryAll(db, &esquel1, query); err != nil { + t.Errorf("Can not find index: %q", err) + } + + indexStatement = `CREATE INDEX samples_host_name_ix ON samples (host, name)` + if string(esquel1[1].Sql.([]byte)) != indexStatement { + t.Errorf("Can not find index, got: %s", esquel[0]) + } + + if err := mgr.Add(&revision7{}).Migrate(); err != nil { + t.Errorf("Can not migrate: %q", err) + } + + var esquel2 []*sqliteMaster + if err := meddler.QueryAll(db, &esquel2, query); err != nil { + t.Errorf("Can not find index: %q", err) + } + + if len(esquel2) != 1 { + t.Errorf("Expect row length equal to %d, got %d", 1, len(esquel2)) + } +} + +func TestColumnRedundancy(t *testing.T) { + defer tearDown() + if err := setUp(); err != nil { + t.Fatalf("Error preparing database: %q", err) + } + + Driver = SQLite + + migr := New(db) + if err := migr.Add(&revision1{}, &revision8{}, &revision9{}).Migrate(); err != nil { + t.Errorf("Can not migrate: %q", err) + } + + var tableSql string + query := `SELECT sql FROM sqlite_master where type='table' and name='samples'` + if err := db.QueryRow(query).Scan(&tableSql); err != nil { + t.Errorf("Can not query sqlite_master: %q", err) + } + + if !strings.Contains(tableSql, "repository ") { + t.Errorf("Expect column with name repository") + } +} + func setUp() error { var err error db, err = sql.Open("sqlite3", "migration_tests.sqlite") diff --git a/pkg/database/migrate/util.go b/pkg/database/migrate/util.go index a0f6bfb59..d2001a071 100644 --- a/pkg/database/migrate/util.go +++ b/pkg/database/migrate/util.go @@ -6,12 +6,12 @@ import ( ) func fetchColumns(sql string) ([]string, error) { - if !strings.HasPrefix(sql, "CREATE TABLE ") { + if !strings.HasPrefix(sql, "CREATE ") { return []string{}, fmt.Errorf("Sql input is not a DDL statement.") } parenIdx := strings.Index(sql, "(") - return strings.Split(sql[parenIdx+1:len(sql)-1], ","), nil + return strings.Split(sql[parenIdx+1:strings.LastIndex(sql, ")")], ","), nil } func selectName(columns []string) []string {