Add func to fetch SQL definition for indices.

Also rearrange sql.Result and err declaration.
This commit is contained in:
Nurahmadie 2014-02-18 21:22:22 +07:00
parent 7a75c2d004
commit 8d7cf16a89

View file

@ -44,17 +44,19 @@ func (s *SQLiteDriver) AddColumn(tableName, columnSpec string) (sql.Result, erro
}
func (s *SQLiteDriver) DropColumns(tableName string, columnsToDrop []string) (sql.Result, error) {
var err error
var result sql.Result
if len(columnsToDrop) == 0 {
return nil, fmt.Errorf("No columns to drop.")
}
sql, err := s.getDDLFromTable(tableName)
tableSQL, err := s.getDDLFromTable(tableName)
if err != nil {
return nil, err
}
columns, err := fetchColumns(sql)
columns, err := fetchColumns(tableSQL)
if err != nil {
return nil, err
}
@ -80,8 +82,12 @@ func (s *SQLiteDriver) DropColumns(tableName string, columnsToDrop []string) (sq
}
// fetch indices for this table
oldSQLIndices, err := s.getDDLFromIndex(tableName)
if err != nil {
return nil, err
}
var indices []string
oldSQLIndices := s.getDdlFromIndex(tableName)
for _, idx := range oldSQLIndices {
listed := false
for _, cols := range columnsToDrop {
@ -102,24 +108,22 @@ func (s *SQLiteDriver) DropColumns(tableName string, columnsToDrop []string) (sq
}
// Recreate table with dropped columns omitted
if result, err := s.CreateTable(tableName, preparedColumns); err != nil {
if result, err = s.CreateTable(tableName, preparedColumns); err != nil {
return result, err
}
// Move data from old table
if result, err := s.Tx.Exec(fmt.Sprintf("INSERT INTO %s SELECT %s FROM %s;", tableName,
if result, err = s.Tx.Exec(fmt.Sprintf("INSERT INTO %s SELECT %s FROM %s;", tableName,
strings.Join(selectName(preparedColumns), ", "), proxyName)); err != nil {
return result, err
}
// Clean up proxy table
if result, err := s.DropTable(proxyName); err != nil {
if result, err = s.DropTable(proxyName); err != nil {
return result, err
}
// Recreate Indices
var err error
var result sql.Result
for _, idx := range indices {
if result, err = s.Tx.Exec(idx); err != nil {
return result, err
@ -129,12 +133,15 @@ func (s *SQLiteDriver) DropColumns(tableName string, columnsToDrop []string) (sq
}
func (s *SQLiteDriver) RenameColumns(tableName string, columnChanges map[string]string) (sql.Result, error) {
sql, err := s.getDDLFromTable(tableName)
var err error
var result sql.Result
tableSQL, err := s.getDDLFromTable(tableName)
if err != nil {
return nil, err
}
columns, err := fetchColumns(sql)
columns, err := fetchColumns(tableSQL)
if err != nil {
return nil, err
}
@ -161,7 +168,11 @@ func (s *SQLiteDriver) RenameColumns(tableName string, columnChanges map[string]
}
// fetch indices for this table
oldSQLIndices := s.getDDLFromIndex(tableName)
oldSQLIndices, err := s.getDDLFromIndex(tableName)
if err != nil {
return nil, err
}
var indices []string
for _, idx := range oldSQLIndices {
added := false
@ -185,23 +196,21 @@ func (s *SQLiteDriver) RenameColumns(tableName string, columnChanges map[string]
}
// Create new table with the new columns
if result, err := s.CreateTable(tableName, newColumns); err != nil {
if result, err = s.CreateTable(tableName, newColumns); err != nil {
return result, err
}
// Migrate data
if result, err := s.Tx.Exec(fmt.Sprintf("INSERT INTO %s SELECT %s FROM %s", tableName,
if result, err = s.Tx.Exec(fmt.Sprintf("INSERT INTO %s SELECT %s FROM %s", tableName,
strings.Join(oldColumnsName, ", "), proxyName)); err != nil {
return result, err
}
// Clean up proxy table
if result, err := s.DropTable(proxyName); err != nil {
if result, err = s.DropTable(proxyName); err != nil {
return result, err
}
var err error
var result sql.Result
for _, idx := range indices {
if result, err = s.Tx.Exec(idx); err != nil {
return result, err
@ -219,3 +228,32 @@ func (s *SQLiteDriver) getDDLFromTable(tableName string) (string, error) {
}
return sql, nil
}
func (s *SQLiteDriver) getDDLFromIndex(tableName string) ([]string, error) {
var sqls []string
query := `SELECT sql FROM sqlite_master WHERE type='index' and tbl_name=?;`
rows, err := s.Tx.Query(query, tableName)
if err != nil {
return sqls, err
}
for rows.Next() {
var sql string
if err := rows.Scan(&sql); err != nil {
if strings.Contains(err.Error(), "Scan pair: <nil> -> *string") {
continue
}
return sqls, err
}
if len(sql) > 0 {
sqls = append(sqls, sql)
}
}
if err := rows.Err(); err != nil {
return sqls, err
}
return sqls, nil
}