Merge pull request #171 from fudanchii/migration

[RFC] database refactorization
This commit is contained in:
Brad Rydzewski 2014-03-25 12:45:30 -07:00
commit fe1f2666f3
23 changed files with 965 additions and 289 deletions

View file

@ -25,6 +25,7 @@ deps:
go get github.com/drone/go-bitbucket/bitbucket
go get github.com/GeertJohan/go.rice
go get github.com/GeertJohan/go.rice/rice
go get github.com/go-sql-driver/mysql
go get github.com/mattn/go-sqlite3
go get github.com/russross/meddler

View file

@ -64,6 +64,27 @@ you can still get a feel for the steps:
https://docs.google.com/file/d/0By8deR1ROz8memUxV0lTSGZPQUk
**Using MySQL**
By default, Drone use sqlite as its database storage. To use MySQL/MariaDB instead, use `-driver` flag
and set it to `mysql`. You will need to set your DSN (`-datasource`) in this form:
```
user:password@tcp(hostname:port)/dbname?parseTime=true
```
Change it according to your database settings. The parseTime above is required since drone using
`time.Time` to represents `TIMESTAMP` data. Please refer to [1] for more options on mysql driver.
You may also need to tweak some innodb options, especially if you're using `utf8mb4` collation type.
```
innodb_file_format = Barracuda
innodb_file_per_table = On
innodb_large_prefix = On
```
Please consult to the MySQL/MariaDB documentation for further information
regarding large prefix for index column and dynamic row format (which is used in Drone).
[1] https://github.com/go-sql-driver/mysql
### Builds
Drone use a **.drone.yml** configuration file in the root of your

View file

@ -1,7 +1,6 @@
package main
import (
"database/sql"
"flag"
"log"
"net/http"
@ -12,23 +11,15 @@ import (
"code.google.com/p/go.net/websocket"
"github.com/GeertJohan/go.rice"
"github.com/bmizerany/pat"
_ "github.com/mattn/go-sqlite3"
"github.com/russross/meddler"
"github.com/drone/drone/pkg/build/docker"
"github.com/drone/drone/pkg/channel"
"github.com/drone/drone/pkg/database"
"github.com/drone/drone/pkg/database/migrate"
"github.com/drone/drone/pkg/handler"
"github.com/drone/drone/pkg/queue"
)
var (
// local path where the SQLite database
// should be stored. By default this is
// in the current working directory.
path string
// port the server will run on
port string
@ -57,7 +48,6 @@ var (
func main() {
// parse command line flags
flag.StringVar(&path, "path", "", "")
flag.StringVar(&port, "port", ":8080", "")
flag.StringVar(&driver, "driver", "sqlite3", "")
flag.StringVar(&datasource, "datasource", "drone.sqlite", "")
@ -71,7 +61,9 @@ func main() {
checkTLSFlags()
// setup database and handlers
setupDatabase()
if err := database.Init(driver, datasource); err != nil {
log.Fatal("Can't initialize database: ", err)
}
setupStatic()
setupHandlers()
@ -97,25 +89,6 @@ func checkTLSFlags() {
}
// setup the database connection and register with the
// global database package.
func setupDatabase() {
// inform meddler and migration we're using sqlite
meddler.Default = meddler.SQLite
migrate.Driver = migrate.SQLite
// connect to the SQLite database
db, err := sql.Open(driver, datasource)
if err != nil {
log.Fatal(err)
}
database.Set(db)
migration := migrate.New(db)
migration.All().Migrate()
}
// setup routes for static assets. These assets may
// be directly embedded inside the application using
// the `rice embed` command, else they are served from disk.

View file

@ -16,7 +16,7 @@ SELECT id, repo_id, status, started, finished, duration,
hash, branch, pull_request, author, gravatar, timestamp, message, created, updated
FROM commits
WHERE repo_id = ? AND branch = ?
ORDER BY created DESC
ORDER BY created DESC, id DESC
LIMIT 10
`
@ -26,7 +26,7 @@ SELECT id, repo_id, status, started, finished, duration,
hash, branch, pull_request, author, gravatar, timestamp, message, created, updated
FROM commits
WHERE repo_id = ? AND branch = ?
ORDER BY created DESC
ORDER BY created DESC, id DESC
LIMIT 1
`
@ -57,7 +57,7 @@ WHERE r.user_id = ?
AND r.team_id = 0
AND r.id = c.repo_id
AND c.status IN ('Success', 'Failure')
ORDER BY c.created desc
ORDER BY c.created desc, c.id desc
LIMIT 10
`
@ -70,7 +70,7 @@ FROM repos r, commits c
WHERE r.team_id = ?
AND r.id = c.repo_id
AND c.status IN ('Success', 'Failure')
ORDER BY c.created desc
ORDER BY c.created desc, c.id desc
LIMIT 10
`

View file

@ -2,23 +2,62 @@ package database
import (
"database/sql"
"log"
"fmt"
"github.com/drone/drone/pkg/database/schema"
"github.com/drone/drone/pkg/database/migrate"
_ "github.com/go-sql-driver/mysql"
_ "github.com/mattn/go-sqlite3"
"github.com/russross/meddler"
)
// global instance of our database connection.
var db *sql.DB
// Set sets the default database.
func Set(database *sql.DB) {
// set the global database
db = database
// load the database schema. If this is
// a new database all the tables and
// indexes will be created.
if err := schema.Load(db); err != nil {
log.Fatal(err)
// Init connects to database and performs migration if necessary.
//
// Database driver name and data source information is provided by user
// from within command line, and error checking is deferred to sql.Open.
//
// Init will just bail out and returns error if driver name
// is not listed, no fallback nor default driver sets here.
func Init(name, datasource string) error {
var err error
driver := map[string]struct {
Md *meddler.Database
Mg migrate.DriverBuilder
}{
"sqlite3": {
meddler.SQLite,
migrate.SQLite,
},
"mysql": {
meddler.MySQL,
migrate.MySQL,
},
}
if drv, ok := driver[name]; ok {
meddler.Default = drv.Md
migrate.Driver = drv.Mg
} else {
return fmt.Errorf("%s driver not found", name)
}
db, err = sql.Open(name, datasource)
if err != nil {
return err
}
migration := migrate.New(db)
if err := migration.All().Migrate(); err != nil {
return err
}
return nil
}
// Close database connection.
func Close() {
db.Close()
}

View file

@ -0,0 +1,153 @@
package migrate
type rev1st struct{}
var SetupTables = &rev1st{}
func (r *rev1st) Revision() int64 {
return 1
}
func (r *rev1st) Up(mg *MigrationDriver) error {
t := mg.T
if _, err := mg.CreateTable("users", []string{
t.Integer("id", PRIMARYKEY, AUTOINCREMENT),
t.String("email", UNIQUE),
t.String("password"),
t.String("token", UNIQUE),
t.String("name"),
t.String("gravatar"),
t.Timestamp("created"),
t.Timestamp("updated"),
t.Bool("admin"),
t.String("github_login"),
t.String("github_token"),
t.String("bitbucket_login"),
t.String("bitbucket_token"),
t.String("bitbucket_secret"),
}); err != nil {
return err
}
if _, err := mg.CreateTable("teams", []string{
t.Integer("id", PRIMARYKEY, AUTOINCREMENT),
t.String("slug", UNIQUE),
t.String("name"),
t.String("email"),
t.String("gravatar"),
t.Timestamp("created"),
t.Timestamp("updated"),
}); err != nil {
return err
}
if _, err := mg.CreateTable("members", []string{
t.Integer("id", PRIMARYKEY, AUTOINCREMENT),
t.Integer("team_id"),
t.Integer("user_id"),
t.String("role"),
}); err != nil {
return err
}
if _, err := mg.CreateTable("repos", []string{
t.Integer("id", PRIMARYKEY, AUTOINCREMENT),
t.String("slug", UNIQUE),
t.String("host"),
t.String("owner"),
t.String("name"),
t.Bool("private"),
t.Bool("disabled"),
t.Bool("disabled_pr"),
t.Bool("priveleged"),
t.Integer("timeout"),
t.Varchar("scm", 25),
t.Varchar("url", 1024),
t.String("username"),
t.String("password"),
t.Varchar("public_key", 1024),
t.Varchar("private_key", 1024),
t.Blob("params"),
t.Timestamp("created"),
t.Timestamp("updated"),
t.Integer("user_id"),
t.Integer("team_id"),
}); err != nil {
return err
}
if _, err := mg.CreateTable("commits", []string{
t.Integer("id", PRIMARYKEY, AUTOINCREMENT),
t.Integer("repo_id"),
t.String("status"),
t.Timestamp("started"),
t.Timestamp("finished"),
t.Integer("duration"),
t.Integer("attempts"),
t.String("hash"),
t.String("branch"),
t.String("pull_request"),
t.String("author"),
t.String("gravatar"),
t.String("timestamp"),
t.String("message"),
t.Timestamp("created"),
t.Timestamp("updated"),
}); err != nil {
return err
}
if _, err := mg.CreateTable("builds", []string{
t.Integer("id", PRIMARYKEY, AUTOINCREMENT),
t.Integer("commit_id"),
t.String("slug"),
t.String("status"),
t.Timestamp("started"),
t.Timestamp("finished"),
t.Integer("duration"),
t.Timestamp("created"),
t.Timestamp("updated"),
t.Text("stdout"),
}); err != nil {
return err
}
_, err := mg.CreateTable("settings", []string{
t.Integer("id", PRIMARYKEY, AUTOINCREMENT),
t.String("github_key"),
t.String("github_secret"),
t.String("bitbucket_key"),
t.String("bitbucket_secret"),
t.Varchar("smtp_server", 1024),
t.Varchar("smtp_port", 5),
t.Varchar("smtp_address", 1024),
t.Varchar("smtp_username", 1024),
t.Varchar("smtp_password", 1024),
t.Varchar("hostname", 1024),
t.Varchar("scheme", 5),
})
return err
}
func (r *rev1st) Down(mg *MigrationDriver) error {
if _, err := mg.DropTable("settings"); err != nil {
return err
}
if _, err := mg.DropTable("builds"); err != nil {
return err
}
if _, err := mg.DropTable("commits"); err != nil {
return err
}
if _, err := mg.DropTable("repos"); err != nil {
return err
}
if _, err := mg.DropTable("members"); err != nil {
return err
}
if _, err := mg.DropTable("teams"); err != nil {
return err
}
_, err := mg.DropTable("users")
return err
}

View file

@ -8,15 +8,15 @@ func (r *Rev1) Revision() int64 {
return 201402200603
}
func (r *Rev1) Up(op Operation) error {
_, err := op.RenameColumns("repos", map[string]string{
func (r *Rev1) Up(mg *MigrationDriver) error {
_, err := mg.RenameColumns("repos", map[string]string{
"priveleged": "privileged",
})
return err
}
func (r *Rev1) Down(op Operation) error {
_, err := op.RenameColumns("repos", map[string]string{
func (r *Rev1) Down(mg *MigrationDriver) error {
_, err := mg.RenameColumns("repos", map[string]string{
"privileged": "priveleged",
})
return err

View file

@ -8,19 +8,19 @@ func (r *Rev3) Revision() int64 {
return 201402211147
}
func (r *Rev3) Up(op Operation) error {
_, err := op.AddColumn("settings", "github_domain VARCHAR(255)")
func (r *Rev3) Up(mg *MigrationDriver) error {
_, err := mg.AddColumn("settings", "github_domain VARCHAR(255)")
if err != nil {
return err
}
_, err = op.AddColumn("settings", "github_apiurl VARCHAR(255)")
_, err = mg.AddColumn("settings", "github_apiurl VARCHAR(255)")
op.Exec("update settings set github_domain=?", "github.com")
op.Exec("update settings set github_apiurl=?", "https://api.github.com")
mg.Tx.Exec("update settings set github_domain=?", "github.com")
mg.Tx.Exec("update settings set github_apiurl=?", "https://api.github.com")
return err
}
func (r *Rev3) Down(op Operation) error {
_, err := op.DropColumns("settings", []string{"github_domain", "github_apiurl"})
func (r *Rev3) Down(mg *MigrationDriver) error {
_, err := mg.DropColumns("settings", "github_domain", "github_apiurl")
return err
}

View file

@ -0,0 +1,21 @@
package migrate
type rev20140310104446 struct{}
var AddOpenInvitationColumn = &rev20140310104446{}
func (r *rev20140310104446) Revision() int64 {
return 20140310104446
}
func (r *rev20140310104446) Up(mg *MigrationDriver) error {
// Suppress error here for backward compatibility
_, err := mg.AddColumn("settings", "open_invitations BOOLEAN")
_, err = mg.Tx.Exec("UPDATE settings SET open_invitations=0 WHERE open_invitations IS NULL")
return err
}
func (r *rev20140310104446) Down(mg *MigrationDriver) error {
_, err := mg.DropColumns("settings", "open_invitations")
return err
}

View file

@ -0,0 +1,83 @@
package migrate
type rev2nd struct{}
var SetupIndices = &rev2nd{}
func (r *rev2nd) Revision() int64 {
return 2
}
func (r *rev2nd) Up(mg *MigrationDriver) error {
if _, err := mg.AddIndex("members", []string{"team_id", "user_id"}, "unique"); err != nil {
return err
}
if _, err := mg.AddIndex("members", []string{"team_id"}); err != nil {
return err
}
if _, err := mg.AddIndex("members", []string{"user_id"}); err != nil {
return err
}
if _, err := mg.AddIndex("commits", []string{"repo_id", "hash", "branch"}, "unique"); err != nil {
return err
}
if _, err := mg.AddIndex("commits", []string{"repo_id"}); err != nil {
return err
}
if _, err := mg.AddIndex("commits", []string{"repo_id", "branch"}); err != nil {
return err
}
if _, err := mg.AddIndex("repos", []string{"team_id"}); err != nil {
return err
}
if _, err := mg.AddIndex("repos", []string{"user_id"}); err != nil {
return err
}
if _, err := mg.AddIndex("builds", []string{"commit_id"}); err != nil {
return err
}
_, err := mg.AddIndex("builds", []string{"commit_id", "slug"})
return err
}
func (r *rev2nd) Down(mg *MigrationDriver) error {
if _, err := mg.DropIndex("builds", []string{"commit_id", "slug"}); err != nil {
return err
}
if _, err := mg.DropIndex("builds", []string{"commit_id"}); err != nil {
return err
}
if _, err := mg.DropIndex("repos", []string{"user_id"}); err != nil {
return err
}
if _, err := mg.DropIndex("repos", []string{"team_id"}); err != nil {
return err
}
if _, err := mg.DropIndex("commits", []string{"repo_id", "branch"}); err != nil {
return err
}
if _, err := mg.DropIndex("commits", []string{"repo_id"}); err != nil {
return err
}
if _, err := mg.DropIndex("commits", []string{"repo_id", "hash", "branch"}); err != nil {
return err
}
if _, err := mg.DropIndex("members", []string{"user_id"}); err != nil {
return err
}
if _, err := mg.DropIndex("members", []string{"team_id"}); err != nil {
return err
}
_, err := mg.DropIndex("members", []string{"team_id", "user_id"})
return err
}

View file

@ -1,10 +1,17 @@
package migrate
// All is called to collect all migration scripts
// and adds them to Revision list. New Revision
// should be added here ordered by its revision
// number.
func (m *Migration) All() *Migration {
// List all migrations here
m.Add(SetupTables)
m.Add(SetupIndices)
m.Add(RenamePrivelegedToPrivileged)
m.Add(GitHubEnterpriseSupport)
m.Add(AddOpenInvitationColumn)
// m.Add(...)
// ...

View file

@ -0,0 +1,63 @@
package migrate
import (
"database/sql"
)
// Operation interface covers basic migration operations.
// Implementation details is specific for each database,
// see migrate/sqlite.go for implementation reference.
type Operation interface {
// CreateTable may be used to create a table named `tableName`
// with its columns specification listed in `args` as an array of string
CreateTable(tableName string, args []string) (sql.Result, error)
// RenameTable simply rename table from `tableName` to `newName`
RenameTable(tableName, newName string) (sql.Result, error)
// DropTable drops table named `tableName`
DropTable(tableName string) (sql.Result, error)
// AddColumn adds single new column to `tableName`, columnSpec is
// a standard column definition (column name included) which may looks like this:
//
// mg.AddColumn("example", "email VARCHAR(255) UNIQUE")
//
// it's equivalent to:
//
// mg.AddColumn("example", mg.T.String("email", UNIQUE))
//
AddColumn(tableName, columnSpec string) (sql.Result, error)
// ChangeColumn may be used to change the type of a column
// `newType` should always specify the column's new type even
// if the type is not meant to be change. Eg.
//
// mg.ChangeColumn("example", "name", "VARCHAR(255) UNIQUE")
//
ChangeColumn(tableName, columnName, newType string) (sql.Result, error)
// DropColumns drops a list of columns
DropColumns(tableName string, columnsToDrop ...string) (sql.Result, error)
// RenameColumns will rename columns listed in `columnChanges`
RenameColumns(tableName string, columnChanges map[string]string) (sql.Result, error)
// AddIndex adds index on `tableName` indexed by `columns`
AddIndex(tableName string, columns []string, flags ...string) (sql.Result, error)
// DropIndex drops index indexed by `columns` from `tableName`
DropIndex(tableName string, columns []string) (sql.Result, error)
}
// MigrationDriver drives migration script by injecting transaction object (*sql.Tx),
// `Operation` implementation and column type helper.
type MigrationDriver struct {
Operation
T *columnType
Tx *sql.Tx
}
// DriverBuilder is a constructor for MigrationDriver
type DriverBuilder func(tx *sql.Tx) *MigrationDriver

View file

@ -0,0 +1,96 @@
package migrate
import (
"fmt"
"reflect"
"strings"
)
const (
UNIQUE int = iota
PRIMARYKEY
AUTOINCREMENT
NULL
NOTNULL
)
// columnType will be injected to migration script
// along with MigrationDriver. `AttrMap` is used to
// defines distinct column's attribute between database
// implementation. e.g. 'AUTOINCREMENT' in sqlite and
// 'AUTO_INCREMENT' in mysql.
type columnType struct {
Driver string
AttrMap map[int]string
}
// defaultMap defines default values for column's attribute
// lookup.
var defaultMap = map[int]string{
UNIQUE: "UNIQUE",
PRIMARYKEY: "PRIMARY KEY",
AUTOINCREMENT: "AUTOINCREMENT",
NULL: "NULL",
NOTNULL: "NOT NULL",
}
// Integer returns column definition for INTEGER typed column.
// Additional attributes may be specified as string or predefined key
// listed in defaultMap.
func (c *columnType) Integer(colName string, spec ...interface{}) string {
return fmt.Sprintf("%s INTEGER %s", colName, c.parseAttr(spec))
}
// String returns column definition for VARCHAR(255) typed column.
func (c *columnType) String(colName string, spec ...interface{}) string {
return fmt.Sprintf("%s VARCHAR(255) %s", colName, c.parseAttr(spec))
}
// Text returns column definition for TEXT typed column.
func (c *columnType) Text(colName string, spec ...interface{}) string {
return fmt.Sprintf("%s TEXT %s", colName, c.parseAttr(spec))
}
// Blob returns column definition for BLOB typed column
func (c *columnType) Blob(colName string, spec ...interface{}) string {
return fmt.Sprintf("%s BLOB %s", colName, c.parseAttr(spec))
}
// Timestamp returns column definition for TIMESTAMP typed column
func (c *columnType) Timestamp(colName string, spec ...interface{}) string {
return fmt.Sprintf("%s TIMESTAMP %s", colName, c.parseAttr(spec))
}
// Bool returns column definition for BOOLEAN typed column
func (c *columnType) Bool(colName string, spec ...interface{}) string {
return fmt.Sprintf("%s BOOLEAN %s", colName, c.parseAttr(spec))
}
// Varchar returns column definition for VARCHAR typed column.
// column's max length is specified as `length`.
func (c *columnType) Varchar(colName string, length int, spec ...interface{}) string {
return fmt.Sprintf("%s VARCHAR(%d) %s", colName, length, c.parseAttr(spec))
}
// attr returns string representation of column attribute specified as key for defaultMap.
func (c *columnType) attr(flag int) string {
if v, ok := c.AttrMap[flag]; ok {
return v
}
return defaultMap[flag]
}
// parseAttr reflects spec value for its type and returns the string
// representation returned by `attr`
func (c *columnType) parseAttr(spec []interface{}) string {
var attrs []string
for _, v := range spec {
switch reflect.ValueOf(v).Kind() {
case reflect.Int:
attrs = append(attrs, c.attr(v.(int)))
case reflect.String:
attrs = append(attrs, v.(string))
}
}
return strings.Join(attrs, " ")
}

View file

@ -1,24 +1,3 @@
// Usage
// migrate.To(2)
// .Add(Version_1)
// .Add(Version_2)
// .Add(Version_3)
// .Exec(db)
//
// migrate.ToLatest()
// .Add(Version_1)
// .Add(Version_2)
// .Add(Version_3)
// .SetDialect(migrate.MySQL)
// .Exec(db)
//
// migrate.ToLatest()
// .Add(Version_1)
// .Add(Version_2)
// .Add(Version_3)
// .Backup(path)
// .Exec()
package migrate
import (
@ -28,7 +7,7 @@ import (
const migrationTableStmt = `
CREATE TABLE IF NOT EXISTS migration (
revision NUMBER PRIMARY KEY
revision BIGINT PRIMARY KEY
)
`
@ -49,45 +28,18 @@ const deleteRevisionStmt = `
DELETE FROM migration where revision = ?
`
// Operation interface covers basic migration operations.
// Implementation details is specific for each database,
// see migrate/sqlite.go for implementation reference.
type Operation interface {
CreateTable(tableName string, args []string) (sql.Result, error)
RenameTable(tableName, newName string) (sql.Result, error)
DropTable(tableName string) (sql.Result, error)
AddColumn(tableName, columnSpec string) (sql.Result, error)
DropColumns(tableName string, columnsToDrop []string) (sql.Result, error)
RenameColumns(tableName string, columnChanges map[string]string) (sql.Result, error)
Exec(query string, args ...interface{}) (sql.Result, error)
Query(query string, args ...interface{}) (*sql.Rows, error)
QueryRow(query string, args ...interface{}) *sql.Row
}
type Revision interface {
Up(op Operation) error
Down(op Operation) error
Up(mg *MigrationDriver) error
Down(mg *MigrationDriver) error
Revision() int64
}
type MigrationDriver struct {
Tx *sql.Tx
}
type Migration struct {
db *sql.DB
revs []Revision
}
var Driver func(tx *sql.Tx) Operation
var Driver DriverBuilder
func New(db *sql.DB) *Migration {
return &Migration{db: db}
@ -99,7 +51,7 @@ func (m *Migration) Add(rev ...Revision) *Migration {
return m
}
// Execute the full list of migrations.
// Migrate executes the full list of migrations.
func (m *Migration) Migrate() error {
var target int64
if len(m.revs) > 0 {
@ -111,7 +63,7 @@ func (m *Migration) Migrate() error {
return m.MigrateTo(target)
}
// Execute all database migration until
// MigrateTo executes all database migration until
// you are at the specified revision number.
// If the revision number is less than the
// current revision, then we will downgrade.
@ -148,14 +100,14 @@ func (m *Migration) up(target, current int64) error {
return err
}
op := Driver(tx)
mg := Driver(tx)
// loop through and execute revisions
for _, rev := range m.revs {
if rev.Revision() > current && rev.Revision() <= target {
current = rev.Revision()
// execute the revision Upgrade.
if err := rev.Up(op); err != nil {
if err := rev.Up(mg); err != nil {
log.Printf("Failed to upgrade to Revision Number %v\n", current)
log.Println(err)
return tx.Rollback()
@ -181,7 +133,7 @@ func (m *Migration) down(target, current int64) error {
return err
}
op := Driver(tx)
mg := Driver(tx)
// reverse the list of revisions
revs := []Revision{}
@ -195,7 +147,7 @@ func (m *Migration) down(target, current int64) error {
if rev.Revision() > target {
current = rev.Revision()
// execute the revision Upgrade.
if err := rev.Down(op); err != nil {
if err := rev.Down(mg); err != nil {
log.Printf("Failed to downgrade from Revision Number %v\n", current)
log.Println(err)
return tx.Rollback()

View file

@ -9,6 +9,24 @@ titleize() {
echo "$1" | sed -r -e "s/-|_/ /g" -e 's/\b(.)/\U\1/g' -e 's/ //g'
}
howto() {
echo "Usage:"
echo " ./migration create_sample_table"
echo ""
echo "Above invocation will create a migration script called:"
echo " ${REV}_create_sample_table.go"
echo "You can add your migration step at the Up and Down function"
echo "definition inside the file."
echo ""
echo "Database transaction available through MigrationDriver,"
echo "so you can access mg.Tx (sql.Tx instance) directly,"
echo "there are also some migration helpers available, see api.go"
echo "for the list of available helpers (Operation interface)."
echo ""
}
[[ $# -eq 0 ]] && howto && exit 0
cat > ${REV}_$filename.go << EOF
package migrate
@ -20,11 +38,11 @@ func (r *rev$REV) Revision() int64 {
${TAB}return $REV
}
func (r *rev$REV) Up(op Operation) error {
func (r *rev$REV) Up(mg *MigrationDriver) error {
${TAB}// Migration steps here
}
func (r *rev$REV) Down(op Operation) error {
func (r *rev$REV) Down(mg *MigrationDriver) error {
${TAB}// Revert migration steps here
}
EOF

View file

@ -0,0 +1,109 @@
package migrate
import (
"database/sql"
"fmt"
"strings"
)
type mysqlDriver struct {
Tx *sql.Tx
}
func MySQL(tx *sql.Tx) *MigrationDriver {
return &MigrationDriver{
Tx: tx,
Operation: &mysqlDriver{Tx: tx},
T: &columnType{
AttrMap: map[int]string{AUTOINCREMENT: "AUTO_INCREMENT"},
},
}
}
func (m *mysqlDriver) CreateTable(tableName string, args []string) (sql.Result, error) {
return m.Tx.Exec(fmt.Sprintf("CREATE TABLE %s (%s) ROW_FORMAT=DYNAMIC", tableName, strings.Join(args, ", ")))
}
func (m *mysqlDriver) RenameTable(tableName, newName string) (sql.Result, error) {
return m.Tx.Exec(fmt.Sprintf("ALTER TABLE %s RENAME TO %s", tableName, newName))
}
func (m *mysqlDriver) DropTable(tableName string) (sql.Result, error) {
return m.Tx.Exec(fmt.Sprintf("DROP TABLE IF EXISTS %s", tableName))
}
func (m *mysqlDriver) AddColumn(tableName, columnSpec string) (sql.Result, error) {
return m.Tx.Exec(fmt.Sprintf("ALTER TABLE %s ADD COLUMN (%s)", tableName, columnSpec))
}
func (m *mysqlDriver) ChangeColumn(tableName, columnName, newSpecs string) (sql.Result, error) {
return m.Tx.Exec(fmt.Sprintf("ALTER TABLE %s MODIFY %s %s", tableName, columnName, newSpecs))
}
func (m *mysqlDriver) DropColumns(tableName string, columnsToDrop ...string) (sql.Result, error) {
for k, v := range columnsToDrop {
columnsToDrop[k] = fmt.Sprintf("DROP %s", v)
}
return m.Tx.Exec(fmt.Sprintf("ALTER TABLE %s %s", tableName, strings.Join(columnsToDrop, ", ")))
}
func (m *mysqlDriver) RenameColumns(tableName string, columnChanges map[string]string) (sql.Result, error) {
var columns []string
tableSQL, err := m.getTableDefinition(tableName)
if err != nil {
return nil, err
}
columns, err = fetchColumns(tableSQL)
if err != nil {
return nil, err
}
var colspec []string
for k, v := range columnChanges {
for _, col := range columns {
col = strings.Trim(col, " \n")
cols := strings.SplitN(col, " ", 2)
if quote(k) == cols[0] {
colspec = append(colspec, fmt.Sprintf("CHANGE %s %s %s", k, v, cols[1]))
break
}
}
}
return m.Tx.Exec(fmt.Sprintf("ALTER TABLE %s %s", tableName, strings.Join(colspec, ", ")))
}
func (m *mysqlDriver) AddIndex(tableName string, columns []string, flags ...string) (sql.Result, error) {
flag := ""
if len(flags) > 0 {
switch strings.ToUpper(flags[0]) {
case "UNIQUE":
fallthrough
case "FULLTEXT":
fallthrough
case "SPATIAL":
flag = flags[0]
}
}
return m.Tx.Exec(fmt.Sprintf("CREATE %s INDEX %s ON %s (%s)", flag,
indexName(tableName, columns), tableName, strings.Join(columns, ", ")))
}
func (m *mysqlDriver) DropIndex(tableName string, columns []string) (sql.Result, error) {
return m.Tx.Exec(fmt.Sprintf("DROP INDEX %s on %s", indexName(tableName, columns), tableName))
}
func (m *mysqlDriver) getTableDefinition(tableName string) (string, error) {
var name, def string
st := fmt.Sprintf("SHOW CREATE TABLE %s", tableName)
if err := m.Tx.QueryRow(st).Scan(&name, &def); err != nil {
return "", err
}
return def, nil
}
func quote(name string) string {
return fmt.Sprintf("`%s`", name)
}

View file

@ -0,0 +1,53 @@
package migrate
import (
"database/sql"
"errors"
)
type postgresqlDriver struct {
Tx *sql.Tx
}
func PostgreSQL(tx *sql.Tx) *MigrationDriver {
return &MigrationDriver{
Tx: tx,
Operation: &postgresqlDriver{Tx: tx},
}
}
func (p *postgresqlDriver) CreateTable(tableName string, args []string) (sql.Result, error) {
return nil, errors.New("not implemented yet")
}
func (p *postgresqlDriver) RenameTable(tableName, newName string) (sql.Result, error) {
return nil, errors.New("not implemented yet")
}
func (p *postgresqlDriver) DropTable(tableName string) (sql.Result, error) {
return nil, errors.New("not implemented yet")
}
func (p *postgresqlDriver) AddColumn(tableName, columnSpec string) (sql.Result, error) {
return nil, errors.New("not implemented yet")
}
func (p *postgresqlDriver) ChangeColumn(tableName, columnName, newSpecs string) (sql.Result, error) {
return nil, errors.New("not implemented yet")
}
func (p *postgresqlDriver) DropColumns(tableName string, columnsToDrop ...string) (sql.Result, error) {
return nil, errors.New("not implemented yet")
}
func (p *postgresqlDriver) RenameColumns(tableName string, columnChanges map[string]string) (sql.Result, error) {
return nil, errors.New("not implemented yet")
}
func (p *postgresqlDriver) AddIndex(tableName string, columns []string, flags ...string) (sql.Result, error) {
return nil, errors.New("not implemented yet")
}
func (p *postgresqlDriver) DropIndex(tableName string, columns []string) (sql.Result, error) {
return nil, errors.New("not implemented yet")
}

View file

@ -4,46 +4,94 @@ import (
"database/sql"
"fmt"
"strings"
"github.com/dchest/uniuri"
_ "github.com/mattn/go-sqlite3"
)
type SQLiteDriver MigrationDriver
func SQLite(tx *sql.Tx) Operation {
return &SQLiteDriver{Tx: tx}
type sqliteDriver struct {
Tx *sql.Tx
}
func (s *SQLiteDriver) Exec(query string, args ...interface{}) (sql.Result, error) {
return s.Tx.Exec(query, args...)
func SQLite(tx *sql.Tx) *MigrationDriver {
return &MigrationDriver{
Tx: tx,
Operation: &sqliteDriver{Tx: tx},
T: &columnType{},
}
}
func (s *SQLiteDriver) Query(query string, args ...interface{}) (*sql.Rows, error) {
return s.Tx.Query(query, args...)
func (s *sqliteDriver) CreateTable(tableName string, args []string) (sql.Result, error) {
return s.Tx.Exec(fmt.Sprintf("CREATE TABLE %s (%s)", tableName, strings.Join(args, ", ")))
}
func (s *SQLiteDriver) QueryRow(query string, args ...interface{}) *sql.Row {
return s.Tx.QueryRow(query, args...)
func (s *sqliteDriver) RenameTable(tableName, newName string) (sql.Result, error) {
return s.Tx.Exec(fmt.Sprintf("ALTER TABLE %s RENAME TO %s", tableName, newName))
}
func (s *SQLiteDriver) CreateTable(tableName string, args []string) (sql.Result, error) {
return s.Tx.Exec(fmt.Sprintf("CREATE TABLE %s (%s);", tableName, strings.Join(args, ", ")))
func (s *sqliteDriver) DropTable(tableName string) (sql.Result, error) {
return s.Tx.Exec(fmt.Sprintf("DROP TABLE IF EXISTS %s", tableName))
}
func (s *SQLiteDriver) RenameTable(tableName, newName string) (sql.Result, error) {
return s.Tx.Exec(fmt.Sprintf("ALTER TABLE %s RENAME TO %s;", tableName, newName))
func (s *sqliteDriver) AddColumn(tableName, columnSpec string) (sql.Result, error) {
return s.Tx.Exec(fmt.Sprintf("ALTER TABLE %s ADD COLUMN %s", tableName, columnSpec))
}
func (s *SQLiteDriver) DropTable(tableName string) (sql.Result, error) {
return s.Tx.Exec(fmt.Sprintf("DROP TABLE IF EXISTS %s;", tableName))
func (s *sqliteDriver) ChangeColumn(tableName, columnName, newType string) (sql.Result, error) {
var result sql.Result
var err error
tableSQL, err := s.getTableDefinition(tableName)
if err != nil {
return nil, err
}
columns, err := fetchColumns(tableSQL)
if err != nil {
return nil, err
}
columnNames := selectName(columns)
for k, column := range columnNames {
if columnName == column {
columns[k] = fmt.Sprintf("%s %s", columnName, newType)
break
}
}
indices, err := s.getIndexDefinition(tableName)
if err != nil {
return nil, err
}
proxy := proxyName(tableName)
if result, err = s.RenameTable(tableName, proxy); err != nil {
return nil, err
}
if result, err = s.CreateTable(tableName, columns); err != nil {
return nil, err
}
// Migrate data
if result, err = s.Tx.Exec(fmt.Sprintf("INSERT INTO %s SELECT %s FROM %s", tableName,
strings.Join(columnNames, ", "), proxy)); err != nil {
return result, err
}
// Clean up proxy table
if result, err = s.DropTable(proxy); err != nil {
return result, err
}
for _, idx := range indices {
if result, err = s.Tx.Exec(idx); err != nil {
return result, err
}
}
return result, err
}
func (s *SQLiteDriver) AddColumn(tableName, columnSpec string) (sql.Result, error) {
return s.Tx.Exec(fmt.Sprintf("ALTER TABLE %s ADD COLUMN %s;", tableName, columnSpec))
}
func (s *SQLiteDriver) DropColumns(tableName string, columnsToDrop []string) (sql.Result, error) {
func (s *sqliteDriver) DropColumns(tableName string, columnsToDrop ...string) (sql.Result, error) {
var err error
var result sql.Result
@ -51,7 +99,7 @@ func (s *SQLiteDriver) DropColumns(tableName string, columnsToDrop []string) (sq
return nil, fmt.Errorf("No columns to drop.")
}
tableSQL, err := s.getDDLFromTable(tableName)
tableSQL, err := s.getTableDefinition(tableName)
if err != nil {
return nil, err
}
@ -82,7 +130,7 @@ func (s *SQLiteDriver) DropColumns(tableName string, columnsToDrop []string) (sq
}
// fetch indices for this table
oldSQLIndices, err := s.getDDLFromIndex(tableName)
oldSQLIndices, err := s.getIndexDefinition(tableName)
if err != nil {
return nil, err
}
@ -114,8 +162,8 @@ func (s *SQLiteDriver) DropColumns(tableName string, columnsToDrop []string) (sq
}
// Rename old table, here's our proxy
proxyName := fmt.Sprintf("%s_%s", tableName, uniuri.NewLen(16))
if result, err := s.RenameTable(tableName, proxyName); err != nil {
proxy := proxyName(tableName)
if result, err := s.RenameTable(tableName, proxy); err != nil {
return result, err
}
@ -125,13 +173,13 @@ func (s *SQLiteDriver) DropColumns(tableName string, columnsToDrop []string) (sq
}
// Move data from old table
if result, err = s.Tx.Exec(fmt.Sprintf("INSERT INTO %s SELECT %s FROM %s;", tableName,
strings.Join(selectName(preparedColumns), ", "), proxyName)); err != nil {
if result, err = s.Tx.Exec(fmt.Sprintf("INSERT INTO %s SELECT %s FROM %s", tableName,
strings.Join(selectName(preparedColumns), ", "), proxy)); err != nil {
return result, err
}
// Clean up proxy table
if result, err = s.DropTable(proxyName); err != nil {
if result, err = s.DropTable(proxy); err != nil {
return result, err
}
@ -144,11 +192,11 @@ func (s *SQLiteDriver) DropColumns(tableName string, columnsToDrop []string) (sq
return result, err
}
func (s *SQLiteDriver) RenameColumns(tableName string, columnChanges map[string]string) (sql.Result, error) {
func (s *sqliteDriver) RenameColumns(tableName string, columnChanges map[string]string) (sql.Result, error) {
var err error
var result sql.Result
tableSQL, err := s.getDDLFromTable(tableName)
tableSQL, err := s.getTableDefinition(tableName)
if err != nil {
return nil, err
}
@ -180,7 +228,7 @@ func (s *SQLiteDriver) RenameColumns(tableName string, columnChanges map[string]
}
// fetch indices for this table
oldSQLIndices, err := s.getDDLFromIndex(tableName)
oldSQLIndices, err := s.getIndexDefinition(tableName)
if err != nil {
return nil, err
}
@ -214,8 +262,8 @@ func (s *SQLiteDriver) RenameColumns(tableName string, columnChanges map[string]
}
// Rename current table
proxyName := fmt.Sprintf("%s_%s", tableName, uniuri.NewLen(16))
if result, err := s.RenameTable(tableName, proxyName); err != nil {
proxy := proxyName(tableName)
if result, err := s.RenameTable(tableName, proxy); err != nil {
return result, err
}
@ -226,12 +274,12 @@ func (s *SQLiteDriver) RenameColumns(tableName string, columnChanges map[string]
// Migrate data
if result, err = s.Tx.Exec(fmt.Sprintf("INSERT INTO %s SELECT %s FROM %s", tableName,
strings.Join(oldColumnsName, ", "), proxyName)); err != nil {
strings.Join(oldColumnsName, ", "), proxy)); err != nil {
return result, err
}
// Clean up proxy table
if result, err = s.DropTable(proxyName); err != nil {
if result, err = s.DropTable(proxy); err != nil {
return result, err
}
@ -243,9 +291,24 @@ func (s *SQLiteDriver) RenameColumns(tableName string, columnChanges map[string]
return result, err
}
func (s *SQLiteDriver) getDDLFromTable(tableName string) (string, error) {
func (s *sqliteDriver) AddIndex(tableName string, columns []string, flags ...string) (sql.Result, error) {
flag := ""
if len(flags) > 0 {
if strings.ToUpper(flags[0]) == "UNIQUE" {
flag = flags[0]
}
}
return s.Tx.Exec(fmt.Sprintf("CREATE %s INDEX %s ON %s (%s)", flag, indexName(tableName, columns),
tableName, strings.Join(columns, ", ")))
}
func (s *sqliteDriver) DropIndex(tableName string, columns []string) (sql.Result, error) {
return s.Tx.Exec(fmt.Sprintf("DROP INDEX %s", indexName(tableName, columns)))
}
func (s *sqliteDriver) getTableDefinition(tableName string) (string, error) {
var sql string
query := `SELECT sql FROM sqlite_master WHERE type='table' and name=?;`
query := `SELECT sql FROM sqlite_master WHERE type='table' and name=?`
err := s.Tx.QueryRow(query, tableName).Scan(&sql)
if err != nil {
return "", err
@ -253,26 +316,23 @@ func (s *SQLiteDriver) getDDLFromTable(tableName string) (string, error) {
return sql, nil
}
func (s *SQLiteDriver) getDDLFromIndex(tableName string) ([]string, error) {
func (s *sqliteDriver) getIndexDefinition(tableName string) ([]string, error) {
var sqls []string
query := `SELECT sql FROM sqlite_master WHERE type='index' and tbl_name=?;`
query := `SELECT sql FROM sqlite_master WHERE type='index' and tbl_name=?`
rows, err := s.Tx.Query(query, tableName)
if err != nil {
return sqls, err
}
for rows.Next() {
var sql string
var sql sql.NullString
if err := rows.Scan(&sql); err != nil {
// This error came from autoindex, since its sql value is null,
// we want to continue.
if strings.Contains(err.Error(), "Scan pair: <nil> -> *string") {
continue
}
return sqls, err
}
sqls = append(sqls, sql)
if sql.Valid {
sqls = append(sqls, sql.String)
}
}
if err := rows.Err(); err != nil {

View file

@ -1,4 +1,4 @@
package migrate
package migrate_test
import (
"database/sql"
@ -6,6 +6,9 @@ import (
"strings"
"testing"
. "github.com/drone/drone/pkg/database/migrate"
_ "github.com/mattn/go-sqlite3"
"github.com/russross/meddler"
)
@ -33,8 +36,8 @@ type AddColumnSample struct {
type revision1 struct{}
func (r *revision1) Up(op Operation) error {
_, err := op.CreateTable("samples", []string{
func (r *revision1) Up(mg *MigrationDriver) error {
_, err := mg.CreateTable("samples", []string{
"id INTEGER PRIMARY KEY AUTOINCREMENT",
"imel VARCHAR(255) UNIQUE",
"name VARCHAR(255)",
@ -42,8 +45,8 @@ func (r *revision1) Up(op Operation) error {
return err
}
func (r *revision1) Down(op Operation) error {
_, err := op.DropTable("samples")
func (r *revision1) Down(mg *MigrationDriver) error {
_, err := mg.DropTable("samples")
return err
}
@ -57,13 +60,13 @@ func (r *revision1) Revision() int64 {
type revision2 struct{}
func (r *revision2) Up(op Operation) error {
_, err := op.RenameTable("samples", "examples")
func (r *revision2) Up(mg *MigrationDriver) error {
_, err := mg.RenameTable("samples", "examples")
return err
}
func (r *revision2) Down(op Operation) error {
_, err := op.RenameTable("examples", "samples")
func (r *revision2) Down(mg *MigrationDriver) error {
_, err := mg.RenameTable("examples", "samples")
return err
}
@ -77,16 +80,16 @@ func (r *revision2) Revision() int64 {
type revision3 struct{}
func (r *revision3) Up(op Operation) error {
if _, err := op.AddColumn("samples", "url VARCHAR(255)"); err != nil {
func (r *revision3) Up(mg *MigrationDriver) error {
if _, err := mg.AddColumn("samples", "url VARCHAR(255)"); err != nil {
return err
}
_, err := op.AddColumn("samples", "num INTEGER")
_, err := mg.AddColumn("samples", "num INTEGER")
return err
}
func (r *revision3) Down(op Operation) error {
_, err := op.DropColumns("samples", []string{"num", "url"})
func (r *revision3) Down(mg *MigrationDriver) error {
_, err := mg.DropColumns("samples", "num", "url")
return err
}
@ -100,15 +103,15 @@ func (r *revision3) Revision() int64 {
type revision4 struct{}
func (r *revision4) Up(op Operation) error {
_, err := op.RenameColumns("samples", map[string]string{
func (r *revision4) Up(mg *MigrationDriver) error {
_, err := mg.RenameColumns("samples", map[string]string{
"imel": "email",
})
return err
}
func (r *revision4) Down(op Operation) error {
_, err := op.RenameColumns("samples", map[string]string{
func (r *revision4) Down(mg *MigrationDriver) error {
_, err := mg.RenameColumns("samples", map[string]string{
"email": "imel",
})
return err
@ -124,13 +127,13 @@ func (r *revision4) Revision() int64 {
type revision5 struct{}
func (r *revision5) Up(op Operation) error {
_, err := op.Exec(`CREATE INDEX samples_url_name_ix ON samples (url, name)`)
func (r *revision5) Up(mg *MigrationDriver) error {
_, err := mg.AddIndex("samples", []string{"url", "name"})
return err
}
func (r *revision5) Down(op Operation) error {
_, err := op.Exec(`DROP INDEX samples_url_name_ix`)
func (r *revision5) Down(mg *MigrationDriver) error {
_, err := mg.DropIndex("samples", []string{"url", "name"})
return err
}
@ -143,15 +146,15 @@ func (r *revision5) Revision() int64 {
// ---------- revision 6
type revision6 struct{}
func (r *revision6) Up(op Operation) error {
_, err := op.RenameColumns("samples", map[string]string{
func (r *revision6) Up(mg *MigrationDriver) error {
_, err := mg.RenameColumns("samples", map[string]string{
"url": "host",
})
return err
}
func (r *revision6) Down(op Operation) error {
_, err := op.RenameColumns("samples", map[string]string{
func (r *revision6) Down(mg *MigrationDriver) error {
_, err := mg.RenameColumns("samples", map[string]string{
"host": "url",
})
return err
@ -166,16 +169,16 @@ func (r *revision6) Revision() int64 {
// ---------- revision 7
type revision7 struct{}
func (r *revision7) Up(op Operation) error {
_, err := op.DropColumns("samples", []string{"host", "num"})
func (r *revision7) Up(mg *MigrationDriver) error {
_, err := mg.DropColumns("samples", "host", "num")
return err
}
func (r *revision7) Down(op Operation) error {
if _, err := op.AddColumn("samples", "host VARCHAR(255)"); err != nil {
func (r *revision7) Down(mg *MigrationDriver) error {
if _, err := mg.AddColumn("samples", "host VARCHAR(255)"); err != nil {
return err
}
_, err := op.AddColumn("samples", "num INSTEGER")
_, err := mg.AddColumn("samples", "num INSTEGER")
return err
}
@ -188,16 +191,16 @@ func (r *revision7) Revision() int64 {
// ---------- revision 8
type revision8 struct{}
func (r *revision8) Up(op Operation) error {
if _, err := op.AddColumn("samples", "repo_id INTEGER"); err != nil {
func (r *revision8) Up(mg *MigrationDriver) error {
if _, err := mg.AddColumn("samples", "repo_id INTEGER"); err != nil {
return err
}
_, err := op.AddColumn("samples", "repo VARCHAR(255)")
_, err := mg.AddColumn("samples", "repo VARCHAR(255)")
return err
}
func (r *revision8) Down(op Operation) error {
_, err := op.DropColumns("samples", []string{"repo", "repo_id"})
func (r *revision8) Down(mg *MigrationDriver) error {
_, err := mg.DropColumns("samples", "repo", "repo_id")
return err
}
@ -210,15 +213,15 @@ func (r *revision8) Revision() int64 {
// ---------- revision 9
type revision9 struct{}
func (r *revision9) Up(op Operation) error {
_, err := op.RenameColumns("samples", map[string]string{
func (r *revision9) Up(mg *MigrationDriver) error {
_, err := mg.RenameColumns("samples", map[string]string{
"repo": "repository",
})
return err
}
func (r *revision9) Down(op Operation) error {
_, err := op.RenameColumns("samples", map[string]string{
func (r *revision9) Down(mg *MigrationDriver) error {
_, err := mg.RenameColumns("samples", map[string]string{
"repository": "repo",
})
return err
@ -230,6 +233,26 @@ func (r *revision9) Revision() int64 {
// ---------- end of revision 9
// ---------- revision 10
type revision10 struct{}
func (r *revision10) Revision() int64 {
return 10
}
func (r *revision10) Up(mg *MigrationDriver) error {
_, err := mg.ChangeColumn("samples", "email", "varchar(512) UNIQUE")
return err
}
func (r *revision10) Down(mg *MigrationDriver) error {
_, err := mg.ChangeColumn("samples", "email", "varchar(255) unique")
return err
}
// ---------- end of revision 10
var db *sql.DB
var testSchema = `
@ -252,11 +275,9 @@ func TestMigrateCreateTable(t *testing.T) {
t.Fatalf("Error preparing database: %q", err)
}
Driver = SQLite
mgr := New(db)
if err := mgr.Add(&revision1{}).Migrate(); err != nil {
t.Errorf("Can not migrate: %q", err)
t.Fatalf("Can not migrate: %q", err)
}
sample := Sample{
@ -265,7 +286,7 @@ func TestMigrateCreateTable(t *testing.T) {
Name: "Test Tester",
}
if err := meddler.Save(db, "samples", &sample); err != nil {
t.Errorf("Can not save data: %q", err)
t.Fatalf("Can not save data: %q", err)
}
}
@ -275,22 +296,20 @@ func TestMigrateRenameTable(t *testing.T) {
t.Fatalf("Error preparing database: %q", err)
}
Driver = SQLite
mgr := New(db)
if err := mgr.Add(&revision1{}).Migrate(); err != nil {
t.Errorf("Can not migrate: %q", err)
t.Fatalf("Can not migrate: %q", err)
}
loadFixture(t)
if err := mgr.Add(&revision2{}).Migrate(); err != nil {
t.Errorf("Can not migrate: %q", err)
t.Fatalf("Can not migrate: %q", err)
}
sample := Sample{}
if err := meddler.QueryRow(db, &sample, `SELECT * FROM examples WHERE id = ?`, 2); err != nil {
t.Errorf("Can not fetch data: %q", err)
t.Fatalf("Can not fetch data: %q", err)
}
if sample.Imel != "foo@bar.com" {
@ -313,16 +332,14 @@ func TestMigrateAddRemoveColumns(t *testing.T) {
t.Fatalf("Error preparing database: %q", err)
}
Driver = SQLite
mgr := New(db)
if err := mgr.Add(&revision1{}, &revision3{}).Migrate(); err != nil {
t.Errorf("Can not migrate: %q", err)
t.Fatalf("Can not migrate: %q", err)
}
var columns []*TableInfo
if err := meddler.QueryAll(db, &columns, `PRAGMA table_info(samples);`); err != nil {
t.Errorf("Can not access table info: %q", err)
t.Fatalf("Can not access table info: %q", err)
}
if len(columns) < 5 {
@ -337,16 +354,16 @@ func TestMigrateAddRemoveColumns(t *testing.T) {
Num: 42,
}
if err := meddler.Save(db, "samples", &row); err != nil {
t.Errorf("Can not save into database: %q", err)
t.Fatalf("Can not save into database: %q", err)
}
if err := mgr.MigrateTo(1); err != nil {
t.Errorf("Can not migrate: %q", err)
t.Fatalf("Can not migrate: %q", err)
}
var another_columns []*TableInfo
if err := meddler.QueryAll(db, &another_columns, `PRAGMA table_info(samples);`); err != nil {
t.Errorf("Can not access table info: %q", err)
t.Fatalf("Can not access table info: %q", err)
}
if len(another_columns) != 3 {
@ -360,22 +377,20 @@ func TestRenameColumn(t *testing.T) {
t.Fatalf("Error preparing database: %q", err)
}
Driver = SQLite
mgr := New(db)
if err := mgr.Add(&revision1{}, &revision4{}).MigrateTo(1); err != nil {
t.Errorf("Can not migrate: %q", err)
t.Fatalf("Can not migrate: %q", err)
}
loadFixture(t)
if err := mgr.MigrateTo(4); err != nil {
t.Errorf("Can not migrate: %q", err)
t.Fatalf("Can not migrate: %q", err)
}
row := RenameSample{}
if err := meddler.QueryRow(db, &row, `SELECT * FROM samples WHERE id = 3;`); err != nil {
t.Errorf("Can not query database: %q", err)
t.Fatalf("Can not query database: %q", err)
}
if row.Email != "crash@bandicoot.io" {
@ -389,22 +404,20 @@ func TestMigrateExistingTable(t *testing.T) {
t.Fatalf("Error preparing database: %q", err)
}
Driver = SQLite
if _, err := db.Exec(testSchema); err != nil {
t.Errorf("Can not create database: %q", err)
t.Fatalf("Can not create database: %q", err)
}
loadFixture(t)
mgr := New(db)
if err := mgr.Add(&revision4{}).Migrate(); err != nil {
t.Errorf("Can not migrate: %q", err)
t.Fatalf("Can not migrate: %q", err)
}
var rows []*RenameSample
if err := meddler.QueryAll(db, &rows, `SELECT * from samples;`); err != nil {
t.Errorf("Can not query database: %q", err)
t.Fatalf("Can not query database: %q", err)
}
if len(rows) != 3 {
@ -426,49 +439,47 @@ func TestIndexOperations(t *testing.T) {
t.Fatalf("Error preparing database: %q", err)
}
Driver = SQLite
mgr := New(db)
// Migrate, create index
if err := mgr.Add(&revision1{}, &revision3{}, &revision5{}).Migrate(); err != nil {
t.Errorf("Can not migrate: %q", err)
t.Fatalf("Can not migrate: %q", err)
}
var esquel []*sqliteMaster
// Query sqlite_master, check if index is exists.
query := `SELECT sql FROM sqlite_master WHERE type='index' and tbl_name='samples'`
if err := meddler.QueryAll(db, &esquel, query); err != nil {
t.Errorf("Can not find index: %q", err)
t.Fatalf("Can not find index: %q", err)
}
indexStatement := `CREATE INDEX samples_url_name_ix ON samples (url, name)`
indexStatement := `CREATE INDEX idx_samples_on_url_and_name ON samples (url, name)`
if string(esquel[1].Sql.([]byte)) != indexStatement {
t.Errorf("Can not find index")
t.Errorf("Can not find index, got: %q", esquel[1])
}
// Migrate, rename indexed columns
if err := mgr.Add(&revision6{}).Migrate(); err != nil {
t.Errorf("Can not migrate: %q", err)
t.Fatalf("Can not migrate: %q", err)
}
var esquel1 []*sqliteMaster
if err := meddler.QueryAll(db, &esquel1, query); err != nil {
t.Errorf("Can not find index: %q", err)
t.Fatalf("Can not find index: %q", err)
}
indexStatement = `CREATE INDEX samples_host_name_ix ON samples (host, name)`
indexStatement = `CREATE INDEX idx_samples_on_host_and_name ON samples (host, name)`
if string(esquel1[1].Sql.([]byte)) != indexStatement {
t.Errorf("Can not find index, got: %s", esquel[0])
t.Errorf("Can not find index, got: %q", esquel1[1])
}
if err := mgr.Add(&revision7{}).Migrate(); err != nil {
t.Errorf("Can not migrate: %q", err)
t.Fatalf("Can not migrate: %q", err)
}
var esquel2 []*sqliteMaster
if err := meddler.QueryAll(db, &esquel2, query); err != nil {
t.Errorf("Can not find index: %q", err)
t.Fatalf("Can not find index: %q", err)
}
if len(esquel2) != 1 {
@ -482,17 +493,15 @@ func TestColumnRedundancy(t *testing.T) {
t.Fatalf("Error preparing database: %q", err)
}
Driver = SQLite
migr := New(db)
if err := migr.Add(&revision1{}, &revision8{}, &revision9{}).Migrate(); err != nil {
t.Errorf("Can not migrate: %q", err)
t.Fatalf("Can not migrate: %q", err)
}
var tableSql string
query := `SELECT sql FROM sqlite_master where type='table' and name='samples'`
if err := db.QueryRow(query).Scan(&tableSql); err != nil {
t.Errorf("Can not query sqlite_master: %q", err)
t.Fatalf("Can not query sqlite_master: %q", err)
}
if !strings.Contains(tableSql, "repository ") {
@ -500,8 +509,31 @@ func TestColumnRedundancy(t *testing.T) {
}
}
func TestChangeColumnType(t *testing.T) {
defer tearDown()
if err := setUp(); err != nil {
t.Fatalf("Error preparing database: %q", err)
}
migr := New(db)
if err := migr.Add(&revision1{}, &revision4{}, &revision10{}).Migrate(); err != nil {
t.Fatalf("Can not migrate: %q", err)
}
var tableSql string
query := `SELECT sql FROM sqlite_master where type='table' and name='samples'`
if err := db.QueryRow(query).Scan(&tableSql); err != nil {
t.Fatalf("Can not query sqlite_master: %q", err)
}
if !strings.Contains(tableSql, "email varchar(512) UNIQUE") {
t.Errorf("Expect email type to changed: %q", tableSql)
}
}
func setUp() error {
var err error
Driver = SQLite
db, err = sql.Open("sqlite3", "migration_tests.sqlite")
return err
}
@ -514,7 +546,7 @@ func tearDown() {
func loadFixture(t *testing.T) {
for _, sql := range dataDump {
if _, err := db.Exec(sql); err != nil {
t.Errorf("Can not insert into database: %q", err)
t.Fatalf("Can not insert into database: %q", err)
}
}
}

View file

@ -3,6 +3,8 @@ package migrate
import (
"fmt"
"strings"
"github.com/dchest/uniuri"
)
func fetchColumns(sql string) ([]string, error) {
@ -30,3 +32,11 @@ func setForUpdate(left []string, right []string) string {
}
return strings.Join(results, ", ")
}
func proxyName(tableName string) string {
return fmt.Sprintf("%s_%s", tableName, uniuri.NewLen(16))
}
func indexName(tableName string, columns []string) string {
return fmt.Sprintf("idx_%s_on_%s", tableName, strings.Join(columns, "_and_"))
}

View file

@ -16,31 +16,31 @@ func TestGetCommit(t *testing.T) {
}
if commit.ID != 1 {
t.Errorf("Exepected ID %d, got %d", 1, commit.ID)
t.Errorf("Expected ID %d, got %d", 1, commit.ID)
}
if commit.Status != "Success" {
t.Errorf("Exepected Status %s, got %s", "Success", commit.Status)
t.Errorf("Expected Status %s, got %s", "Success", commit.Status)
}
if commit.Hash != "4f4c4594be6d6ddbc1c0dd521334f7ecba92b608" {
t.Errorf("Exepected Hash %s, got %s", "4f4c4594be6d6ddbc1c0dd521334f7ecba92b608", commit.Hash)
t.Errorf("Expected Hash %s, got %s", "4f4c4594be6d6ddbc1c0dd521334f7ecba92b608", commit.Hash)
}
if commit.Branch != "master" {
t.Errorf("Exepected Branch %s, got %s", "master", commit.Branch)
t.Errorf("Expected Branch %s, got %s", "master", commit.Branch)
}
if commit.Author != "brad.rydzewski@gmail.com" {
t.Errorf("Exepected Author %s, got %s", "master", commit.Author)
t.Errorf("Expected Author %s, got %s", "master", commit.Author)
}
if commit.Message != "commit message" {
t.Errorf("Exepected Message %s, got %s", "master", commit.Message)
t.Errorf("Expected Message %s, got %s", "master", commit.Message)
}
if commit.Gravatar != "8c58a0be77ee441bb8f8595b7f1b4e87" {
t.Errorf("Exepected Gravatar %s, got %s", "8c58a0be77ee441bb8f8595b7f1b4e87", commit.Gravatar)
t.Errorf("Expected Gravatar %s, got %s", "8c58a0be77ee441bb8f8595b7f1b4e87", commit.Gravatar)
}
}
@ -54,15 +54,15 @@ func TestGetCommitHash(t *testing.T) {
}
if commit.ID != 1 {
t.Errorf("Exepected ID %d, got %d", 1, commit.ID)
t.Errorf("Expected ID %d, got %d", 1, commit.ID)
}
if commit.Hash != "4f4c4594be6d6ddbc1c0dd521334f7ecba92b608" {
t.Errorf("Exepected Hash %s, got %s", "4f4c4594be6d6ddbc1c0dd521334f7ecba92b608", commit.Hash)
t.Errorf("Expected Hash %s, got %s", "4f4c4594be6d6ddbc1c0dd521334f7ecba92b608", commit.Hash)
}
if commit.Status != "Success" {
t.Errorf("Exepected Status %s, got %s", "Success", commit.Status)
t.Errorf("Expected Status %s, got %s", "Success", commit.Status)
}
}
@ -91,11 +91,11 @@ func TestSaveCommit(t *testing.T) {
}
if commit.Hash != updatedCommit.Hash {
t.Errorf("Exepected Hash %s, got %s", updatedCommit.Hash, commit.Hash)
t.Errorf("Expected Hash %s, got %s", updatedCommit.Hash, commit.Hash)
}
if commit.Status != "Failing" {
t.Errorf("Exepected Status %s, got %s", updatedCommit.Status, commit.Status)
t.Errorf("Expected Status %s, got %s", updatedCommit.Status, commit.Status)
}
}
@ -126,7 +126,7 @@ func TestListCommits(t *testing.T) {
// verify commit count
if len(commits) != 2 {
t.Errorf("Exepected %d commits in database, got %d", 2, len(commits))
t.Errorf("Expected %d commits in database, got %d", 2, len(commits))
return
}
@ -135,30 +135,30 @@ func TestListCommits(t *testing.T) {
commit := commits[1] // TODO something strange is happening with ordering here
if commit.ID != 1 {
t.Errorf("Exepected ID %d, got %d", 1, commit.ID)
t.Errorf("Expected ID %d, got %d", 1, commit.ID)
}
if commit.Status != "Success" {
t.Errorf("Exepected Status %s, got %s", "Success", commit.Status)
t.Errorf("Expected Status %s, got %s", "Success", commit.Status)
}
if commit.Hash != "4f4c4594be6d6ddbc1c0dd521334f7ecba92b608" {
t.Errorf("Exepected Hash %s, got %s", "4f4c4594be6d6ddbc1c0dd521334f7ecba92b608", commit.Hash)
t.Errorf("Expected Hash %s, got %s", "4f4c4594be6d6ddbc1c0dd521334f7ecba92b608", commit.Hash)
}
if commit.Branch != "master" {
t.Errorf("Exepected Branch %s, got %s", "master", commit.Branch)
t.Errorf("Expected Branch %s, got %s", "master", commit.Branch)
}
if commit.Author != "brad.rydzewski@gmail.com" {
t.Errorf("Exepected Author %s, got %s", "master", commit.Author)
t.Errorf("Expected Author %s, got %s", "master", commit.Author)
}
if commit.Message != "commit message" {
t.Errorf("Exepected Message %s, got %s", "master", commit.Message)
t.Errorf("Expected Message %s, got %s", "master", commit.Message)
}
if commit.Gravatar != "8c58a0be77ee441bb8f8595b7f1b4e87" {
t.Errorf("Exepected Gravatar %s, got %s", "8c58a0be77ee441bb8f8595b7f1b4e87", commit.Gravatar)
t.Errorf("Expected Gravatar %s, got %s", "8c58a0be77ee441bb8f8595b7f1b4e87", commit.Gravatar)
}
}

View file

@ -64,21 +64,21 @@ func TestIsMemberAdmin(t *testing.T) {
if ok, err := database.IsMemberAdmin(1, 1); err != nil {
t.Error(err)
} else if !ok {
t.Errorf("Expected IsMemberAdmin to return true, returned false")
t.Errorf("Expected user id 1 IsMemberAdmin to return true, returned false")
}
// expecting user is Admin
if ok, err := database.IsMemberAdmin(2, 1); err != nil {
t.Error(err)
} else if !ok {
t.Errorf("Expected IsMemberAdmin to return true, returned false")
t.Errorf("Expected user id 2 IsMemberAdmin to return true, returned false")
}
// expecting user is NOT Admin (Write role)
if ok, err := database.IsMemberAdmin(3, 1); err != nil {
t.Error(err)
} else if ok {
t.Errorf("Expected IsMemberAdmin to return false, returned true")
t.Errorf("Expected user id 3 IsMemberAdmin to return false, returned true")
}
}

View file

@ -2,22 +2,16 @@ package database
import (
"crypto/aes"
"database/sql"
"log"
"github.com/drone/drone/pkg/database"
"github.com/drone/drone/pkg/database/encrypt"
"github.com/drone/drone/pkg/database/migrate"
. "github.com/drone/drone/pkg/model"
_ "github.com/mattn/go-sqlite3"
"github.com/russross/meddler"
)
// in-memory database used for
// unit testing purposes.
var db *sql.DB
func init() {
// create a cipher for ecnrypting and decrypting
// database fields
@ -30,20 +24,11 @@ func init() {
// decrypt database fields.
meddler.Register("gobencrypt", &encrypt.EncryptedField{cipher})
// notify meddler that we are working with sqlite
meddler.Default = meddler.SQLite
migrate.Driver = migrate.SQLite
}
func Setup() {
// create an in-memory database
db, _ = sql.Open("sqlite3", ":memory:")
// make sure all the tables and indexes are created
database.Set(db)
migration := migrate.New(db)
migration.All().Migrate()
database.Init("sqlite3", ":memory:")
// create dummy user data
user1 := User{
@ -208,5 +193,5 @@ func Setup() {
}
func Teardown() {
db.Close()
database.Close()
}