diff --git a/pkg/database/migrate/sqlite.go b/pkg/database/migrate/sqlite.go new file mode 100644 index 000000000..a2074b08e --- /dev/null +++ b/pkg/database/migrate/sqlite.go @@ -0,0 +1,135 @@ +package migrate + +import ( + "database/sql" + "fmt" + "strings" + + _ "github.com/mattn/go-sqlite3" + "github.com/dchest/uniuri" +) + +type SQLiteDriver MigrationDriver + +func SQLite(tx *sql.Tx) Operation { + return &SQLiteDriver{Tx: tx} +} + +func (s *SQLiteDriver) CreateTable(tableName string, args []string) (sql.Result, error) { + return s.Tx.Exec(fmt.Sprintf("CREATE TABLE %s (%s);", tableName, strings.Join(args, ", "))) +} + +func (s *SQLiteDriver) RenameTable(tableName, newName string) (sql.Result, error) { + return s.Tx.Exec(fmt.Sprintf("ALTER TABLE %s RENAME TO %s;", tableName, newName)) +} + +func (s *SQLiteDriver) DropTable(tableName string) (sql.Result, error) { + return s.Tx.Exec(fmt.Sprintf("DROP TABLE IF EXISTS %s;", tableName)) +} + +func (s *SQLiteDriver) AddColumn(tableName, columnSpec string) (sql.Result, error) { + return s.Tx.Exec(fmt.Sprintf("ALTER TABLE %s ADD COLUMN %s;", tableName, columnSpec)) +} + +func (s *SQLiteDriver) DropColumns(tableName string, columnsToDrop []string) (sql.Result, error) { + + if len(columnsToDrop) == 0 { + return nil, fmt.Errorf("No columns to drop.") + } + + sql, err := s.getDDLFromTable(tableName) + if err != nil { + return nil, err + } + + columns, err := fetchColumns(sql) + if err != nil { + return nil, err + } + + columnNames := selectName(columns) + preparedColumns := make([]string, len(columnNames)-len(columnsToDrop)) + for k, column := range columnNames { + listed := false + for _, dropped := range columnsToDrop { + if column == dropped { + listed = true + break + } + } + if !listed { + preparedColumns = append(preparedColumns, columns[k]) + } + } + + if len(preparedColumns) == 0 { + return nil, fmt.Errorf("No columns match, drops nothing.") + } + + // Rename old table, here's our proxy + proxyName := fmt.Sprintf("%s_%s", tableName, uniuri.NewLen(16)) + if result, err := s.RenameTable(tableName, proxyName); err != nil { + return result, err + } + + // Recreate table with dropped columns omitted + if result, err := s.CreateTable(tableName, preparedColumns); err != nil { + return result, err + } + + // Move data from old table + if result, err := s.Tx.Exec(fmt.Sprintf("INSERT INTO %s SELECT %s FROM %s;", tableName, + strings.Join(selectName(preparedColumns), ", "), proxyName)); err != nil { + return result, err + } + + // Clean up proxy table + return s.Tx.Exec(fmt.Sprintf("DROP TABLE %s;", proxyName)) +} + +func (s *SQLiteDriver) RenameColumns(tableName string, columnChanges map[string]string) (sql.Result, error) { + sql, err := s.getDDLFromTable(tableName) + if err != nil { + return nil, err + } + + columns, err := fetchColumns(sql) + if err != nil { + return nil, err + } + + oldColumns := make([]string, len(columnChanges)) + newColumns := make([]string, len(columnChanges)) + for k, column := range selectName(columns) { + for Old, New := range columnChanges { + if column == Old { + columnToAdd := strings.Replace(columns[k], Old, New, 1) + + if results, err := s.AddColumn(tableName, columnToAdd); err != nil { + return results, err + } + + oldColumns = append(oldColumns, Old) + newColumns = append(newColumns, New) + break + } + } + } + + statement := fmt.Sprintf("UPDATE %s SET %s;", tableName, setForUpdate(oldColumns, newColumns)) + if results, err := s.Tx.Exec(statement); err != nil { + return results, err + } + + return s.DropColumns(tableName, oldColumns) +} + +func (s *SQLiteDriver) getDDLFromTable(tableName string) (string, error) { + var sql string + query := `SELECT sql FROM sqlite_master WHERE type='table' and name='?';` + err := s.Tx.QueryRow(query, tableName).Scan(&sql) + if err != nil { + return "", err + } + return sql, nil +} diff --git a/pkg/database/migrate/sqlite_test.go b/pkg/database/migrate/sqlite_test.go new file mode 100644 index 000000000..ab7d8492b --- /dev/null +++ b/pkg/database/migrate/sqlite_test.go @@ -0,0 +1,164 @@ +package migrate + +import ( + "database/sql" + "os" + "testing" + + "github.com/russross/meddler" +) + +type Sample struct { + ID int64 `meddler:"id,pk"` + Imel string `meddler:"imel"` + Name string `meddler:"name"` +} + +type RenameSample struct { + ID int64 `meddler:"id,pk"` + Email string `meddler:"email"` + Name string `meddler:"name"` +} + +type AddColumnSample struct { + ID int64 `meddler:"id,pk"` + Imel string `meddler:"imel"` + Name string `meddler:"name"` + Num int64 `meddler:"num"` +} + +type RemoveColumnSample struct { + ID int64 `meddler:"id,pk"` + Name string `meddler:"name"` +} + +// ---------- revision 1 + +type revision1 struct{} + +func (r *revision1) Up(op Operation) error { + _, err := op.CreateTable("samples", []string{ + "id INTEGER PRIMARY KEY AUTOINCREMENT", + "imel VARCHAR(255) UNIQUE", + "name VARCHAR(255)", + }) + return err +} + +func (r *revision1) Down(op Operation) error { + _, err := op.DropTable("samples") + return err +} + +func (r *revision1) Revision() int64 { + return 1 +} + +// ---------- end of revision 1 + +// ---------- revision 2 + +type revision2 struct{} + +func (r *revision2) Up(op Operation) error { + _, err := op.RenameTable("samples", "examples") + return err +} + +func (r *revision2) Down(op Operation) error { + _, err := op.RenameTable("examples", "samples") + return err +} + +func (r *revision2) Revision() int64 { + return 2 +} + +// ---------- end of revision 2 + +var db *sql.DB + +var testSchema = ` +CREATE TABLE samples ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + imel VARCHAR(255) UNIQUE, + name VARCHAR(255), +); +` + +var dataDump = []string{ + `INSERT INTO samples (imel, name) VALUES ('test@example.com', 'Test Tester');`, + `INSERT INTO samples (imel, name) VALUES ('foo@bar.com', 'Foo Bar');`, + `INSERT INTO samples (imel, name) VALUES ('crash@bandicoot.io', 'Crash Bandicoot');`, +} + +func TestMigrateCreateTable(t *testing.T) { + defer tearDown() + if err := setUp(); err != nil { + t.Fatalf("Error preparing database: %q", err) + } + + Driver = SQLite + + mgr := New(db) + if err := mgr.Add(&revision1{}).Migrate(); err != nil { + t.Errorf("Can not migrate: %q", err) + } + + sample := Sample{ + ID: 1, + Imel: "test@example.com", + Name: "Test Tester", + } + if err := meddler.Save(db, "samples", &sample); err != nil { + t.Errorf("Can not save data: %q", err) + } +} + +func TestMigrateRenameTable(t *testing.T) { + defer tearDown() + if err := setUp(); err != nil { + t.Fatalf("Error preparing database: %q", err) + } + + Driver = SQLite + + mgr := New(db) + if err := mgr.Add(&revision1{}).Migrate(); err != nil { + t.Errorf("Can not migrate: %q", err) + } + + loadFixture(t) + + if err := mgr.Add(&revision2{}).Migrate(); err != nil { + t.Errorf("Can not migrate: %q", err) + } + + sample := Sample{} + if err := meddler.QueryRow(db, &sample, `SELECT * FROM examples WHERE id = ?`, 2); err != nil { + t.Errorf("Can not fetch data: %q", err) + } + + if sample.Imel != "foo@bar.com" { + t.Errorf("Column doesn't match\n\texpect:\t%s\n\tget:\t%s", "foo@bar.com", sample.Imel) + } +} + +func setUp() error { + var err error + db, err = sql.Open("sqlite3", "migration_tests.sqlite") + return err +} + +func tearDown() { + db.Close() + os.Remove("migration_tests.sqlite") +} + +func loadFixture(t *testing.T) { + for _, sql := range dataDump { + if _, err := db.Exec(sql); err != nil { + t.Errorf("Can not insert into database: %q", err) + } + } +} diff --git a/pkg/database/migrate/util.go b/pkg/database/migrate/util.go new file mode 100644 index 000000000..1dfec95d7 --- /dev/null +++ b/pkg/database/migrate/util.go @@ -0,0 +1,32 @@ +package migrate + +import ( + "fmt" + "strings" +) + +func fetchColumns(sql string) ([]string, error) { + if !strings.HasPrefix(sql, "CREATE TABLE ") { + return []string{}, fmt.Errorf("Sql input is not a DDL statement.") + } + + parenIdx := strings.Index(sql, "(") + return strings.Split(sql[parenIdx+1:len(sql)-1], ","), nil +} + +func selectName(columns []string) []string { + results := make([]string, len(columns)) + for _, column := range columns { + col := strings.SplitN(strings.Trim(column, " \n\t"), " ", 2) + results = append(results, col[0]) + } + return results +} + +func setForUpdate(left []string, right []string) string { + results := make([]string, len(left)) + for k, str := range left { + results = append(results, fmt.Sprintf("%s = %s", str, right[k])) + } + return strings.Join(results, ", ") +}