package meddler

import (
	"database/sql"
	"fmt"
	"strings"
)

// DB is a generic database interface, matching both *sql.Db and *sql.Tx
type DB interface {
	Exec(query string, args ...interface{}) (sql.Result, error)
	Query(query string, args ...interface{}) (*sql.Rows, error)
	QueryRow(query string, args ...interface{}) *sql.Row
}

// Load loads a record using a query for the primary key field.
// Returns sql.ErrNoRows if not found.
func (d *Database) Load(db DB, table string, dst interface{}, pk int64) error {
	columns, err := d.ColumnsQuoted(dst, true)
	if err != nil {
		return err
	}

	// make sure we have a primary key field
	pkName, _, err := d.PrimaryKey(dst)
	if err != nil {
		return err
	}
	if pkName == "" {
		return fmt.Errorf("meddler.Load: no primary key field found")
	}

	// run the query
	q := fmt.Sprintf("SELECT %s FROM %s WHERE %s = %s", columns, d.quoted(table), d.quoted(pkName), d.Placeholder)

	rows, err := db.Query(q, pk)
	if err != nil {
		return fmt.Errorf("meddler.Load: DB error in Query: %v", err)
	}

	// scan the row
	return d.ScanRow(rows, dst)
}

// Load using the Default Database type
func Load(db DB, table string, dst interface{}, pk int64) error {
	return Default.Load(db, table, dst, pk)
}

// Insert performs an INSERT query for the given record.
// If the record has a primary key flagged, it must be zero, and it
// will be set to the newly-allocated primary key value from the database
// as returned by LastInsertId.
func (d *Database) Insert(db DB, table string, src interface{}) error {
	pkName, pkValue, err := d.PrimaryKey(src)
	if err != nil {
		return err
	}
	if pkName != "" && pkValue != 0 {
		return fmt.Errorf("meddler.Insert: primary key must be zero")
	}

	// gather the query parts
	namesPart, err := d.ColumnsQuoted(src, false)
	if err != nil {
		return err
	}
	valuesPart, err := d.PlaceholdersString(src, false)
	if err != nil {
		return err
	}
	values, err := d.Values(src, false)
	if err != nil {
		return err
	}

	// run the query
	q := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)", d.quoted(table), namesPart, valuesPart)
	if d.UseReturningToGetID && pkName != "" {
		q += " RETURNING " + d.quoted(pkName)
		var newPk int64
		err := db.QueryRow(q, values...).Scan(&newPk)
		if err != nil {
			return fmt.Errorf("meddler.Insert: DB error in QueryRow: %v", err)
		}
		if err = d.SetPrimaryKey(src, newPk); err != nil {
			return fmt.Errorf("meddler.Insert: Error saving updated pk: %v", err)
		}
	} else if pkName != "" {
		result, err := db.Exec(q, values...)
		if err != nil {
			return fmt.Errorf("meddler.Insert: DB error in Exec: %v", err)
		}

		// save the new primary key
		newPk, err := result.LastInsertId()
		if err != nil {
			return fmt.Errorf("meddler.Insert: DB error getting new primary key value: %v", err)
		}
		if err = d.SetPrimaryKey(src, newPk); err != nil {
			return fmt.Errorf("meddler.Insert: Error saving updated pk: %v", err)
		}
	} else {
		// no primary key, so no need to lookup new value
		_, err := db.Exec(q, values...)
		if err != nil {
			return fmt.Errorf("meddler.Insert: DB error in Exec: %v", err)
		}
	}

	return nil
}

// Insert using the Default Database type
func Insert(db DB, table string, src interface{}) error {
	return Default.Insert(db, table, src)
}

// Update performs and UPDATE query for the given record.
// The record must have an integer primary key field that is non-zero,
// and it will be used to select the database row that gets updated.
func (d *Database) Update(db DB, table string, src interface{}) error {
	// gather the query parts
	names, err := d.Columns(src, false)
	if err != nil {
		return err
	}
	placeholders, err := d.Placeholders(src, false)
	if err != nil {
		return err
	}
	values, err := d.Values(src, false)
	if err != nil {
		return err
	}

	// form the column=placeholder pairs
	var pairs []string
	for i := 0; i < len(names) && i < len(placeholders); i++ {
		pair := fmt.Sprintf("%s=%s", d.quoted(names[i]), placeholders[i])
		pairs = append(pairs, pair)
	}

	pkName, pkValue, err := d.PrimaryKey(src)
	if err != nil {
		return err
	}
	if pkName == "" {
		return fmt.Errorf("meddler.Update: no primary key field")
	}
	if pkValue < 1 {
		return fmt.Errorf("meddler.Update: primary key must be an integer > 0")
	}
	ph := d.placeholder(len(placeholders) + 1)

	// run the query
	q := fmt.Sprintf("UPDATE %s SET %s WHERE %s=%s", d.quoted(table),
		strings.Join(pairs, ","),
		d.quoted(pkName), ph)
	values = append(values, pkValue)

	if _, err := db.Exec(q, values...); err != nil {
		return fmt.Errorf("meddler.Update: DB error in Exec: %v", err)
	}

	return nil
}

// Update using the Default Database type
func Update(db DB, table string, src interface{}) error {
	return Default.Update(db, table, src)
}

// Save performs an INSERT or an UPDATE, depending on whether or not
// a primary keys exists and is non-zero.
func (d *Database) Save(db DB, table string, src interface{}) error {
	pkName, pkValue, err := d.PrimaryKey(src)
	if err != nil {
		return err
	}
	if pkName != "" && pkValue != 0 {
		return d.Update(db, table, src)
	} else {
		return d.Insert(db, table, src)
	}
}

// Save using the Default Database type
func Save(db DB, table string, src interface{}) error {
	return Default.Save(db, table, src)
}

// QueryOne performs the given query with the given arguments, scanning a
// single row of results into dst. Returns sql.ErrNoRows if there was no
// result row.
func (d *Database) QueryRow(db DB, dst interface{}, query string, args ...interface{}) error {
	// perform the query
	rows, err := db.Query(query, args...)
	if err != nil {
		return err
	}

	// gather the result
	return d.ScanRow(rows, dst)
}

// QueryRow using the Default Database type
func QueryRow(db DB, dst interface{}, query string, args ...interface{}) error {
	return Default.QueryRow(db, dst, query, args...)
}

// QueryAll performs the given query with the given arguments, scanning
// all results rows into dst.
func (d *Database) QueryAll(db DB, dst interface{}, query string, args ...interface{}) error {
	// perform the query
	rows, err := db.Query(query, args...)
	if err != nil {
		return err
	}

	// gather the results
	return d.ScanAll(rows, dst)
}

// QueryAll using the Default Database type
func QueryAll(db DB, dst interface{}, query string, args ...interface{}) error {
	return Default.QueryAll(db, dst, query, args...)
}