// Go MySQL Driver - A MySQL-Driver for Go's database/sql package // // Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved. // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this file, // You can obtain one at http://mozilla.org/MPL/2.0/. package mysql import ( "crypto/sha1" "crypto/tls" "database/sql/driver" "encoding/binary" "errors" "fmt" "io" "log" "net/url" "os" "strings" "time" ) var ( errLog *log.Logger // Error Logger tlsConfigRegister map[string]*tls.Config // Register for custom tls.Configs errInvalidDSNUnescaped = errors.New("Invalid DSN: Did you forget to escape a param value?") errInvalidDSNAddr = errors.New("Invalid DSN: Network Address not terminated (missing closing brace)") ) func init() { errLog = log.New(os.Stderr, "[MySQL] ", log.Ldate|log.Ltime|log.Lshortfile) tlsConfigRegister = make(map[string]*tls.Config) } // RegisterTLSConfig registers a custom tls.Config to be used with sql.Open. // Use the key as a value in the DSN where tls=value. // // rootCertPool := x509.NewCertPool() // pem, err := ioutil.ReadFile("/path/ca-cert.pem") // if err != nil { // log.Fatal(err) // } // if ok := rootCertPool.AppendCertsFromPEM(pem); !ok { // log.Fatal("Failed to append PEM.") // } // clientCert := make([]tls.Certificate, 0, 1) // certs, err := tls.LoadX509KeyPair("/path/client-cert.pem", "/path/client-key.pem") // if err != nil { // log.Fatal(err) // } // clientCert = append(clientCert, certs) // mysql.RegisterTLSConfig("custom", &tls.Config{ // RootCAs: rootCertPool, // Certificates: clientCert, // }) // db, err := sql.Open("mysql", "user@tcp(localhost:3306)/test?tls=custom") // func RegisterTLSConfig(key string, config *tls.Config) error { if _, isBool := readBool(key); isBool || strings.ToLower(key) == "skip-verify" { return fmt.Errorf("Key '%s' is reserved", key) } tlsConfigRegister[key] = config return nil } // DeregisterTLSConfig removes the tls.Config associated with key. func DeregisterTLSConfig(key string) { delete(tlsConfigRegister, key) } // parseDSN parses the DSN string to a config func parseDSN(dsn string) (cfg *config, err error) { cfg = new(config) // TODO: use strings.IndexByte when we can depend on Go 1.2 // [user[:password]@][net[(addr)]]/dbname[?param1=value1¶mN=valueN] // Find the last '/' (since the password or the net addr might contain a '/') for i := len(dsn) - 1; i >= 0; i-- { if dsn[i] == '/' { var j, k int // left part is empty if i <= 0 if i > 0 { // [username[:password]@][protocol[(address)]] // Find the last '@' in dsn[:i] for j = i; j >= 0; j-- { if dsn[j] == '@' { // username[:password] // Find the first ':' in dsn[:j] for k = 0; k < j; k++ { if dsn[k] == ':' { cfg.passwd = dsn[k+1 : j] break } } cfg.user = dsn[:k] break } } // [protocol[(address)]] // Find the first '(' in dsn[j+1:i] for k = j + 1; k < i; k++ { if dsn[k] == '(' { // dsn[i-1] must be == ')' if an adress is specified if dsn[i-1] != ')' { if strings.ContainsRune(dsn[k+1:i], ')') { return nil, errInvalidDSNUnescaped } return nil, errInvalidDSNAddr } cfg.addr = dsn[k+1 : i-1] break } } cfg.net = dsn[j+1 : k] } // dbname[?param1=value1&...¶mN=valueN] // Find the first '?' in dsn[i+1:] for j = i + 1; j < len(dsn); j++ { if dsn[j] == '?' { if err = parseDSNParams(cfg, dsn[j+1:]); err != nil { return } break } } cfg.dbname = dsn[i+1 : j] break } } // Set default network if empty if cfg.net == "" { cfg.net = "tcp" } // Set default adress if empty if cfg.addr == "" { switch cfg.net { case "tcp": cfg.addr = "127.0.0.1:3306" case "unix": cfg.addr = "/tmp/mysql.sock" default: return nil, errors.New("Default addr for network '" + cfg.net + "' unknown") } } // Set default location if empty if cfg.loc == nil { cfg.loc = time.UTC } return } // parseDSNParams parses the DSN "query string" // Values must be url.QueryEscape'ed func parseDSNParams(cfg *config, params string) (err error) { for _, v := range strings.Split(params, "&") { param := strings.SplitN(v, "=", 2) if len(param) != 2 { continue } // cfg params switch value := param[1]; param[0] { // Disable INFILE whitelist / enable all files case "allowAllFiles": var isBool bool cfg.allowAllFiles, isBool = readBool(value) if !isBool { return fmt.Errorf("Invalid Bool value: %s", value) } // Switch "rowsAffected" mode case "clientFoundRows": var isBool bool cfg.clientFoundRows, isBool = readBool(value) if !isBool { return fmt.Errorf("Invalid Bool value: %s", value) } // Use old authentication mode (pre MySQL 4.1) case "allowOldPasswords": var isBool bool cfg.allowOldPasswords, isBool = readBool(value) if !isBool { return fmt.Errorf("Invalid Bool value: %s", value) } // Time Location case "loc": if value, err = url.QueryUnescape(value); err != nil { return } cfg.loc, err = time.LoadLocation(value) if err != nil { return } // Dial Timeout case "timeout": cfg.timeout, err = time.ParseDuration(value) if err != nil { return } // TLS-Encryption case "tls": boolValue, isBool := readBool(value) if isBool { if boolValue { cfg.tls = &tls.Config{} } } else { if strings.ToLower(value) == "skip-verify" { cfg.tls = &tls.Config{InsecureSkipVerify: true} } else if tlsConfig, ok := tlsConfigRegister[value]; ok { cfg.tls = tlsConfig } else { return fmt.Errorf("Invalid value / unknown config name: %s", value) } } default: // lazy init if cfg.params == nil { cfg.params = make(map[string]string) } if cfg.params[param[0]], err = url.QueryUnescape(value); err != nil { return } } } return } // Returns the bool value of the input. // The 2nd return value indicates if the input was a valid bool value func readBool(input string) (value bool, valid bool) { switch input { case "1", "true", "TRUE", "True": return true, true case "0", "false", "FALSE", "False": return false, true } // Not a valid bool value return } /****************************************************************************** * Authentication * ******************************************************************************/ // Encrypt password using 4.1+ method func scramblePassword(scramble, password []byte) []byte { if len(password) == 0 { return nil } // stage1Hash = SHA1(password) crypt := sha1.New() crypt.Write(password) stage1 := crypt.Sum(nil) // scrambleHash = SHA1(scramble + SHA1(stage1Hash)) // inner Hash crypt.Reset() crypt.Write(stage1) hash := crypt.Sum(nil) // outer Hash crypt.Reset() crypt.Write(scramble) crypt.Write(hash) scramble = crypt.Sum(nil) // token = scrambleHash XOR stage1Hash for i := range scramble { scramble[i] ^= stage1[i] } return scramble } // Encrypt password using pre 4.1 (old password) method // https://github.com/atcurtis/mariadb/blob/master/mysys/my_rnd.c type myRnd struct { seed1, seed2 uint32 } const myRndMaxVal = 0x3FFFFFFF // Pseudo random number generator func newMyRnd(seed1, seed2 uint32) *myRnd { return &myRnd{ seed1: seed1 % myRndMaxVal, seed2: seed2 % myRndMaxVal, } } // Tested to be equivalent to MariaDB's floating point variant // http://play.golang.org/p/QHvhd4qved // http://play.golang.org/p/RG0q4ElWDx func (r *myRnd) NextByte() byte { r.seed1 = (r.seed1*3 + r.seed2) % myRndMaxVal r.seed2 = (r.seed1 + r.seed2 + 33) % myRndMaxVal return byte(uint64(r.seed1) * 31 / myRndMaxVal) } // Generate binary hash from byte string using insecure pre 4.1 method func pwHash(password []byte) (result [2]uint32) { var add uint32 = 7 var tmp uint32 result[0] = 1345345333 result[1] = 0x12345671 for _, c := range password { // skip spaces and tabs in password if c == ' ' || c == '\t' { continue } tmp = uint32(c) result[0] ^= (((result[0] & 63) + add) * tmp) + (result[0] << 8) result[1] += (result[1] << 8) ^ result[0] add += tmp } // Remove sign bit (1<<31)-1) result[0] &= 0x7FFFFFFF result[1] &= 0x7FFFFFFF return } // Encrypt password using insecure pre 4.1 method func scrambleOldPassword(scramble, password []byte) []byte { if len(password) == 0 { return nil } scramble = scramble[:8] hashPw := pwHash(password) hashSc := pwHash(scramble) r := newMyRnd(hashPw[0]^hashSc[0], hashPw[1]^hashSc[1]) var out [8]byte for i := range out { out[i] = r.NextByte() + 64 } mask := r.NextByte() for i := range out { out[i] ^= mask } return out[:] } /****************************************************************************** * Time related utils * ******************************************************************************/ // NullTime represents a time.Time that may be NULL. // NullTime implements the Scanner interface so // it can be used as a scan destination: // // var nt NullTime // err := db.QueryRow("SELECT time FROM foo WHERE id=?", id).Scan(&nt) // ... // if nt.Valid { // // use nt.Time // } else { // // NULL value // } // // This NullTime implementation is not driver-specific type NullTime struct { Time time.Time Valid bool // Valid is true if Time is not NULL } // Scan implements the Scanner interface. // The value type must be time.Time or string / []byte (formatted time-string), // otherwise Scan fails. func (nt *NullTime) Scan(value interface{}) (err error) { if value == nil { nt.Time, nt.Valid = time.Time{}, false return } switch v := value.(type) { case time.Time: nt.Time, nt.Valid = v, true return case []byte: nt.Time, err = parseDateTime(string(v), time.UTC) nt.Valid = (err == nil) return case string: nt.Time, err = parseDateTime(v, time.UTC) nt.Valid = (err == nil) return } nt.Valid = false return fmt.Errorf("Can't convert %T to time.Time", value) } // Value implements the driver Valuer interface. func (nt NullTime) Value() (driver.Value, error) { if !nt.Valid { return nil, nil } return nt.Time, nil } func parseDateTime(str string, loc *time.Location) (t time.Time, err error) { switch len(str) { case 10: // YYYY-MM-DD if str == "0000-00-00" { return } t, err = time.Parse(timeFormat[:10], str) case 19: // YYYY-MM-DD HH:MM:SS if str == "0000-00-00 00:00:00" { return } t, err = time.Parse(timeFormat, str) default: err = fmt.Errorf("Invalid Time-String: %s", str) return } // Adjust location if err == nil && loc != time.UTC { y, mo, d := t.Date() h, mi, s := t.Clock() t, err = time.Date(y, mo, d, h, mi, s, t.Nanosecond(), loc), nil } return } func parseBinaryDateTime(num uint64, data []byte, loc *time.Location) (driver.Value, error) { switch num { case 0: return time.Time{}, nil case 4: return time.Date( int(binary.LittleEndian.Uint16(data[:2])), // year time.Month(data[2]), // month int(data[3]), // day 0, 0, 0, 0, loc, ), nil case 7: return time.Date( int(binary.LittleEndian.Uint16(data[:2])), // year time.Month(data[2]), // month int(data[3]), // day int(data[4]), // hour int(data[5]), // minutes int(data[6]), // seconds 0, loc, ), nil case 11: return time.Date( int(binary.LittleEndian.Uint16(data[:2])), // year time.Month(data[2]), // month int(data[3]), // day int(data[4]), // hour int(data[5]), // minutes int(data[6]), // seconds int(binary.LittleEndian.Uint32(data[7:11]))*1000, // nanoseconds loc, ), nil } return nil, fmt.Errorf("Invalid DATETIME-packet length %d", num) } func formatBinaryDate(num uint64, data []byte) (driver.Value, error) { switch num { case 0: return []byte("0000-00-00"), nil case 4: return []byte(fmt.Sprintf( "%04d-%02d-%02d", binary.LittleEndian.Uint16(data[:2]), data[2], data[3], )), nil } return nil, fmt.Errorf("Invalid DATE-packet length %d", num) } func formatBinaryDateTime(num uint64, data []byte) (driver.Value, error) { switch num { case 0: return []byte("0000-00-00 00:00:00"), nil case 4: return []byte(fmt.Sprintf( "%04d-%02d-%02d 00:00:00", binary.LittleEndian.Uint16(data[:2]), data[2], data[3], )), nil case 7: return []byte(fmt.Sprintf( "%04d-%02d-%02d %02d:%02d:%02d", binary.LittleEndian.Uint16(data[:2]), data[2], data[3], data[4], data[5], data[6], )), nil case 11: return []byte(fmt.Sprintf( "%04d-%02d-%02d %02d:%02d:%02d.%06d", binary.LittleEndian.Uint16(data[:2]), data[2], data[3], data[4], data[5], data[6], binary.LittleEndian.Uint32(data[7:11]), )), nil } return nil, fmt.Errorf("Invalid DATETIME-packet length %d", num) } /****************************************************************************** * Convert from and to bytes * ******************************************************************************/ func uint64ToBytes(n uint64) []byte { return []byte{ byte(n), byte(n >> 8), byte(n >> 16), byte(n >> 24), byte(n >> 32), byte(n >> 40), byte(n >> 48), byte(n >> 56), } } func uint64ToString(n uint64) []byte { var a [20]byte i := 20 // U+0030 = 0 // ... // U+0039 = 9 var q uint64 for n >= 10 { i-- q = n / 10 a[i] = uint8(n-q*10) + 0x30 n = q } i-- a[i] = uint8(n) + 0x30 return a[i:] } // treats string value as unsigned integer representation func stringToInt(b []byte) int { val := 0 for i := range b { val *= 10 val += int(b[i] - 0x30) } return val } // returns the string read as a bytes slice, wheter the value is NULL, // the number of bytes read and an error, in case the string is longer than // the input slice func readLengthEnodedString(b []byte) ([]byte, bool, int, error) { // Get length num, isNull, n := readLengthEncodedInteger(b) if num < 1 { return b[n:n], isNull, n, nil } n += int(num) // Check data length if len(b) >= n { return b[n-int(num) : n], false, n, nil } return nil, false, n, io.EOF } // returns the number of bytes skipped and an error, in case the string is // longer than the input slice func skipLengthEnodedString(b []byte) (int, error) { // Get length num, _, n := readLengthEncodedInteger(b) if num < 1 { return n, nil } n += int(num) // Check data length if len(b) >= n { return n, nil } return n, io.EOF } // returns the number read, whether the value is NULL and the number of bytes read func readLengthEncodedInteger(b []byte) (uint64, bool, int) { switch b[0] { // 251: NULL case 0xfb: return 0, true, 1 // 252: value of following 2 case 0xfc: return uint64(b[1]) | uint64(b[2])<<8, false, 3 // 253: value of following 3 case 0xfd: return uint64(b[1]) | uint64(b[2])<<8 | uint64(b[3])<<16, false, 4 // 254: value of following 8 case 0xfe: return uint64(b[1]) | uint64(b[2])<<8 | uint64(b[3])<<16 | uint64(b[4])<<24 | uint64(b[5])<<32 | uint64(b[6])<<40 | uint64(b[7])<<48 | uint64(b[8])<<54, false, 9 } // 0-250: value of first byte return uint64(b[0]), false, 1 } // encodes a uint64 value and appends it to the given bytes slice func appendLengthEncodedInteger(b []byte, n uint64) []byte { switch { case n <= 250: return append(b, byte(n)) case n <= 0xffff: return append(b, 0xfc, byte(n), byte(n>>8)) case n <= 0xffffff: return append(b, 0xfd, byte(n), byte(n>>8), byte(n>>16)) } return b }