diff --git a/pkg/database/migrate/sqlite.go b/pkg/database/migrate/sqlite.go index 2ca211803..20330ac36 100644 --- a/pkg/database/migrate/sqlite.go +++ b/pkg/database/migrate/sqlite.go @@ -44,17 +44,19 @@ func (s *SQLiteDriver) AddColumn(tableName, columnSpec string) (sql.Result, erro } func (s *SQLiteDriver) DropColumns(tableName string, columnsToDrop []string) (sql.Result, error) { + var err error + var result sql.Result if len(columnsToDrop) == 0 { return nil, fmt.Errorf("No columns to drop.") } - sql, err := s.getDDLFromTable(tableName) + tableSQL, err := s.getDDLFromTable(tableName) if err != nil { return nil, err } - columns, err := fetchColumns(sql) + columns, err := fetchColumns(tableSQL) if err != nil { return nil, err } @@ -80,8 +82,12 @@ func (s *SQLiteDriver) DropColumns(tableName string, columnsToDrop []string) (sq } // fetch indices for this table + oldSQLIndices, err := s.getDDLFromIndex(tableName) + if err != nil { + return nil, err + } + var indices []string - oldSQLIndices := s.getDdlFromIndex(tableName) for _, idx := range oldSQLIndices { listed := false for _, cols := range columnsToDrop { @@ -102,24 +108,22 @@ func (s *SQLiteDriver) DropColumns(tableName string, columnsToDrop []string) (sq } // Recreate table with dropped columns omitted - if result, err := s.CreateTable(tableName, preparedColumns); err != nil { + if result, err = s.CreateTable(tableName, preparedColumns); err != nil { return result, err } // Move data from old table - if result, err := s.Tx.Exec(fmt.Sprintf("INSERT INTO %s SELECT %s FROM %s;", tableName, + if result, err = s.Tx.Exec(fmt.Sprintf("INSERT INTO %s SELECT %s FROM %s;", tableName, strings.Join(selectName(preparedColumns), ", "), proxyName)); err != nil { return result, err } // Clean up proxy table - if result, err := s.DropTable(proxyName); err != nil { + if result, err = s.DropTable(proxyName); err != nil { return result, err } // Recreate Indices - var err error - var result sql.Result for _, idx := range indices { if result, err = s.Tx.Exec(idx); err != nil { return result, err @@ -129,12 +133,15 @@ func (s *SQLiteDriver) DropColumns(tableName string, columnsToDrop []string) (sq } func (s *SQLiteDriver) RenameColumns(tableName string, columnChanges map[string]string) (sql.Result, error) { - sql, err := s.getDDLFromTable(tableName) + var err error + var result sql.Result + + tableSQL, err := s.getDDLFromTable(tableName) if err != nil { return nil, err } - columns, err := fetchColumns(sql) + columns, err := fetchColumns(tableSQL) if err != nil { return nil, err } @@ -161,7 +168,11 @@ func (s *SQLiteDriver) RenameColumns(tableName string, columnChanges map[string] } // fetch indices for this table - oldSQLIndices := s.getDDLFromIndex(tableName) + oldSQLIndices, err := s.getDDLFromIndex(tableName) + if err != nil { + return nil, err + } + var indices []string for _, idx := range oldSQLIndices { added := false @@ -185,23 +196,21 @@ func (s *SQLiteDriver) RenameColumns(tableName string, columnChanges map[string] } // Create new table with the new columns - if result, err := s.CreateTable(tableName, newColumns); err != nil { + if result, err = s.CreateTable(tableName, newColumns); err != nil { return result, err } // Migrate data - if result, err := s.Tx.Exec(fmt.Sprintf("INSERT INTO %s SELECT %s FROM %s", tableName, + if result, err = s.Tx.Exec(fmt.Sprintf("INSERT INTO %s SELECT %s FROM %s", tableName, strings.Join(oldColumnsName, ", "), proxyName)); err != nil { return result, err } // Clean up proxy table - if result, err := s.DropTable(proxyName); err != nil { + if result, err = s.DropTable(proxyName); err != nil { return result, err } - var err error - var result sql.Result for _, idx := range indices { if result, err = s.Tx.Exec(idx); err != nil { return result, err @@ -219,3 +228,32 @@ func (s *SQLiteDriver) getDDLFromTable(tableName string) (string, error) { } return sql, nil } + +func (s *SQLiteDriver) getDDLFromIndex(tableName string) ([]string, error) { + var sqls []string + + query := `SELECT sql FROM sqlite_master WHERE type='index' and tbl_name=?;` + rows, err := s.Tx.Query(query, tableName) + if err != nil { + return sqls, err + } + + for rows.Next() { + var sql string + if err := rows.Scan(&sql); err != nil { + if strings.Contains(err.Error(), "Scan pair: -> *string") { + continue + } + return sqls, err + } + if len(sql) > 0 { + sqls = append(sqls, sql) + } + } + + if err := rows.Err(); err != nil { + return sqls, err + } + + return sqls, nil +}