woodpecker/vendor/github.com/denisenkom/go-mssqldb/tds_test.go
2015-09-29 18:21:17 -07:00

337 lines
7.5 KiB
Go

package mssql
import (
"bytes"
"database/sql"
"encoding/hex"
"fmt"
"os"
"testing"
"time"
)
type MockTransport struct {
bytes.Buffer
}
func (t *MockTransport) Close() error {
return nil
}
func TestSendLogin(t *testing.T) {
buf := newTdsBuffer(1024, new(MockTransport))
login := login{
TDSVersion: verTDS73,
PacketSize: 0x1000,
ClientProgVer: 0x01060100,
ClientPID: 100,
ClientTimeZone: -4 * 60,
ClientID: [6]byte{0x12, 0x34, 0x56, 0x78, 0x90, 0xab},
OptionFlags1: 0xe0,
OptionFlags3: 8,
HostName: "subdev1",
UserName: "test",
Password: "testpwd",
AppName: "appname",
ServerName: "servername",
CtlIntName: "library",
Language: "en",
Database: "database",
ClientLCID: 0x204,
AtchDBFile: "filepath",
}
err := sendLogin(buf, login)
if err != nil {
t.Error("sendLogin should succeed")
}
ref := []byte{
16, 1, 0, 222, 0, 0, 1, 0, 198 + 16, 0, 0, 0, 3, 0, 10, 115, 0, 16, 0, 0, 0, 1,
6, 1, 100, 0, 0, 0, 0, 0, 0, 0, 224, 0, 0, 8, 16, 255, 255, 255, 4, 2, 0,
0, 94, 0, 7, 0, 108, 0, 4, 0, 116, 0, 7, 0, 130, 0, 7, 0, 144, 0, 10, 0, 0,
0, 0, 0, 164, 0, 7, 0, 178, 0, 2, 0, 182, 0, 8, 0, 18, 52, 86, 120, 144, 171,
198, 0, 0, 0, 198, 0, 8, 0, 214, 0, 0, 0, 0, 0, 0, 0, 115, 0, 117, 0, 98,
0, 100, 0, 101, 0, 118, 0, 49, 0, 116, 0, 101, 0, 115, 0, 116, 0, 226, 165,
243, 165, 146, 165, 226, 165, 162, 165, 210, 165, 227, 165, 97, 0, 112,
0, 112, 0, 110, 0, 97, 0, 109, 0, 101, 0, 115, 0, 101, 0, 114, 0, 118, 0,
101, 0, 114, 0, 110, 0, 97, 0, 109, 0, 101, 0, 108, 0, 105, 0, 98, 0, 114,
0, 97, 0, 114, 0, 121, 0, 101, 0, 110, 0, 100, 0, 97, 0, 116, 0, 97, 0, 98,
0, 97, 0, 115, 0, 101, 0, 102, 0, 105, 0, 108, 0, 101, 0, 112, 0, 97, 0,
116, 0, 104, 0}
out := buf.buf[:buf.pos]
if !bytes.Equal(ref, out) {
t.Error("input output don't match")
fmt.Print(hex.Dump(ref))
fmt.Print(hex.Dump(out))
}
}
func TestSendSqlBatch(t *testing.T) {
addr := os.Getenv("HOST")
instance := os.Getenv("INSTANCE")
conn, err := connect(map[string]string{
"server": fmt.Sprintf("%s\\%s", addr, instance),
"user id": os.Getenv("SQLUSER"),
"password": os.Getenv("SQLPASSWORD"),
"database": os.Getenv("DATABASE"),
})
if err != nil {
t.Error("Open connection failed:", err.Error())
return
}
defer conn.buf.transport.Close()
headers := []headerStruct{
{hdrtype: dataStmHdrTransDescr,
data: transDescrHdr{0, 1}.pack()},
}
err = sendSqlBatch72(conn.buf, "select 1", headers)
if err != nil {
t.Error("Sending sql batch failed", err.Error())
return
}
ch := make(chan tokenStruct, 5)
go processResponse(conn, ch)
var lastRow []interface{}
loop:
for tok := range ch {
switch token := tok.(type) {
case doneStruct:
break loop
case []columnStruct:
conn.columns = token
case []interface{}:
lastRow = token
default:
fmt.Println("unknown token", tok)
}
}
switch value := lastRow[0].(type) {
case int32:
if value != 1 {
t.Error("Invalid value returned, should be 1", value)
return
}
}
}
func makeConnStr() string {
addr := os.Getenv("HOST")
instance := os.Getenv("INSTANCE")
user := os.Getenv("SQLUSER")
password := os.Getenv("SQLPASSWORD")
database := os.Getenv("DATABASE")
return fmt.Sprintf(
"Server=%s\\%s;User Id=%s;Password=%s;Database=%s;log=63",
addr, instance, user, password, database)
}
func open(t *testing.T) *sql.DB {
conn, err := sql.Open("mssql", makeConnStr())
if err != nil {
t.Error("Open connection failed:", err.Error())
return nil
}
return conn
}
func TestConnect(t *testing.T) {
conn, err := sql.Open("mssql", makeConnStr())
if err != nil {
t.Error("Open connection failed:", err.Error())
return
}
defer conn.Close()
}
func TestBadConnect(t *testing.T) {
badDsns := []string{
//"Server=badhost",
fmt.Sprintf("Server=%s\\%s;User ID=baduser;Password=badpwd",
os.Getenv("HOST"), os.Getenv("INSTANCE")),
}
for _, badDsn := range badDsns {
conn, err := sql.Open("mssql", badDsn)
if err != nil {
t.Error("Open connection failed:", err.Error())
}
defer conn.Close()
err = conn.Ping()
if err == nil {
t.Error("Ping should fail for connection: ", badDsn)
}
}
}
func simpleQuery(conn *sql.DB, t *testing.T) (stmt *sql.Stmt) {
stmt, err := conn.Prepare("select 1 as a")
if err != nil {
t.Error("Prepare failed:", err.Error())
return nil
}
return stmt
}
func checkSimpleQuery(rows *sql.Rows, t *testing.T) {
numrows := 0
for rows.Next() {
var val int
err := rows.Scan(&val)
if err != nil {
t.Error("Scan failed:", err.Error())
}
if val != 1 {
t.Error("query should return 1")
}
numrows++
}
if numrows != 1 {
t.Error("query should return 1 row, returned", numrows)
}
}
func TestQuery(t *testing.T) {
conn := open(t)
if conn == nil {
return
}
defer conn.Close()
stmt := simpleQuery(conn, t)
if stmt == nil {
return
}
defer stmt.Close()
rows, err := stmt.Query()
if err != nil {
t.Error("Query failed:", err.Error())
}
defer rows.Close()
columns, err := rows.Columns()
if err != nil {
t.Error("getting columns failed", err.Error())
}
if len(columns) != 1 && columns[0] != "a" {
t.Error("returned incorrect columns (expected ['a']):", columns)
}
checkSimpleQuery(rows, t)
}
func TestMultipleQueriesSequentialy(t *testing.T) {
conn := open(t)
defer conn.Close()
stmt, err := conn.Prepare("select 1 as a")
if err != nil {
t.Error("Prepare failed:", err.Error())
return
}
defer stmt.Close()
rows, err := stmt.Query()
if err != nil {
t.Error("Query failed:", err.Error())
return
}
defer rows.Close()
checkSimpleQuery(rows, t)
rows, err = stmt.Query()
if err != nil {
t.Error("Query failed:", err.Error())
return
}
defer rows.Close()
checkSimpleQuery(rows, t)
}
func TestMultipleQueryClose(t *testing.T) {
conn := open(t)
defer conn.Close()
stmt, err := conn.Prepare("select 1 as a")
if err != nil {
t.Error("Prepare failed:", err.Error())
return
}
defer stmt.Close()
rows, err := stmt.Query()
if err != nil {
t.Error("Query failed:", err.Error())
return
}
rows.Close()
rows, err = stmt.Query()
if err != nil {
t.Error("Query failed:", err.Error())
return
}
defer rows.Close()
checkSimpleQuery(rows, t)
}
func TestPing(t *testing.T) {
conn := open(t)
defer conn.Close()
conn.Ping()
}
func TestSecureWithInvalidHostName(t *testing.T) {
dsn := makeConnStr() + ";Encrypt=true;TrustServerCertificate=false;hostNameInCertificate=foo.bar"
conn, err := sql.Open("mssql", dsn)
if err != nil {
t.Fatal("Open connection failed:", err.Error())
}
defer conn.Close()
err = conn.Ping()
if err == nil {
t.Fatal("Connected to fake foo.bar server")
}
}
func TestSecureConnection(t *testing.T) {
dsn := makeConnStr() + ";Encrypt=true;TrustServerCertificate=true"
conn, err := sql.Open("mssql", dsn)
if err != nil {
t.Fatal("Open connection failed:", err.Error())
}
defer conn.Close()
var msg string
err = conn.QueryRow("select 'secret'").Scan(&msg)
if err != nil {
t.Fatal("cannot scan value", err)
}
if msg != "secret" {
t.Fatal("expected secret, got: ", msg)
}
var secure bool
err = conn.QueryRow("select encrypt_option from sys.dm_exec_connections where session_id=@@SPID").Scan(&secure)
if err != nil {
t.Fatal("cannot scan value", err)
}
if !secure {
t.Fatal("connection is not encrypted")
}
}
func TestParseConnectParamsKeepAlive(t *testing.T) {
params := parseConnectionString("keepAlive=60")
parsedParams, err := parseConnectParams(params)
if err != nil {
t.Fatal("cannot parse params: ", err)
}
if parsedParams.keepAlive != time.Duration(60)*time.Second {
t.Fail()
}
}