moved to Godep with path rewrite

This commit is contained in:
Brad Rydzewski 2015-05-22 11:37:40 -07:00
parent 52c62d4c70
commit 35d604659a
598 changed files with 259201 additions and 100 deletions

129
Godeps/Godeps.json generated Normal file
View file

@ -0,0 +1,129 @@
{
"ImportPath": "github.com/drone/drone",
"GoVersion": "go1.4.2",
"Packages": [
"github.com/drone/drone/..."
],
"Deps": [
{
"ImportPath": "code.google.com/p/go.crypto/ssh",
"Comment": "null-236",
"Rev": "69e2a90ed92d03812364aeb947b7068dc42e561e"
},
{
"ImportPath": "github.com/BurntSushi/migration",
"Rev": "c45b897f13350786ccaf2b7403b92b1c7ad85844"
},
{
"ImportPath": "github.com/BurntSushi/toml",
"Comment": "v0.1.0-18-g443a628",
"Rev": "443a628bc233f634a75bcbdd71fe5350789f1afa"
},
{
"ImportPath": "github.com/Sirupsen/logrus",
"Comment": "v0.6.6-2-g2cea0f0",
"Rev": "2cea0f0d141f56fae06df5b813ec4119d1c8ccbd"
},
{
"ImportPath": "github.com/dgrijalva/jwt-go",
"Comment": "v2.2.0-16-gc48cfd5",
"Rev": "c48cfd5d9711c75acb6036d2698ef3aef7bb655a"
},
{
"ImportPath": "github.com/elazarl/go-bindata-assetfs",
"Rev": "bea323321994103859d60197d229f1a94699dde3"
},
{
"ImportPath": "github.com/franela/goblin",
"Comment": "0.0.1-53-g2042c4d",
"Rev": "2042c4d610d2da4e4d2afd754da0b4eef31753bb"
},
{
"ImportPath": "github.com/gin-gonic/gin",
"Comment": "v0.4-255-gcaa75c6",
"Rev": "caa75c620153398c334dba46601a92dc9cc00c5d"
},
{
"ImportPath": "github.com/google/go-github/github",
"Rev": "2431923100d4989611dcc339f1ff8d280bd9c84c"
},
{
"ImportPath": "github.com/google/go-querystring/query",
"Rev": "d8840cbb2baa915f4836edda4750050a2c0b7aea"
},
{
"ImportPath": "github.com/gorilla/securecookie",
"Rev": "8e98dd730fc43f1383f19615db6f2e702c9738e8"
},
{
"ImportPath": "github.com/hashicorp/golang-lru",
"Rev": "d85392d6bc30546d352f52f2632814cde4201d44"
},
{
"ImportPath": "github.com/lib/pq",
"Comment": "go1.0-cutoff-47-g93e9980",
"Rev": "93e9980741c9e593411b94e07d5bad8cfb4809db"
},
{
"ImportPath": "github.com/manucorporat/sse",
"Rev": "c574f6c50c8594f93d28b03a1bbd87b4a3899093"
},
{
"ImportPath": "github.com/mattn/go-colorable",
"Rev": "043ae16291351db8465272edf465c9f388161627"
},
{
"ImportPath": "github.com/mattn/go-sqlite3",
"Rev": "542ae647f8601bafd96233961b150cae198e0295"
},
{
"ImportPath": "github.com/naoina/go-stringutil",
"Rev": "360db0db4b01d34e12a2ec042c09e7d37fece761"
},
{
"ImportPath": "github.com/naoina/toml",
"Rev": "7b2dffbeaee47506726f29e36d19cf4ee90d361b"
},
{
"ImportPath": "github.com/russross/meddler",
"Rev": "cd98050d9328eb98ed57de34c74660dea3ca5df4"
},
{
"ImportPath": "github.com/samalba/dockerclient",
"Rev": "0fdc3ca0e58365801f1212900def9c7c60bbe2c7"
},
{
"ImportPath": "github.com/stretchr/objx",
"Rev": "cbeaeb16a013161a98496fad62933b1d21786672"
},
{
"ImportPath": "github.com/stretchr/testify/assert",
"Rev": "73a8112e05603706e47153d5cbb86aac8d85adb0"
},
{
"ImportPath": "github.com/stretchr/testify/mock",
"Rev": "73a8112e05603706e47153d5cbb86aac8d85adb0"
},
{
"ImportPath": "github.com/ungerik/go-gravatar",
"Rev": "6ab22628222adee0d853df36a2f0783f56ea55ec"
},
{
"ImportPath": "github.com/vrischmann/envconfig",
"Rev": "2985ccd302077381e9447e7b05af4d363fcae2c2"
},
{
"ImportPath": "golang.org/x/net/context",
"Rev": "e0403b4e005737430c05a57aac078479844f919c"
},
{
"ImportPath": "gopkg.in/joeybloggs/go-validate-yourself.v4",
"Comment": "v4.0.4",
"Rev": "47a6634ff80051b92481b6c076854b85119fd267"
},
{
"ImportPath": "gopkg.in/yaml.v2",
"Rev": "49c95bdc21843256fb6c4e0d370a05f24a0bf213"
}
]
}

5
Godeps/Readme generated Normal file
View file

@ -0,0 +1,5 @@
This directory tree is generated automatically by godep.
Please do not edit.
See https://github.com/tools/godep for more information.

2
Godeps/_workspace/.gitignore generated vendored Normal file
View file

@ -0,0 +1,2 @@
/pkg
/bin

View file

@ -0,0 +1,563 @@
// Copyright 2012 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
/*
Package agent implements a client to an ssh-agent daemon.
References:
[PROTOCOL.agent]: http://www.openbsd.org/cgi-bin/cvsweb/src/usr.bin/ssh/PROTOCOL.agent
*/
package agent
import (
"bytes"
"crypto/dsa"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rsa"
"encoding/base64"
"encoding/binary"
"errors"
"fmt"
"io"
"math/big"
"sync"
"golang.org/x/crypto/ssh"
)
// Agent represents the capabilities of an ssh-agent.
type Agent interface {
// List returns the identities known to the agent.
List() ([]*Key, error)
// Sign has the agent sign the data using a protocol 2 key as defined
// in [PROTOCOL.agent] section 2.6.2.
Sign(key ssh.PublicKey, data []byte) (*ssh.Signature, error)
// Insert adds a private key to the agent. If a certificate
// is given, that certificate is added as public key.
Add(s interface{}, cert *ssh.Certificate, comment string) error
// Remove removes all identities with the given public key.
Remove(key ssh.PublicKey) error
// RemoveAll removes all identities.
RemoveAll() error
// Lock locks the agent. Sign and Remove will fail, and List will empty an empty list.
Lock(passphrase []byte) error
// Unlock undoes the effect of Lock
Unlock(passphrase []byte) error
// Signers returns signers for all the known keys.
Signers() ([]ssh.Signer, error)
}
// See [PROTOCOL.agent], section 3.
const (
agentRequestV1Identities = 1
// 3.2 Requests from client to agent for protocol 2 key operations
agentAddIdentity = 17
agentRemoveIdentity = 18
agentRemoveAllIdentities = 19
agentAddIdConstrained = 25
// 3.3 Key-type independent requests from client to agent
agentAddSmartcardKey = 20
agentRemoveSmartcardKey = 21
agentLock = 22
agentUnlock = 23
agentAddSmartcardKeyConstrained = 26
// 3.7 Key constraint identifiers
agentConstrainLifetime = 1
agentConstrainConfirm = 2
)
// maxAgentResponseBytes is the maximum agent reply size that is accepted. This
// is a sanity check, not a limit in the spec.
const maxAgentResponseBytes = 16 << 20
// Agent messages:
// These structures mirror the wire format of the corresponding ssh agent
// messages found in [PROTOCOL.agent].
// 3.4 Generic replies from agent to client
const agentFailure = 5
type failureAgentMsg struct{}
const agentSuccess = 6
type successAgentMsg struct{}
// See [PROTOCOL.agent], section 2.5.2.
const agentRequestIdentities = 11
type requestIdentitiesAgentMsg struct{}
// See [PROTOCOL.agent], section 2.5.2.
const agentIdentitiesAnswer = 12
type identitiesAnswerAgentMsg struct {
NumKeys uint32 `sshtype:"12"`
Keys []byte `ssh:"rest"`
}
// See [PROTOCOL.agent], section 2.6.2.
const agentSignRequest = 13
type signRequestAgentMsg struct {
KeyBlob []byte `sshtype:"13"`
Data []byte
Flags uint32
}
// See [PROTOCOL.agent], section 2.6.2.
// 3.6 Replies from agent to client for protocol 2 key operations
const agentSignResponse = 14
type signResponseAgentMsg struct {
SigBlob []byte `sshtype:"14"`
}
type publicKey struct {
Format string
Rest []byte `ssh:"rest"`
}
// Key represents a protocol 2 public key as defined in
// [PROTOCOL.agent], section 2.5.2.
type Key struct {
Format string
Blob []byte
Comment string
}
func clientErr(err error) error {
return fmt.Errorf("agent: client error: %v", err)
}
// String returns the storage form of an agent key with the format, base64
// encoded serialized key, and the comment if it is not empty.
func (k *Key) String() string {
s := string(k.Format) + " " + base64.StdEncoding.EncodeToString(k.Blob)
if k.Comment != "" {
s += " " + k.Comment
}
return s
}
// Type returns the public key type.
func (k *Key) Type() string {
return k.Format
}
// Marshal returns key blob to satisfy the ssh.PublicKey interface.
func (k *Key) Marshal() []byte {
return k.Blob
}
// Verify satisfies the ssh.PublicKey interface, but is not
// implemented for agent keys.
func (k *Key) Verify(data []byte, sig *ssh.Signature) error {
return errors.New("agent: agent key does not know how to verify")
}
type wireKey struct {
Format string
Rest []byte `ssh:"rest"`
}
func parseKey(in []byte) (out *Key, rest []byte, err error) {
var record struct {
Blob []byte
Comment string
Rest []byte `ssh:"rest"`
}
if err := ssh.Unmarshal(in, &record); err != nil {
return nil, nil, err
}
var wk wireKey
if err := ssh.Unmarshal(record.Blob, &wk); err != nil {
return nil, nil, err
}
return &Key{
Format: wk.Format,
Blob: record.Blob,
Comment: record.Comment,
}, record.Rest, nil
}
// client is a client for an ssh-agent process.
type client struct {
// conn is typically a *net.UnixConn
conn io.ReadWriter
// mu is used to prevent concurrent access to the agent
mu sync.Mutex
}
// NewClient returns an Agent that talks to an ssh-agent process over
// the given connection.
func NewClient(rw io.ReadWriter) Agent {
return &client{conn: rw}
}
// call sends an RPC to the agent. On success, the reply is
// unmarshaled into reply and replyType is set to the first byte of
// the reply, which contains the type of the message.
func (c *client) call(req []byte) (reply interface{}, err error) {
c.mu.Lock()
defer c.mu.Unlock()
msg := make([]byte, 4+len(req))
binary.BigEndian.PutUint32(msg, uint32(len(req)))
copy(msg[4:], req)
if _, err = c.conn.Write(msg); err != nil {
return nil, clientErr(err)
}
var respSizeBuf [4]byte
if _, err = io.ReadFull(c.conn, respSizeBuf[:]); err != nil {
return nil, clientErr(err)
}
respSize := binary.BigEndian.Uint32(respSizeBuf[:])
if respSize > maxAgentResponseBytes {
return nil, clientErr(err)
}
buf := make([]byte, respSize)
if _, err = io.ReadFull(c.conn, buf); err != nil {
return nil, clientErr(err)
}
reply, err = unmarshal(buf)
if err != nil {
return nil, clientErr(err)
}
return reply, err
}
func (c *client) simpleCall(req []byte) error {
resp, err := c.call(req)
if err != nil {
return err
}
if _, ok := resp.(*successAgentMsg); ok {
return nil
}
return errors.New("agent: failure")
}
func (c *client) RemoveAll() error {
return c.simpleCall([]byte{agentRemoveAllIdentities})
}
func (c *client) Remove(key ssh.PublicKey) error {
req := ssh.Marshal(&agentRemoveIdentityMsg{
KeyBlob: key.Marshal(),
})
return c.simpleCall(req)
}
func (c *client) Lock(passphrase []byte) error {
req := ssh.Marshal(&agentLockMsg{
Passphrase: passphrase,
})
return c.simpleCall(req)
}
func (c *client) Unlock(passphrase []byte) error {
req := ssh.Marshal(&agentUnlockMsg{
Passphrase: passphrase,
})
return c.simpleCall(req)
}
// List returns the identities known to the agent.
func (c *client) List() ([]*Key, error) {
// see [PROTOCOL.agent] section 2.5.2.
req := []byte{agentRequestIdentities}
msg, err := c.call(req)
if err != nil {
return nil, err
}
switch msg := msg.(type) {
case *identitiesAnswerAgentMsg:
if msg.NumKeys > maxAgentResponseBytes/8 {
return nil, errors.New("agent: too many keys in agent reply")
}
keys := make([]*Key, msg.NumKeys)
data := msg.Keys
for i := uint32(0); i < msg.NumKeys; i++ {
var key *Key
var err error
if key, data, err = parseKey(data); err != nil {
return nil, err
}
keys[i] = key
}
return keys, nil
case *failureAgentMsg:
return nil, errors.New("agent: failed to list keys")
}
panic("unreachable")
}
// Sign has the agent sign the data using a protocol 2 key as defined
// in [PROTOCOL.agent] section 2.6.2.
func (c *client) Sign(key ssh.PublicKey, data []byte) (*ssh.Signature, error) {
req := ssh.Marshal(signRequestAgentMsg{
KeyBlob: key.Marshal(),
Data: data,
})
msg, err := c.call(req)
if err != nil {
return nil, err
}
switch msg := msg.(type) {
case *signResponseAgentMsg:
var sig ssh.Signature
if err := ssh.Unmarshal(msg.SigBlob, &sig); err != nil {
return nil, err
}
return &sig, nil
case *failureAgentMsg:
return nil, errors.New("agent: failed to sign challenge")
}
panic("unreachable")
}
// unmarshal parses an agent message in packet, returning the parsed
// form and the message type of packet.
func unmarshal(packet []byte) (interface{}, error) {
if len(packet) < 1 {
return nil, errors.New("agent: empty packet")
}
var msg interface{}
switch packet[0] {
case agentFailure:
return new(failureAgentMsg), nil
case agentSuccess:
return new(successAgentMsg), nil
case agentIdentitiesAnswer:
msg = new(identitiesAnswerAgentMsg)
case agentSignResponse:
msg = new(signResponseAgentMsg)
default:
return nil, fmt.Errorf("agent: unknown type tag %d", packet[0])
}
if err := ssh.Unmarshal(packet, msg); err != nil {
return nil, err
}
return msg, nil
}
type rsaKeyMsg struct {
Type string `sshtype:"17"`
N *big.Int
E *big.Int
D *big.Int
Iqmp *big.Int // IQMP = Inverse Q Mod P
P *big.Int
Q *big.Int
Comments string
}
type dsaKeyMsg struct {
Type string `sshtype:"17"`
P *big.Int
Q *big.Int
G *big.Int
Y *big.Int
X *big.Int
Comments string
}
type ecdsaKeyMsg struct {
Type string `sshtype:"17"`
Curve string
KeyBytes []byte
D *big.Int
Comments string
}
// Insert adds a private key to the agent.
func (c *client) insertKey(s interface{}, comment string) error {
var req []byte
switch k := s.(type) {
case *rsa.PrivateKey:
if len(k.Primes) != 2 {
return fmt.Errorf("agent: unsupported RSA key with %d primes", len(k.Primes))
}
k.Precompute()
req = ssh.Marshal(rsaKeyMsg{
Type: ssh.KeyAlgoRSA,
N: k.N,
E: big.NewInt(int64(k.E)),
D: k.D,
Iqmp: k.Precomputed.Qinv,
P: k.Primes[0],
Q: k.Primes[1],
Comments: comment,
})
case *dsa.PrivateKey:
req = ssh.Marshal(dsaKeyMsg{
Type: ssh.KeyAlgoDSA,
P: k.P,
Q: k.Q,
G: k.G,
Y: k.Y,
X: k.X,
Comments: comment,
})
case *ecdsa.PrivateKey:
nistID := fmt.Sprintf("nistp%d", k.Params().BitSize)
req = ssh.Marshal(ecdsaKeyMsg{
Type: "ecdsa-sha2-" + nistID,
Curve: nistID,
KeyBytes: elliptic.Marshal(k.Curve, k.X, k.Y),
D: k.D,
Comments: comment,
})
default:
return fmt.Errorf("agent: unsupported key type %T", s)
}
resp, err := c.call(req)
if err != nil {
return err
}
if _, ok := resp.(*successAgentMsg); ok {
return nil
}
return errors.New("agent: failure")
}
type rsaCertMsg struct {
Type string `sshtype:"17"`
CertBytes []byte
D *big.Int
Iqmp *big.Int // IQMP = Inverse Q Mod P
P *big.Int
Q *big.Int
Comments string
}
type dsaCertMsg struct {
Type string `sshtype:"17"`
CertBytes []byte
X *big.Int
Comments string
}
type ecdsaCertMsg struct {
Type string `sshtype:"17"`
CertBytes []byte
D *big.Int
Comments string
}
// Insert adds a private key to the agent. If a certificate is given,
// that certificate is added instead as public key.
func (c *client) Add(s interface{}, cert *ssh.Certificate, comment string) error {
if cert == nil {
return c.insertKey(s, comment)
} else {
return c.insertCert(s, cert, comment)
}
}
func (c *client) insertCert(s interface{}, cert *ssh.Certificate, comment string) error {
var req []byte
switch k := s.(type) {
case *rsa.PrivateKey:
if len(k.Primes) != 2 {
return fmt.Errorf("agent: unsupported RSA key with %d primes", len(k.Primes))
}
k.Precompute()
req = ssh.Marshal(rsaCertMsg{
Type: cert.Type(),
CertBytes: cert.Marshal(),
D: k.D,
Iqmp: k.Precomputed.Qinv,
P: k.Primes[0],
Q: k.Primes[1],
Comments: comment,
})
case *dsa.PrivateKey:
req = ssh.Marshal(dsaCertMsg{
Type: cert.Type(),
CertBytes: cert.Marshal(),
X: k.X,
Comments: comment,
})
case *ecdsa.PrivateKey:
req = ssh.Marshal(ecdsaCertMsg{
Type: cert.Type(),
CertBytes: cert.Marshal(),
D: k.D,
Comments: comment,
})
default:
return fmt.Errorf("agent: unsupported key type %T", s)
}
signer, err := ssh.NewSignerFromKey(s)
if err != nil {
return err
}
if bytes.Compare(cert.Key.Marshal(), signer.PublicKey().Marshal()) != 0 {
return errors.New("agent: signer and cert have different public key")
}
resp, err := c.call(req)
if err != nil {
return err
}
if _, ok := resp.(*successAgentMsg); ok {
return nil
}
return errors.New("agent: failure")
}
// Signers provides a callback for client authentication.
func (c *client) Signers() ([]ssh.Signer, error) {
keys, err := c.List()
if err != nil {
return nil, err
}
var result []ssh.Signer
for _, k := range keys {
result = append(result, &agentKeyringSigner{c, k})
}
return result, nil
}
type agentKeyringSigner struct {
agent *client
pub ssh.PublicKey
}
func (s *agentKeyringSigner) PublicKey() ssh.PublicKey {
return s.pub
}
func (s *agentKeyringSigner) Sign(rand io.Reader, data []byte) (*ssh.Signature, error) {
// The agent has its own entropy source, so the rand argument is ignored.
return s.agent.Sign(s.pub, data)
}

View file

@ -0,0 +1,278 @@
// Copyright 2012 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package agent
import (
"bytes"
"crypto/rand"
"errors"
"net"
"os"
"os/exec"
"path/filepath"
"strconv"
"testing"
"golang.org/x/crypto/ssh"
)
// startAgent executes ssh-agent, and returns a Agent interface to it.
func startAgent(t *testing.T) (client Agent, socket string, cleanup func()) {
if testing.Short() {
// ssh-agent is not always available, and the key
// types supported vary by platform.
t.Skip("skipping test due to -short")
}
bin, err := exec.LookPath("ssh-agent")
if err != nil {
t.Skip("could not find ssh-agent")
}
cmd := exec.Command(bin, "-s")
out, err := cmd.Output()
if err != nil {
t.Fatalf("cmd.Output: %v", err)
}
/* Output looks like:
SSH_AUTH_SOCK=/tmp/ssh-P65gpcqArqvH/agent.15541; export SSH_AUTH_SOCK;
SSH_AGENT_PID=15542; export SSH_AGENT_PID;
echo Agent pid 15542;
*/
fields := bytes.Split(out, []byte(";"))
line := bytes.SplitN(fields[0], []byte("="), 2)
line[0] = bytes.TrimLeft(line[0], "\n")
if string(line[0]) != "SSH_AUTH_SOCK" {
t.Fatalf("could not find key SSH_AUTH_SOCK in %q", fields[0])
}
socket = string(line[1])
line = bytes.SplitN(fields[2], []byte("="), 2)
line[0] = bytes.TrimLeft(line[0], "\n")
if string(line[0]) != "SSH_AGENT_PID" {
t.Fatalf("could not find key SSH_AGENT_PID in %q", fields[2])
}
pidStr := line[1]
pid, err := strconv.Atoi(string(pidStr))
if err != nil {
t.Fatalf("Atoi(%q): %v", pidStr, err)
}
conn, err := net.Dial("unix", string(socket))
if err != nil {
t.Fatalf("net.Dial: %v", err)
}
ac := NewClient(conn)
return ac, socket, func() {
proc, _ := os.FindProcess(pid)
if proc != nil {
proc.Kill()
}
conn.Close()
os.RemoveAll(filepath.Dir(socket))
}
}
func testAgent(t *testing.T, key interface{}, cert *ssh.Certificate) {
agent, _, cleanup := startAgent(t)
defer cleanup()
testAgentInterface(t, agent, key, cert)
}
func testAgentInterface(t *testing.T, agent Agent, key interface{}, cert *ssh.Certificate) {
signer, err := ssh.NewSignerFromKey(key)
if err != nil {
t.Fatalf("NewSignerFromKey(%T): %v", key, err)
}
// The agent should start up empty.
if keys, err := agent.List(); err != nil {
t.Fatalf("RequestIdentities: %v", err)
} else if len(keys) > 0 {
t.Fatalf("got %d keys, want 0: %v", len(keys), keys)
}
// Attempt to insert the key, with certificate if specified.
var pubKey ssh.PublicKey
if cert != nil {
err = agent.Add(key, cert, "comment")
pubKey = cert
} else {
err = agent.Add(key, nil, "comment")
pubKey = signer.PublicKey()
}
if err != nil {
t.Fatalf("insert(%T): %v", key, err)
}
// Did the key get inserted successfully?
if keys, err := agent.List(); err != nil {
t.Fatalf("List: %v", err)
} else if len(keys) != 1 {
t.Fatalf("got %v, want 1 key", keys)
} else if keys[0].Comment != "comment" {
t.Fatalf("key comment: got %v, want %v", keys[0].Comment, "comment")
} else if !bytes.Equal(keys[0].Blob, pubKey.Marshal()) {
t.Fatalf("key mismatch")
}
// Can the agent make a valid signature?
data := []byte("hello")
sig, err := agent.Sign(pubKey, data)
if err != nil {
t.Fatalf("Sign(%s): %v", pubKey.Type(), err)
}
if err := pubKey.Verify(data, sig); err != nil {
t.Fatalf("Verify(%s): %v", pubKey.Type(), err)
}
}
func TestAgent(t *testing.T) {
for _, keyType := range []string{"rsa", "dsa", "ecdsa"} {
testAgent(t, testPrivateKeys[keyType], nil)
}
}
func TestCert(t *testing.T) {
cert := &ssh.Certificate{
Key: testPublicKeys["rsa"],
ValidBefore: ssh.CertTimeInfinity,
CertType: ssh.UserCert,
}
cert.SignCert(rand.Reader, testSigners["ecdsa"])
testAgent(t, testPrivateKeys["rsa"], cert)
}
// netPipe is analogous to net.Pipe, but it uses a real net.Conn, and
// therefore is buffered (net.Pipe deadlocks if both sides start with
// a write.)
func netPipe() (net.Conn, net.Conn, error) {
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
return nil, nil, err
}
defer listener.Close()
c1, err := net.Dial("tcp", listener.Addr().String())
if err != nil {
return nil, nil, err
}
c2, err := listener.Accept()
if err != nil {
c1.Close()
return nil, nil, err
}
return c1, c2, nil
}
func TestAuth(t *testing.T) {
a, b, err := netPipe()
if err != nil {
t.Fatalf("netPipe: %v", err)
}
defer a.Close()
defer b.Close()
agent, _, cleanup := startAgent(t)
defer cleanup()
if err := agent.Add(testPrivateKeys["rsa"], nil, "comment"); err != nil {
t.Errorf("Add: %v", err)
}
serverConf := ssh.ServerConfig{}
serverConf.AddHostKey(testSigners["rsa"])
serverConf.PublicKeyCallback = func(c ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) {
if bytes.Equal(key.Marshal(), testPublicKeys["rsa"].Marshal()) {
return nil, nil
}
return nil, errors.New("pubkey rejected")
}
go func() {
conn, _, _, err := ssh.NewServerConn(a, &serverConf)
if err != nil {
t.Fatalf("Server: %v", err)
}
conn.Close()
}()
conf := ssh.ClientConfig{}
conf.Auth = append(conf.Auth, ssh.PublicKeysCallback(agent.Signers))
conn, _, _, err := ssh.NewClientConn(b, "", &conf)
if err != nil {
t.Fatalf("NewClientConn: %v", err)
}
conn.Close()
}
func TestLockClient(t *testing.T) {
agent, _, cleanup := startAgent(t)
defer cleanup()
testLockAgent(agent, t)
}
func testLockAgent(agent Agent, t *testing.T) {
if err := agent.Add(testPrivateKeys["rsa"], nil, "comment 1"); err != nil {
t.Errorf("Add: %v", err)
}
if err := agent.Add(testPrivateKeys["dsa"], nil, "comment dsa"); err != nil {
t.Errorf("Add: %v", err)
}
if keys, err := agent.List(); err != nil {
t.Errorf("List: %v", err)
} else if len(keys) != 2 {
t.Errorf("Want 2 keys, got %v", keys)
}
passphrase := []byte("secret")
if err := agent.Lock(passphrase); err != nil {
t.Errorf("Lock: %v", err)
}
if keys, err := agent.List(); err != nil {
t.Errorf("List: %v", err)
} else if len(keys) != 0 {
t.Errorf("Want 0 keys, got %v", keys)
}
signer, _ := ssh.NewSignerFromKey(testPrivateKeys["rsa"])
if _, err := agent.Sign(signer.PublicKey(), []byte("hello")); err == nil {
t.Fatalf("Sign did not fail")
}
if err := agent.Remove(signer.PublicKey()); err == nil {
t.Fatalf("Remove did not fail")
}
if err := agent.RemoveAll(); err == nil {
t.Fatalf("RemoveAll did not fail")
}
if err := agent.Unlock(nil); err == nil {
t.Errorf("Unlock with wrong passphrase succeeded")
}
if err := agent.Unlock(passphrase); err != nil {
t.Errorf("Unlock: %v", err)
}
if err := agent.Remove(signer.PublicKey()); err != nil {
t.Fatalf("Remove: %v", err)
}
if keys, err := agent.List(); err != nil {
t.Errorf("List: %v", err)
} else if len(keys) != 1 {
t.Errorf("Want 1 keys, got %v", keys)
}
}

View file

@ -0,0 +1,103 @@
// Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package agent
import (
"errors"
"io"
"net"
"sync"
"golang.org/x/crypto/ssh"
)
// RequestAgentForwarding sets up agent forwarding for the session.
// ForwardToAgent or ForwardToRemote should be called to route
// the authentication requests.
func RequestAgentForwarding(session *ssh.Session) error {
ok, err := session.SendRequest("auth-agent-req@openssh.com", true, nil)
if err != nil {
return err
}
if !ok {
return errors.New("forwarding request denied")
}
return nil
}
// ForwardToAgent routes authentication requests to the given keyring.
func ForwardToAgent(client *ssh.Client, keyring Agent) error {
channels := client.HandleChannelOpen(channelType)
if channels == nil {
return errors.New("agent: already have handler for " + channelType)
}
go func() {
for ch := range channels {
channel, reqs, err := ch.Accept()
if err != nil {
continue
}
go ssh.DiscardRequests(reqs)
go func() {
ServeAgent(keyring, channel)
channel.Close()
}()
}
}()
return nil
}
const channelType = "auth-agent@openssh.com"
// ForwardToRemote routes authentication requests to the ssh-agent
// process serving on the given unix socket.
func ForwardToRemote(client *ssh.Client, addr string) error {
channels := client.HandleChannelOpen(channelType)
if channels == nil {
return errors.New("agent: already have handler for " + channelType)
}
conn, err := net.Dial("unix", addr)
if err != nil {
return err
}
conn.Close()
go func() {
for ch := range channels {
channel, reqs, err := ch.Accept()
if err != nil {
continue
}
go ssh.DiscardRequests(reqs)
go forwardUnixSocket(channel, addr)
}
}()
return nil
}
func forwardUnixSocket(channel ssh.Channel, addr string) {
conn, err := net.Dial("unix", addr)
if err != nil {
return
}
var wg sync.WaitGroup
wg.Add(2)
go func() {
io.Copy(conn, channel)
conn.(*net.UnixConn).CloseWrite()
wg.Done()
}()
go func() {
io.Copy(channel, conn)
channel.CloseWrite()
wg.Done()
}()
wg.Wait()
conn.Close()
channel.Close()
}

View file

@ -0,0 +1,183 @@
// Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package agent
import (
"bytes"
"crypto/rand"
"crypto/subtle"
"errors"
"fmt"
"sync"
"golang.org/x/crypto/ssh"
)
type privKey struct {
signer ssh.Signer
comment string
}
type keyring struct {
mu sync.Mutex
keys []privKey
locked bool
passphrase []byte
}
var errLocked = errors.New("agent: locked")
// NewKeyring returns an Agent that holds keys in memory. It is safe
// for concurrent use by multiple goroutines.
func NewKeyring() Agent {
return &keyring{}
}
// RemoveAll removes all identities.
func (r *keyring) RemoveAll() error {
r.mu.Lock()
defer r.mu.Unlock()
if r.locked {
return errLocked
}
r.keys = nil
return nil
}
// Remove removes all identities with the given public key.
func (r *keyring) Remove(key ssh.PublicKey) error {
r.mu.Lock()
defer r.mu.Unlock()
if r.locked {
return errLocked
}
want := key.Marshal()
found := false
for i := 0; i < len(r.keys); {
if bytes.Equal(r.keys[i].signer.PublicKey().Marshal(), want) {
found = true
r.keys[i] = r.keys[len(r.keys)-1]
r.keys = r.keys[len(r.keys)-1:]
continue
} else {
i++
}
}
if !found {
return errors.New("agent: key not found")
}
return nil
}
// Lock locks the agent. Sign and Remove will fail, and List will empty an empty list.
func (r *keyring) Lock(passphrase []byte) error {
r.mu.Lock()
defer r.mu.Unlock()
if r.locked {
return errLocked
}
r.locked = true
r.passphrase = passphrase
return nil
}
// Unlock undoes the effect of Lock
func (r *keyring) Unlock(passphrase []byte) error {
r.mu.Lock()
defer r.mu.Unlock()
if !r.locked {
return errors.New("agent: not locked")
}
if len(passphrase) != len(r.passphrase) || 1 != subtle.ConstantTimeCompare(passphrase, r.passphrase) {
return fmt.Errorf("agent: incorrect passphrase")
}
r.locked = false
r.passphrase = nil
return nil
}
// List returns the identities known to the agent.
func (r *keyring) List() ([]*Key, error) {
r.mu.Lock()
defer r.mu.Unlock()
if r.locked {
// section 2.7: locked agents return empty.
return nil, nil
}
var ids []*Key
for _, k := range r.keys {
pub := k.signer.PublicKey()
ids = append(ids, &Key{
Format: pub.Type(),
Blob: pub.Marshal(),
Comment: k.comment})
}
return ids, nil
}
// Insert adds a private key to the keyring. If a certificate
// is given, that certificate is added as public key.
func (r *keyring) Add(priv interface{}, cert *ssh.Certificate, comment string) error {
r.mu.Lock()
defer r.mu.Unlock()
if r.locked {
return errLocked
}
signer, err := ssh.NewSignerFromKey(priv)
if err != nil {
return err
}
if cert != nil {
signer, err = ssh.NewCertSigner(cert, signer)
if err != nil {
return err
}
}
r.keys = append(r.keys, privKey{signer, comment})
return nil
}
// Sign returns a signature for the data.
func (r *keyring) Sign(key ssh.PublicKey, data []byte) (*ssh.Signature, error) {
r.mu.Lock()
defer r.mu.Unlock()
if r.locked {
return nil, errLocked
}
wanted := key.Marshal()
for _, k := range r.keys {
if bytes.Equal(k.signer.PublicKey().Marshal(), wanted) {
return k.signer.Sign(rand.Reader, data)
}
}
return nil, errors.New("not found")
}
// Signers returns signers for all the known keys.
func (r *keyring) Signers() ([]ssh.Signer, error) {
r.mu.Lock()
defer r.mu.Unlock()
if r.locked {
return nil, errLocked
}
s := make([]ssh.Signer, len(r.keys))
for _, k := range r.keys {
s = append(s, k.signer)
}
return s, nil
}

View file

@ -0,0 +1,209 @@
// Copyright 2012 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package agent
import (
"crypto/rsa"
"encoding/binary"
"fmt"
"io"
"log"
"math/big"
"golang.org/x/crypto/ssh"
)
// Server wraps an Agent and uses it to implement the agent side of
// the SSH-agent, wire protocol.
type server struct {
agent Agent
}
func (s *server) processRequestBytes(reqData []byte) []byte {
rep, err := s.processRequest(reqData)
if err != nil {
if err != errLocked {
// TODO(hanwen): provide better logging interface?
log.Printf("agent %d: %v", reqData[0], err)
}
return []byte{agentFailure}
}
if err == nil && rep == nil {
return []byte{agentSuccess}
}
return ssh.Marshal(rep)
}
func marshalKey(k *Key) []byte {
var record struct {
Blob []byte
Comment string
}
record.Blob = k.Marshal()
record.Comment = k.Comment
return ssh.Marshal(&record)
}
type agentV1IdentityMsg struct {
Numkeys uint32 `sshtype:"2"`
}
type agentRemoveIdentityMsg struct {
KeyBlob []byte `sshtype:"18"`
}
type agentLockMsg struct {
Passphrase []byte `sshtype:"22"`
}
type agentUnlockMsg struct {
Passphrase []byte `sshtype:"23"`
}
func (s *server) processRequest(data []byte) (interface{}, error) {
switch data[0] {
case agentRequestV1Identities:
return &agentV1IdentityMsg{0}, nil
case agentRemoveIdentity:
var req agentRemoveIdentityMsg
if err := ssh.Unmarshal(data, &req); err != nil {
return nil, err
}
var wk wireKey
if err := ssh.Unmarshal(req.KeyBlob, &wk); err != nil {
return nil, err
}
return nil, s.agent.Remove(&Key{Format: wk.Format, Blob: req.KeyBlob})
case agentRemoveAllIdentities:
return nil, s.agent.RemoveAll()
case agentLock:
var req agentLockMsg
if err := ssh.Unmarshal(data, &req); err != nil {
return nil, err
}
return nil, s.agent.Lock(req.Passphrase)
case agentUnlock:
var req agentLockMsg
if err := ssh.Unmarshal(data, &req); err != nil {
return nil, err
}
return nil, s.agent.Unlock(req.Passphrase)
case agentSignRequest:
var req signRequestAgentMsg
if err := ssh.Unmarshal(data, &req); err != nil {
return nil, err
}
var wk wireKey
if err := ssh.Unmarshal(req.KeyBlob, &wk); err != nil {
return nil, err
}
k := &Key{
Format: wk.Format,
Blob: req.KeyBlob,
}
sig, err := s.agent.Sign(k, req.Data) // TODO(hanwen): flags.
if err != nil {
return nil, err
}
return &signResponseAgentMsg{SigBlob: ssh.Marshal(sig)}, nil
case agentRequestIdentities:
keys, err := s.agent.List()
if err != nil {
return nil, err
}
rep := identitiesAnswerAgentMsg{
NumKeys: uint32(len(keys)),
}
for _, k := range keys {
rep.Keys = append(rep.Keys, marshalKey(k)...)
}
return rep, nil
case agentAddIdentity:
return nil, s.insertIdentity(data)
}
return nil, fmt.Errorf("unknown opcode %d", data[0])
}
func (s *server) insertIdentity(req []byte) error {
var record struct {
Type string `sshtype:"17"`
Rest []byte `ssh:"rest"`
}
if err := ssh.Unmarshal(req, &record); err != nil {
return err
}
switch record.Type {
case ssh.KeyAlgoRSA:
var k rsaKeyMsg
if err := ssh.Unmarshal(req, &k); err != nil {
return err
}
priv := rsa.PrivateKey{
PublicKey: rsa.PublicKey{
E: int(k.E.Int64()),
N: k.N,
},
D: k.D,
Primes: []*big.Int{k.P, k.Q},
}
priv.Precompute()
return s.agent.Add(&priv, nil, k.Comments)
}
return fmt.Errorf("not implemented: %s", record.Type)
}
// ServeAgent serves the agent protocol on the given connection. It
// returns when an I/O error occurs.
func ServeAgent(agent Agent, c io.ReadWriter) error {
s := &server{agent}
var length [4]byte
for {
if _, err := io.ReadFull(c, length[:]); err != nil {
return err
}
l := binary.BigEndian.Uint32(length[:])
if l > maxAgentResponseBytes {
// We also cap requests.
return fmt.Errorf("agent: request too large: %d", l)
}
req := make([]byte, l)
if _, err := io.ReadFull(c, req); err != nil {
return err
}
repData := s.processRequestBytes(req)
if len(repData) > maxAgentResponseBytes {
return fmt.Errorf("agent: reply too large: %d bytes", len(repData))
}
binary.BigEndian.PutUint32(length[:], uint32(len(repData)))
if _, err := c.Write(length[:]); err != nil {
return err
}
if _, err := c.Write(repData); err != nil {
return err
}
}
}

View file

@ -0,0 +1,77 @@
// Copyright 2012 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package agent
import (
"testing"
"golang.org/x/crypto/ssh"
)
func TestServer(t *testing.T) {
c1, c2, err := netPipe()
if err != nil {
t.Fatalf("netPipe: %v", err)
}
defer c1.Close()
defer c2.Close()
client := NewClient(c1)
go ServeAgent(NewKeyring(), c2)
testAgentInterface(t, client, testPrivateKeys["rsa"], nil)
}
func TestLockServer(t *testing.T) {
testLockAgent(NewKeyring(), t)
}
func TestSetupForwardAgent(t *testing.T) {
a, b, err := netPipe()
if err != nil {
t.Fatalf("netPipe: %v", err)
}
defer a.Close()
defer b.Close()
_, socket, cleanup := startAgent(t)
defer cleanup()
serverConf := ssh.ServerConfig{
NoClientAuth: true,
}
serverConf.AddHostKey(testSigners["rsa"])
incoming := make(chan *ssh.ServerConn, 1)
go func() {
conn, _, _, err := ssh.NewServerConn(a, &serverConf)
if err != nil {
t.Fatalf("Server: %v", err)
}
incoming <- conn
}()
conf := ssh.ClientConfig{}
conn, chans, reqs, err := ssh.NewClientConn(b, "", &conf)
if err != nil {
t.Fatalf("NewClientConn: %v", err)
}
client := ssh.NewClient(conn, chans, reqs)
if err := ForwardToRemote(client, socket); err != nil {
t.Fatalf("SetupForwardAgent: %v", err)
}
server := <-incoming
ch, reqs, err := server.OpenChannel(channelType, nil)
if err != nil {
t.Fatalf("OpenChannel(%q): %v", channelType, err)
}
go ssh.DiscardRequests(reqs)
agentClient := NewClient(ch)
testAgentInterface(t, agentClient, testPrivateKeys["rsa"], nil)
conn.Close()
}

View file

@ -0,0 +1,64 @@
// Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// IMPLEMENTOR NOTE: To avoid a package loop, this file is in three places:
// ssh/, ssh/agent, and ssh/test/. It should be kept in sync across all three
// instances.
package agent
import (
"crypto/rand"
"fmt"
"golang.org/x/crypto/ssh"
"golang.org/x/crypto/ssh/testdata"
)
var (
testPrivateKeys map[string]interface{}
testSigners map[string]ssh.Signer
testPublicKeys map[string]ssh.PublicKey
)
func init() {
var err error
n := len(testdata.PEMBytes)
testPrivateKeys = make(map[string]interface{}, n)
testSigners = make(map[string]ssh.Signer, n)
testPublicKeys = make(map[string]ssh.PublicKey, n)
for t, k := range testdata.PEMBytes {
testPrivateKeys[t], err = ssh.ParseRawPrivateKey(k)
if err != nil {
panic(fmt.Sprintf("Unable to parse test key %s: %v", t, err))
}
testSigners[t], err = ssh.NewSignerFromKey(testPrivateKeys[t])
if err != nil {
panic(fmt.Sprintf("Unable to create signer for test key %s: %v", t, err))
}
testPublicKeys[t] = testSigners[t].PublicKey()
}
// Create a cert and sign it for use in tests.
testCert := &ssh.Certificate{
Nonce: []byte{}, // To pass reflect.DeepEqual after marshal & parse, this must be non-nil
ValidPrincipals: []string{"gopher1", "gopher2"}, // increases test coverage
ValidAfter: 0, // unix epoch
ValidBefore: ssh.CertTimeInfinity, // The end of currently representable time.
Reserved: []byte{}, // To pass reflect.DeepEqual after marshal & parse, this must be non-nil
Key: testPublicKeys["ecdsa"],
SignatureKey: testPublicKeys["rsa"],
Permissions: ssh.Permissions{
CriticalOptions: map[string]string{},
Extensions: map[string]string{},
},
}
testCert.SignCert(rand.Reader, testSigners["rsa"])
testPrivateKeys["cert"] = testPrivateKeys["ecdsa"]
testSigners["cert"], err = ssh.NewCertSigner(testCert, testSigners["ecdsa"])
if err != nil {
panic(fmt.Sprintf("Unable to create certificate signer: %v", err))
}
}

View file

@ -0,0 +1,122 @@
// Copyright 2013 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ssh
import (
"errors"
"io"
"net"
"testing"
)
type server struct {
*ServerConn
chans <-chan NewChannel
}
func newServer(c net.Conn, conf *ServerConfig) (*server, error) {
sconn, chans, reqs, err := NewServerConn(c, conf)
if err != nil {
return nil, err
}
go DiscardRequests(reqs)
return &server{sconn, chans}, nil
}
func (s *server) Accept() (NewChannel, error) {
n, ok := <-s.chans
if !ok {
return nil, io.EOF
}
return n, nil
}
func sshPipe() (Conn, *server, error) {
c1, c2, err := netPipe()
if err != nil {
return nil, nil, err
}
clientConf := ClientConfig{
User: "user",
}
serverConf := ServerConfig{
NoClientAuth: true,
}
serverConf.AddHostKey(testSigners["ecdsa"])
done := make(chan *server, 1)
go func() {
server, err := newServer(c2, &serverConf)
if err != nil {
done <- nil
}
done <- server
}()
client, _, reqs, err := NewClientConn(c1, "", &clientConf)
if err != nil {
return nil, nil, err
}
server := <-done
if server == nil {
return nil, nil, errors.New("server handshake failed.")
}
go DiscardRequests(reqs)
return client, server, nil
}
func BenchmarkEndToEnd(b *testing.B) {
b.StopTimer()
client, server, err := sshPipe()
if err != nil {
b.Fatalf("sshPipe: %v", err)
}
defer client.Close()
defer server.Close()
size := (1 << 20)
input := make([]byte, size)
output := make([]byte, size)
b.SetBytes(int64(size))
done := make(chan int, 1)
go func() {
newCh, err := server.Accept()
if err != nil {
b.Fatalf("Client: %v", err)
}
ch, incoming, err := newCh.Accept()
go DiscardRequests(incoming)
for i := 0; i < b.N; i++ {
if _, err := io.ReadFull(ch, output); err != nil {
b.Fatalf("ReadFull: %v", err)
}
}
ch.Close()
done <- 1
}()
ch, in, err := client.OpenChannel("speed", nil)
if err != nil {
b.Fatalf("OpenChannel: %v", err)
}
go DiscardRequests(in)
b.ResetTimer()
b.StartTimer()
for i := 0; i < b.N; i++ {
if _, err := ch.Write(input); err != nil {
b.Fatalf("WriteFull: %v", err)
}
}
ch.Close()
b.StopTimer()
<-done
}

View file

@ -0,0 +1,98 @@
// Copyright 2012 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ssh
import (
"io"
"sync"
)
// buffer provides a linked list buffer for data exchange
// between producer and consumer. Theoretically the buffer is
// of unlimited capacity as it does no allocation of its own.
type buffer struct {
// protects concurrent access to head, tail and closed
*sync.Cond
head *element // the buffer that will be read first
tail *element // the buffer that will be read last
closed bool
}
// An element represents a single link in a linked list.
type element struct {
buf []byte
next *element
}
// newBuffer returns an empty buffer that is not closed.
func newBuffer() *buffer {
e := new(element)
b := &buffer{
Cond: newCond(),
head: e,
tail: e,
}
return b
}
// write makes buf available for Read to receive.
// buf must not be modified after the call to write.
func (b *buffer) write(buf []byte) {
b.Cond.L.Lock()
e := &element{buf: buf}
b.tail.next = e
b.tail = e
b.Cond.Signal()
b.Cond.L.Unlock()
}
// eof closes the buffer. Reads from the buffer once all
// the data has been consumed will receive os.EOF.
func (b *buffer) eof() error {
b.Cond.L.Lock()
b.closed = true
b.Cond.Signal()
b.Cond.L.Unlock()
return nil
}
// Read reads data from the internal buffer in buf. Reads will block
// if no data is available, or until the buffer is closed.
func (b *buffer) Read(buf []byte) (n int, err error) {
b.Cond.L.Lock()
defer b.Cond.L.Unlock()
for len(buf) > 0 {
// if there is data in b.head, copy it
if len(b.head.buf) > 0 {
r := copy(buf, b.head.buf)
buf, b.head.buf = buf[r:], b.head.buf[r:]
n += r
continue
}
// if there is a next buffer, make it the head
if len(b.head.buf) == 0 && b.head != b.tail {
b.head = b.head.next
continue
}
// if at least one byte has been copied, return
if n > 0 {
break
}
// if nothing was read, and there is nothing outstanding
// check to see if the buffer is closed.
if b.closed {
err = io.EOF
break
}
// out of buffers, wait for producer
b.Cond.Wait()
}
return
}

View file

@ -0,0 +1,87 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ssh
import (
"io"
"testing"
)
var alphabet = []byte("abcdefghijklmnopqrstuvwxyz")
func TestBufferReadwrite(t *testing.T) {
b := newBuffer()
b.write(alphabet[:10])
r, _ := b.Read(make([]byte, 10))
if r != 10 {
t.Fatalf("Expected written == read == 10, written: 10, read %d", r)
}
b = newBuffer()
b.write(alphabet[:5])
r, _ = b.Read(make([]byte, 10))
if r != 5 {
t.Fatalf("Expected written == read == 5, written: 5, read %d", r)
}
b = newBuffer()
b.write(alphabet[:10])
r, _ = b.Read(make([]byte, 5))
if r != 5 {
t.Fatalf("Expected written == 10, read == 5, written: 10, read %d", r)
}
b = newBuffer()
b.write(alphabet[:5])
b.write(alphabet[5:15])
r, _ = b.Read(make([]byte, 10))
r2, _ := b.Read(make([]byte, 10))
if r != 10 || r2 != 5 || 15 != r+r2 {
t.Fatal("Expected written == read == 15")
}
}
func TestBufferClose(t *testing.T) {
b := newBuffer()
b.write(alphabet[:10])
b.eof()
_, err := b.Read(make([]byte, 5))
if err != nil {
t.Fatal("expected read of 5 to not return EOF")
}
b = newBuffer()
b.write(alphabet[:10])
b.eof()
r, err := b.Read(make([]byte, 5))
r2, err2 := b.Read(make([]byte, 10))
if r != 5 || r2 != 5 || err != nil || err2 != nil {
t.Fatal("expected reads of 5 and 5")
}
b = newBuffer()
b.write(alphabet[:10])
b.eof()
r, err = b.Read(make([]byte, 5))
r2, err2 = b.Read(make([]byte, 10))
r3, err3 := b.Read(make([]byte, 10))
if r != 5 || r2 != 5 || r3 != 0 || err != nil || err2 != nil || err3 != io.EOF {
t.Fatal("expected reads of 5 and 5 and 0, with EOF")
}
b = newBuffer()
b.write(make([]byte, 5))
b.write(make([]byte, 10))
b.eof()
r, err = b.Read(make([]byte, 9))
r2, err2 = b.Read(make([]byte, 3))
r3, err3 = b.Read(make([]byte, 3))
r4, err4 := b.Read(make([]byte, 10))
if err != nil || err2 != nil || err3 != nil || err4 != io.EOF {
t.Fatalf("Expected EOF on forth read only, err=%v, err2=%v, err3=%v, err4=%v", err, err2, err3, err4)
}
if r != 9 || r2 != 3 || r3 != 3 || r4 != 0 {
t.Fatal("Expected written == read == 15", r, r2, r3, r4)
}
}

View file

@ -0,0 +1,474 @@
// Copyright 2012 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ssh
import (
"bytes"
"errors"
"fmt"
"io"
"net"
"sort"
"time"
)
// These constants from [PROTOCOL.certkeys] represent the algorithm names
// for certificate types supported by this package.
const (
CertAlgoRSAv01 = "ssh-rsa-cert-v01@openssh.com"
CertAlgoDSAv01 = "ssh-dss-cert-v01@openssh.com"
CertAlgoECDSA256v01 = "ecdsa-sha2-nistp256-cert-v01@openssh.com"
CertAlgoECDSA384v01 = "ecdsa-sha2-nistp384-cert-v01@openssh.com"
CertAlgoECDSA521v01 = "ecdsa-sha2-nistp521-cert-v01@openssh.com"
)
// Certificate types distinguish between host and user
// certificates. The values can be set in the CertType field of
// Certificate.
const (
UserCert = 1
HostCert = 2
)
// Signature represents a cryptographic signature.
type Signature struct {
Format string
Blob []byte
}
// CertTimeInfinity can be used for OpenSSHCertV01.ValidBefore to indicate that
// a certificate does not expire.
const CertTimeInfinity = 1<<64 - 1
// An Certificate represents an OpenSSH certificate as defined in
// [PROTOCOL.certkeys]?rev=1.8.
type Certificate struct {
Nonce []byte
Key PublicKey
Serial uint64
CertType uint32
KeyId string
ValidPrincipals []string
ValidAfter uint64
ValidBefore uint64
Permissions
Reserved []byte
SignatureKey PublicKey
Signature *Signature
}
// genericCertData holds the key-independent part of the certificate data.
// Overall, certificates contain an nonce, public key fields and
// key-independent fields.
type genericCertData struct {
Serial uint64
CertType uint32
KeyId string
ValidPrincipals []byte
ValidAfter uint64
ValidBefore uint64
CriticalOptions []byte
Extensions []byte
Reserved []byte
SignatureKey []byte
Signature []byte
}
func marshalStringList(namelist []string) []byte {
var to []byte
for _, name := range namelist {
s := struct{ N string }{name}
to = append(to, Marshal(&s)...)
}
return to
}
func marshalTuples(tups map[string]string) []byte {
keys := make([]string, 0, len(tups))
for k := range tups {
keys = append(keys, k)
}
sort.Strings(keys)
var r []byte
for _, k := range keys {
s := struct{ K, V string }{k, tups[k]}
r = append(r, Marshal(&s)...)
}
return r
}
func parseTuples(in []byte) (map[string]string, error) {
tups := map[string]string{}
var lastKey string
var haveLastKey bool
for len(in) > 0 {
nameBytes, rest, ok := parseString(in)
if !ok {
return nil, errShortRead
}
data, rest, ok := parseString(rest)
if !ok {
return nil, errShortRead
}
name := string(nameBytes)
// according to [PROTOCOL.certkeys], the names must be in
// lexical order.
if haveLastKey && name <= lastKey {
return nil, fmt.Errorf("ssh: certificate options are not in lexical order")
}
lastKey, haveLastKey = name, true
tups[name] = string(data)
in = rest
}
return tups, nil
}
func parseCert(in []byte, privAlgo string) (*Certificate, error) {
nonce, rest, ok := parseString(in)
if !ok {
return nil, errShortRead
}
key, rest, err := parsePubKey(rest, privAlgo)
if err != nil {
return nil, err
}
var g genericCertData
if err := Unmarshal(rest, &g); err != nil {
return nil, err
}
c := &Certificate{
Nonce: nonce,
Key: key,
Serial: g.Serial,
CertType: g.CertType,
KeyId: g.KeyId,
ValidAfter: g.ValidAfter,
ValidBefore: g.ValidBefore,
}
for principals := g.ValidPrincipals; len(principals) > 0; {
principal, rest, ok := parseString(principals)
if !ok {
return nil, errShortRead
}
c.ValidPrincipals = append(c.ValidPrincipals, string(principal))
principals = rest
}
c.CriticalOptions, err = parseTuples(g.CriticalOptions)
if err != nil {
return nil, err
}
c.Extensions, err = parseTuples(g.Extensions)
if err != nil {
return nil, err
}
c.Reserved = g.Reserved
k, err := ParsePublicKey(g.SignatureKey)
if err != nil {
return nil, err
}
c.SignatureKey = k
c.Signature, rest, ok = parseSignatureBody(g.Signature)
if !ok || len(rest) > 0 {
return nil, errors.New("ssh: signature parse error")
}
return c, nil
}
type openSSHCertSigner struct {
pub *Certificate
signer Signer
}
// NewCertSigner returns a Signer that signs with the given Certificate, whose
// private key is held by signer. It returns an error if the public key in cert
// doesn't match the key used by signer.
func NewCertSigner(cert *Certificate, signer Signer) (Signer, error) {
if bytes.Compare(cert.Key.Marshal(), signer.PublicKey().Marshal()) != 0 {
return nil, errors.New("ssh: signer and cert have different public key")
}
return &openSSHCertSigner{cert, signer}, nil
}
func (s *openSSHCertSigner) Sign(rand io.Reader, data []byte) (*Signature, error) {
return s.signer.Sign(rand, data)
}
func (s *openSSHCertSigner) PublicKey() PublicKey {
return s.pub
}
const sourceAddressCriticalOption = "source-address"
// CertChecker does the work of verifying a certificate. Its methods
// can be plugged into ClientConfig.HostKeyCallback and
// ServerConfig.PublicKeyCallback. For the CertChecker to work,
// minimally, the IsAuthority callback should be set.
type CertChecker struct {
// SupportedCriticalOptions lists the CriticalOptions that the
// server application layer understands. These are only used
// for user certificates.
SupportedCriticalOptions []string
// IsAuthority should return true if the key is recognized as
// an authority. This allows for certificates to be signed by other
// certificates.
IsAuthority func(auth PublicKey) bool
// Clock is used for verifying time stamps. If nil, time.Now
// is used.
Clock func() time.Time
// UserKeyFallback is called when CertChecker.Authenticate encounters a
// public key that is not a certificate. It must implement validation
// of user keys or else, if nil, all such keys are rejected.
UserKeyFallback func(conn ConnMetadata, key PublicKey) (*Permissions, error)
// HostKeyFallback is called when CertChecker.CheckHostKey encounters a
// public key that is not a certificate. It must implement host key
// validation or else, if nil, all such keys are rejected.
HostKeyFallback func(addr string, remote net.Addr, key PublicKey) error
// IsRevoked is called for each certificate so that revocation checking
// can be implemented. It should return true if the given certificate
// is revoked and false otherwise. If nil, no certificates are
// considered to have been revoked.
IsRevoked func(cert *Certificate) bool
}
// CheckHostKey checks a host key certificate. This method can be
// plugged into ClientConfig.HostKeyCallback.
func (c *CertChecker) CheckHostKey(addr string, remote net.Addr, key PublicKey) error {
cert, ok := key.(*Certificate)
if !ok {
if c.HostKeyFallback != nil {
return c.HostKeyFallback(addr, remote, key)
}
return errors.New("ssh: non-certificate host key")
}
if cert.CertType != HostCert {
return fmt.Errorf("ssh: certificate presented as a host key has type %d", cert.CertType)
}
return c.CheckCert(addr, cert)
}
// Authenticate checks a user certificate. Authenticate can be used as
// a value for ServerConfig.PublicKeyCallback.
func (c *CertChecker) Authenticate(conn ConnMetadata, pubKey PublicKey) (*Permissions, error) {
cert, ok := pubKey.(*Certificate)
if !ok {
if c.UserKeyFallback != nil {
return c.UserKeyFallback(conn, pubKey)
}
return nil, errors.New("ssh: normal key pairs not accepted")
}
if cert.CertType != UserCert {
return nil, fmt.Errorf("ssh: cert has type %d", cert.CertType)
}
if err := c.CheckCert(conn.User(), cert); err != nil {
return nil, err
}
return &cert.Permissions, nil
}
// CheckCert checks CriticalOptions, ValidPrincipals, revocation, timestamp and
// the signature of the certificate.
func (c *CertChecker) CheckCert(principal string, cert *Certificate) error {
if c.IsRevoked != nil && c.IsRevoked(cert) {
return fmt.Errorf("ssh: certicate serial %d revoked", cert.Serial)
}
for opt, _ := range cert.CriticalOptions {
// sourceAddressCriticalOption will be enforced by
// serverAuthenticate
if opt == sourceAddressCriticalOption {
continue
}
found := false
for _, supp := range c.SupportedCriticalOptions {
if supp == opt {
found = true
break
}
}
if !found {
return fmt.Errorf("ssh: unsupported critical option %q in certificate", opt)
}
}
if len(cert.ValidPrincipals) > 0 {
// By default, certs are valid for all users/hosts.
found := false
for _, p := range cert.ValidPrincipals {
if p == principal {
found = true
break
}
}
if !found {
return fmt.Errorf("ssh: principal %q not in the set of valid principals for given certificate: %q", principal, cert.ValidPrincipals)
}
}
if !c.IsAuthority(cert.SignatureKey) {
return fmt.Errorf("ssh: certificate signed by unrecognized authority")
}
clock := c.Clock
if clock == nil {
clock = time.Now
}
unixNow := clock().Unix()
if after := int64(cert.ValidAfter); after < 0 || unixNow < int64(cert.ValidAfter) {
return fmt.Errorf("ssh: cert is not yet valid")
}
if before := int64(cert.ValidBefore); cert.ValidBefore != CertTimeInfinity && (unixNow >= before || before < 0) {
return fmt.Errorf("ssh: cert has expired")
}
if err := cert.SignatureKey.Verify(cert.bytesForSigning(), cert.Signature); err != nil {
return fmt.Errorf("ssh: certificate signature does not verify")
}
return nil
}
// SignCert sets c.SignatureKey to the authority's public key and stores a
// Signature, by authority, in the certificate.
func (c *Certificate) SignCert(rand io.Reader, authority Signer) error {
c.Nonce = make([]byte, 32)
if _, err := io.ReadFull(rand, c.Nonce); err != nil {
return err
}
c.SignatureKey = authority.PublicKey()
sig, err := authority.Sign(rand, c.bytesForSigning())
if err != nil {
return err
}
c.Signature = sig
return nil
}
var certAlgoNames = map[string]string{
KeyAlgoRSA: CertAlgoRSAv01,
KeyAlgoDSA: CertAlgoDSAv01,
KeyAlgoECDSA256: CertAlgoECDSA256v01,
KeyAlgoECDSA384: CertAlgoECDSA384v01,
KeyAlgoECDSA521: CertAlgoECDSA521v01,
}
// certToPrivAlgo returns the underlying algorithm for a certificate algorithm.
// Panics if a non-certificate algorithm is passed.
func certToPrivAlgo(algo string) string {
for privAlgo, pubAlgo := range certAlgoNames {
if pubAlgo == algo {
return privAlgo
}
}
panic("unknown cert algorithm")
}
func (cert *Certificate) bytesForSigning() []byte {
c2 := *cert
c2.Signature = nil
out := c2.Marshal()
// Drop trailing signature length.
return out[:len(out)-4]
}
// Marshal serializes c into OpenSSH's wire format. It is part of the
// PublicKey interface.
func (c *Certificate) Marshal() []byte {
generic := genericCertData{
Serial: c.Serial,
CertType: c.CertType,
KeyId: c.KeyId,
ValidPrincipals: marshalStringList(c.ValidPrincipals),
ValidAfter: uint64(c.ValidAfter),
ValidBefore: uint64(c.ValidBefore),
CriticalOptions: marshalTuples(c.CriticalOptions),
Extensions: marshalTuples(c.Extensions),
Reserved: c.Reserved,
SignatureKey: c.SignatureKey.Marshal(),
}
if c.Signature != nil {
generic.Signature = Marshal(c.Signature)
}
genericBytes := Marshal(&generic)
keyBytes := c.Key.Marshal()
_, keyBytes, _ = parseString(keyBytes)
prefix := Marshal(&struct {
Name string
Nonce []byte
Key []byte `ssh:"rest"`
}{c.Type(), c.Nonce, keyBytes})
result := make([]byte, 0, len(prefix)+len(genericBytes))
result = append(result, prefix...)
result = append(result, genericBytes...)
return result
}
// Type returns the key name. It is part of the PublicKey interface.
func (c *Certificate) Type() string {
algo, ok := certAlgoNames[c.Key.Type()]
if !ok {
panic("unknown cert key type")
}
return algo
}
// Verify verifies a signature against the certificate's public
// key. It is part of the PublicKey interface.
func (c *Certificate) Verify(data []byte, sig *Signature) error {
return c.Key.Verify(data, sig)
}
func parseSignatureBody(in []byte) (out *Signature, rest []byte, ok bool) {
format, in, ok := parseString(in)
if !ok {
return
}
out = &Signature{
Format: string(format),
}
if out.Blob, in, ok = parseString(in); !ok {
return
}
return out, in, ok
}
func parseSignature(in []byte) (out *Signature, rest []byte, ok bool) {
sigBytes, rest, ok := parseString(in)
if !ok {
return
}
out, trailing, ok := parseSignatureBody(sigBytes)
if !ok || len(trailing) > 0 {
return nil, nil, false
}
return
}

View file

@ -0,0 +1,156 @@
// Copyright 2013 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ssh
import (
"bytes"
"crypto/rand"
"testing"
"time"
)
// Cert generated by ssh-keygen 6.0p1 Debian-4.
// % ssh-keygen -s ca-key -I test user-key
var exampleSSHCert = `ssh-rsa-cert-v01@openssh.com AAAAHHNzaC1yc2EtY2VydC12MDFAb3BlbnNzaC5jb20AAAAgb1srW/W3ZDjYAO45xLYAwzHBDLsJ4Ux6ICFIkTjb1LEAAAADAQABAAAAYQCkoR51poH0wE8w72cqSB8Sszx+vAhzcMdCO0wqHTj7UNENHWEXGrU0E0UQekD7U+yhkhtoyjbPOVIP7hNa6aRk/ezdh/iUnCIt4Jt1v3Z1h1P+hA4QuYFMHNB+rmjPwAcAAAAAAAAAAAAAAAEAAAAEdGVzdAAAAAAAAAAAAAAAAP//////////AAAAAAAAAIIAAAAVcGVybWl0LVgxMS1mb3J3YXJkaW5nAAAAAAAAABdwZXJtaXQtYWdlbnQtZm9yd2FyZGluZwAAAAAAAAAWcGVybWl0LXBvcnQtZm9yd2FyZGluZwAAAAAAAAAKcGVybWl0LXB0eQAAAAAAAAAOcGVybWl0LXVzZXItcmMAAAAAAAAAAAAAAHcAAAAHc3NoLXJzYQAAAAMBAAEAAABhANFS2kaktpSGc+CcmEKPyw9mJC4nZKxHKTgLVZeaGbFZOvJTNzBspQHdy7Q1uKSfktxpgjZnksiu/tFF9ngyY2KFoc+U88ya95IZUycBGCUbBQ8+bhDtw/icdDGQD5WnUwAAAG8AAAAHc3NoLXJzYQAAAGC8Y9Z2LQKhIhxf52773XaWrXdxP0t3GBVo4A10vUWiYoAGepr6rQIoGGXFxT4B9Gp+nEBJjOwKDXPrAevow0T9ca8gZN+0ykbhSrXLE5Ao48rqr3zP4O1/9P7e6gp0gw8=`
func TestParseCert(t *testing.T) {
authKeyBytes := []byte(exampleSSHCert)
key, _, _, rest, err := ParseAuthorizedKey(authKeyBytes)
if err != nil {
t.Fatalf("ParseAuthorizedKey: %v", err)
}
if len(rest) > 0 {
t.Errorf("rest: got %q, want empty", rest)
}
if _, ok := key.(*Certificate); !ok {
t.Fatalf("got %#v, want *Certificate", key)
}
marshaled := MarshalAuthorizedKey(key)
// Before comparison, remove the trailing newline that
// MarshalAuthorizedKey adds.
marshaled = marshaled[:len(marshaled)-1]
if !bytes.Equal(authKeyBytes, marshaled) {
t.Errorf("marshaled certificate does not match original: got %q, want %q", marshaled, authKeyBytes)
}
}
func TestValidateCert(t *testing.T) {
key, _, _, _, err := ParseAuthorizedKey([]byte(exampleSSHCert))
if err != nil {
t.Fatalf("ParseAuthorizedKey: %v", err)
}
validCert, ok := key.(*Certificate)
if !ok {
t.Fatalf("got %v (%T), want *Certificate", key, key)
}
checker := CertChecker{}
checker.IsAuthority = func(k PublicKey) bool {
return bytes.Equal(k.Marshal(), validCert.SignatureKey.Marshal())
}
if err := checker.CheckCert("user", validCert); err != nil {
t.Errorf("Unable to validate certificate: %v", err)
}
invalidCert := &Certificate{
Key: testPublicKeys["rsa"],
SignatureKey: testPublicKeys["ecdsa"],
ValidBefore: CertTimeInfinity,
Signature: &Signature{},
}
if err := checker.CheckCert("user", invalidCert); err == nil {
t.Error("Invalid cert signature passed validation")
}
}
func TestValidateCertTime(t *testing.T) {
cert := Certificate{
ValidPrincipals: []string{"user"},
Key: testPublicKeys["rsa"],
ValidAfter: 50,
ValidBefore: 100,
}
cert.SignCert(rand.Reader, testSigners["ecdsa"])
for ts, ok := range map[int64]bool{
25: false,
50: true,
99: true,
100: false,
125: false,
} {
checker := CertChecker{
Clock: func() time.Time { return time.Unix(ts, 0) },
}
checker.IsAuthority = func(k PublicKey) bool {
return bytes.Equal(k.Marshal(),
testPublicKeys["ecdsa"].Marshal())
}
if v := checker.CheckCert("user", &cert); (v == nil) != ok {
t.Errorf("Authenticate(%d): %v", ts, v)
}
}
}
// TODO(hanwen): tests for
//
// host keys:
// * fallbacks
func TestHostKeyCert(t *testing.T) {
cert := &Certificate{
ValidPrincipals: []string{"hostname", "hostname.domain"},
Key: testPublicKeys["rsa"],
ValidBefore: CertTimeInfinity,
CertType: HostCert,
}
cert.SignCert(rand.Reader, testSigners["ecdsa"])
checker := &CertChecker{
IsAuthority: func(p PublicKey) bool {
return bytes.Equal(testPublicKeys["ecdsa"].Marshal(), p.Marshal())
},
}
certSigner, err := NewCertSigner(cert, testSigners["rsa"])
if err != nil {
t.Errorf("NewCertSigner: %v", err)
}
for _, name := range []string{"hostname", "otherhost"} {
c1, c2, err := netPipe()
if err != nil {
t.Fatalf("netPipe: %v", err)
}
defer c1.Close()
defer c2.Close()
go func() {
conf := ServerConfig{
NoClientAuth: true,
}
conf.AddHostKey(certSigner)
_, _, _, err := NewServerConn(c1, &conf)
if err != nil {
t.Fatalf("NewServerConn: %v", err)
}
}()
config := &ClientConfig{
User: "user",
HostKeyCallback: checker.CheckHostKey,
}
_, _, _, err = NewClientConn(c2, name, config)
succeed := name == "hostname"
if (err == nil) != succeed {
t.Fatalf("NewClientConn(%q): %v", name, err)
}
}
}

View file

@ -0,0 +1,631 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ssh
import (
"encoding/binary"
"errors"
"fmt"
"io"
"log"
"sync"
)
const (
minPacketLength = 9
// channelMaxPacket contains the maximum number of bytes that will be
// sent in a single packet. As per RFC 4253, section 6.1, 32k is also
// the minimum.
channelMaxPacket = 1 << 15
// We follow OpenSSH here.
channelWindowSize = 64 * channelMaxPacket
)
// NewChannel represents an incoming request to a channel. It must either be
// accepted for use by calling Accept, or rejected by calling Reject.
type NewChannel interface {
// Accept accepts the channel creation request. It returns the Channel
// and a Go channel containing SSH requests. The Go channel must be
// serviced otherwise the Channel will hang.
Accept() (Channel, <-chan *Request, error)
// Reject rejects the channel creation request. After calling
// this, no other methods on the Channel may be called.
Reject(reason RejectionReason, message string) error
// ChannelType returns the type of the channel, as supplied by the
// client.
ChannelType() string
// ExtraData returns the arbitrary payload for this channel, as supplied
// by the client. This data is specific to the channel type.
ExtraData() []byte
}
// A Channel is an ordered, reliable, flow-controlled, duplex stream
// that is multiplexed over an SSH connection.
type Channel interface {
// Read reads up to len(data) bytes from the channel.
Read(data []byte) (int, error)
// Write writes len(data) bytes to the channel.
Write(data []byte) (int, error)
// Close signals end of channel use. No data may be sent after this
// call.
Close() error
// CloseWrite signals the end of sending in-band
// data. Requests may still be sent, and the other side may
// still send data
CloseWrite() error
// SendRequest sends a channel request. If wantReply is true,
// it will wait for a reply and return the result as a
// boolean, otherwise the return value will be false. Channel
// requests are out-of-band messages so they may be sent even
// if the data stream is closed or blocked by flow control.
SendRequest(name string, wantReply bool, payload []byte) (bool, error)
// Stderr returns an io.ReadWriter that writes to this channel
// with the extended data type set to stderr. Stderr may
// safely be read and written from a different goroutine than
// Read and Write respectively.
Stderr() io.ReadWriter
}
// Request is a request sent outside of the normal stream of
// data. Requests can either be specific to an SSH channel, or they
// can be global.
type Request struct {
Type string
WantReply bool
Payload []byte
ch *channel
mux *mux
}
// Reply sends a response to a request. It must be called for all requests
// where WantReply is true and is a no-op otherwise. The payload argument is
// ignored for replies to channel-specific requests.
func (r *Request) Reply(ok bool, payload []byte) error {
if !r.WantReply {
return nil
}
if r.ch == nil {
return r.mux.ackRequest(ok, payload)
}
return r.ch.ackRequest(ok)
}
// RejectionReason is an enumeration used when rejecting channel creation
// requests. See RFC 4254, section 5.1.
type RejectionReason uint32
const (
Prohibited RejectionReason = iota + 1
ConnectionFailed
UnknownChannelType
ResourceShortage
)
// String converts the rejection reason to human readable form.
func (r RejectionReason) String() string {
switch r {
case Prohibited:
return "administratively prohibited"
case ConnectionFailed:
return "connect failed"
case UnknownChannelType:
return "unknown channel type"
case ResourceShortage:
return "resource shortage"
}
return fmt.Sprintf("unknown reason %d", int(r))
}
func min(a uint32, b int) uint32 {
if a < uint32(b) {
return a
}
return uint32(b)
}
type channelDirection uint8
const (
channelInbound channelDirection = iota
channelOutbound
)
// channel is an implementation of the Channel interface that works
// with the mux class.
type channel struct {
// R/O after creation
chanType string
extraData []byte
localId, remoteId uint32
// maxIncomingPayload and maxRemotePayload are the maximum
// payload sizes of normal and extended data packets for
// receiving and sending, respectively. The wire packet will
// be 9 or 13 bytes larger (excluding encryption overhead).
maxIncomingPayload uint32
maxRemotePayload uint32
mux *mux
// decided is set to true if an accept or reject message has been sent
// (for outbound channels) or received (for inbound channels).
decided bool
// direction contains either channelOutbound, for channels created
// locally, or channelInbound, for channels created by the peer.
direction channelDirection
// Pending internal channel messages.
msg chan interface{}
// Since requests have no ID, there can be only one request
// with WantReply=true outstanding. This lock is held by a
// goroutine that has such an outgoing request pending.
sentRequestMu sync.Mutex
incomingRequests chan *Request
sentEOF bool
// thread-safe data
remoteWin window
pending *buffer
extPending *buffer
// windowMu protects myWindow, the flow-control window.
windowMu sync.Mutex
myWindow uint32
// writeMu serializes calls to mux.conn.writePacket() and
// protects sentClose and packetPool. This mutex must be
// different from windowMu, as writePacket can block if there
// is a key exchange pending.
writeMu sync.Mutex
sentClose bool
// packetPool has a buffer for each extended channel ID to
// save allocations during writes.
packetPool map[uint32][]byte
}
// writePacket sends a packet. If the packet is a channel close, it updates
// sentClose. This method takes the lock c.writeMu.
func (c *channel) writePacket(packet []byte) error {
c.writeMu.Lock()
if c.sentClose {
c.writeMu.Unlock()
return io.EOF
}
c.sentClose = (packet[0] == msgChannelClose)
err := c.mux.conn.writePacket(packet)
c.writeMu.Unlock()
return err
}
func (c *channel) sendMessage(msg interface{}) error {
if debugMux {
log.Printf("send %d: %#v", c.mux.chanList.offset, msg)
}
p := Marshal(msg)
binary.BigEndian.PutUint32(p[1:], c.remoteId)
return c.writePacket(p)
}
// WriteExtended writes data to a specific extended stream. These streams are
// used, for example, for stderr.
func (c *channel) WriteExtended(data []byte, extendedCode uint32) (n int, err error) {
if c.sentEOF {
return 0, io.EOF
}
// 1 byte message type, 4 bytes remoteId, 4 bytes data length
opCode := byte(msgChannelData)
headerLength := uint32(9)
if extendedCode > 0 {
headerLength += 4
opCode = msgChannelExtendedData
}
c.writeMu.Lock()
packet := c.packetPool[extendedCode]
// We don't remove the buffer from packetPool, so
// WriteExtended calls from different goroutines will be
// flagged as errors by the race detector.
c.writeMu.Unlock()
for len(data) > 0 {
space := min(c.maxRemotePayload, len(data))
if space, err = c.remoteWin.reserve(space); err != nil {
return n, err
}
if want := headerLength + space; uint32(cap(packet)) < want {
packet = make([]byte, want)
} else {
packet = packet[:want]
}
todo := data[:space]
packet[0] = opCode
binary.BigEndian.PutUint32(packet[1:], c.remoteId)
if extendedCode > 0 {
binary.BigEndian.PutUint32(packet[5:], uint32(extendedCode))
}
binary.BigEndian.PutUint32(packet[headerLength-4:], uint32(len(todo)))
copy(packet[headerLength:], todo)
if err = c.writePacket(packet); err != nil {
return n, err
}
n += len(todo)
data = data[len(todo):]
}
c.writeMu.Lock()
c.packetPool[extendedCode] = packet
c.writeMu.Unlock()
return n, err
}
func (c *channel) handleData(packet []byte) error {
headerLen := 9
isExtendedData := packet[0] == msgChannelExtendedData
if isExtendedData {
headerLen = 13
}
if len(packet) < headerLen {
// malformed data packet
return parseError(packet[0])
}
var extended uint32
if isExtendedData {
extended = binary.BigEndian.Uint32(packet[5:])
}
length := binary.BigEndian.Uint32(packet[headerLen-4 : headerLen])
if length == 0 {
return nil
}
if length > c.maxIncomingPayload {
// TODO(hanwen): should send Disconnect?
return errors.New("ssh: incoming packet exceeds maximum payload size")
}
data := packet[headerLen:]
if length != uint32(len(data)) {
return errors.New("ssh: wrong packet length")
}
c.windowMu.Lock()
if c.myWindow < length {
c.windowMu.Unlock()
// TODO(hanwen): should send Disconnect with reason?
return errors.New("ssh: remote side wrote too much")
}
c.myWindow -= length
c.windowMu.Unlock()
if extended == 1 {
c.extPending.write(data)
} else if extended > 0 {
// discard other extended data.
} else {
c.pending.write(data)
}
return nil
}
func (c *channel) adjustWindow(n uint32) error {
c.windowMu.Lock()
// Since myWindow is managed on our side, and can never exceed
// the initial window setting, we don't worry about overflow.
c.myWindow += uint32(n)
c.windowMu.Unlock()
return c.sendMessage(windowAdjustMsg{
AdditionalBytes: uint32(n),
})
}
func (c *channel) ReadExtended(data []byte, extended uint32) (n int, err error) {
switch extended {
case 1:
n, err = c.extPending.Read(data)
case 0:
n, err = c.pending.Read(data)
default:
return 0, fmt.Errorf("ssh: extended code %d unimplemented", extended)
}
if n > 0 {
err = c.adjustWindow(uint32(n))
// sendWindowAdjust can return io.EOF if the remote
// peer has closed the connection, however we want to
// defer forwarding io.EOF to the caller of Read until
// the buffer has been drained.
if n > 0 && err == io.EOF {
err = nil
}
}
return n, err
}
func (c *channel) close() {
c.pending.eof()
c.extPending.eof()
close(c.msg)
close(c.incomingRequests)
c.writeMu.Lock()
// This is not necesary for a normal channel teardown, but if
// there was another error, it is.
c.sentClose = true
c.writeMu.Unlock()
// Unblock writers.
c.remoteWin.close()
}
// responseMessageReceived is called when a success or failure message is
// received on a channel to check that such a message is reasonable for the
// given channel.
func (c *channel) responseMessageReceived() error {
if c.direction == channelInbound {
return errors.New("ssh: channel response message received on inbound channel")
}
if c.decided {
return errors.New("ssh: duplicate response received for channel")
}
c.decided = true
return nil
}
func (c *channel) handlePacket(packet []byte) error {
switch packet[0] {
case msgChannelData, msgChannelExtendedData:
return c.handleData(packet)
case msgChannelClose:
c.sendMessage(channelCloseMsg{PeersId: c.remoteId})
c.mux.chanList.remove(c.localId)
c.close()
return nil
case msgChannelEOF:
// RFC 4254 is mute on how EOF affects dataExt messages but
// it is logical to signal EOF at the same time.
c.extPending.eof()
c.pending.eof()
return nil
}
decoded, err := decode(packet)
if err != nil {
return err
}
switch msg := decoded.(type) {
case *channelOpenFailureMsg:
if err := c.responseMessageReceived(); err != nil {
return err
}
c.mux.chanList.remove(msg.PeersId)
c.msg <- msg
case *channelOpenConfirmMsg:
if err := c.responseMessageReceived(); err != nil {
return err
}
if msg.MaxPacketSize < minPacketLength || msg.MaxPacketSize > 1<<31 {
return fmt.Errorf("ssh: invalid MaxPacketSize %d from peer", msg.MaxPacketSize)
}
c.remoteId = msg.MyId
c.maxRemotePayload = msg.MaxPacketSize
c.remoteWin.add(msg.MyWindow)
c.msg <- msg
case *windowAdjustMsg:
if !c.remoteWin.add(msg.AdditionalBytes) {
return fmt.Errorf("ssh: invalid window update for %d bytes", msg.AdditionalBytes)
}
case *channelRequestMsg:
req := Request{
Type: msg.Request,
WantReply: msg.WantReply,
Payload: msg.RequestSpecificData,
ch: c,
}
c.incomingRequests <- &req
default:
c.msg <- msg
}
return nil
}
func (m *mux) newChannel(chanType string, direction channelDirection, extraData []byte) *channel {
ch := &channel{
remoteWin: window{Cond: newCond()},
myWindow: channelWindowSize,
pending: newBuffer(),
extPending: newBuffer(),
direction: direction,
incomingRequests: make(chan *Request, 16),
msg: make(chan interface{}, 16),
chanType: chanType,
extraData: extraData,
mux: m,
packetPool: make(map[uint32][]byte),
}
ch.localId = m.chanList.add(ch)
return ch
}
var errUndecided = errors.New("ssh: must Accept or Reject channel")
var errDecidedAlready = errors.New("ssh: can call Accept or Reject only once")
type extChannel struct {
code uint32
ch *channel
}
func (e *extChannel) Write(data []byte) (n int, err error) {
return e.ch.WriteExtended(data, e.code)
}
func (e *extChannel) Read(data []byte) (n int, err error) {
return e.ch.ReadExtended(data, e.code)
}
func (c *channel) Accept() (Channel, <-chan *Request, error) {
if c.decided {
return nil, nil, errDecidedAlready
}
c.maxIncomingPayload = channelMaxPacket
confirm := channelOpenConfirmMsg{
PeersId: c.remoteId,
MyId: c.localId,
MyWindow: c.myWindow,
MaxPacketSize: c.maxIncomingPayload,
}
c.decided = true
if err := c.sendMessage(confirm); err != nil {
return nil, nil, err
}
return c, c.incomingRequests, nil
}
func (ch *channel) Reject(reason RejectionReason, message string) error {
if ch.decided {
return errDecidedAlready
}
reject := channelOpenFailureMsg{
PeersId: ch.remoteId,
Reason: reason,
Message: message,
Language: "en",
}
ch.decided = true
return ch.sendMessage(reject)
}
func (ch *channel) Read(data []byte) (int, error) {
if !ch.decided {
return 0, errUndecided
}
return ch.ReadExtended(data, 0)
}
func (ch *channel) Write(data []byte) (int, error) {
if !ch.decided {
return 0, errUndecided
}
return ch.WriteExtended(data, 0)
}
func (ch *channel) CloseWrite() error {
if !ch.decided {
return errUndecided
}
ch.sentEOF = true
return ch.sendMessage(channelEOFMsg{
PeersId: ch.remoteId})
}
func (ch *channel) Close() error {
if !ch.decided {
return errUndecided
}
return ch.sendMessage(channelCloseMsg{
PeersId: ch.remoteId})
}
// Extended returns an io.ReadWriter that sends and receives data on the given,
// SSH extended stream. Such streams are used, for example, for stderr.
func (ch *channel) Extended(code uint32) io.ReadWriter {
if !ch.decided {
return nil
}
return &extChannel{code, ch}
}
func (ch *channel) Stderr() io.ReadWriter {
return ch.Extended(1)
}
func (ch *channel) SendRequest(name string, wantReply bool, payload []byte) (bool, error) {
if !ch.decided {
return false, errUndecided
}
if wantReply {
ch.sentRequestMu.Lock()
defer ch.sentRequestMu.Unlock()
}
msg := channelRequestMsg{
PeersId: ch.remoteId,
Request: name,
WantReply: wantReply,
RequestSpecificData: payload,
}
if err := ch.sendMessage(msg); err != nil {
return false, err
}
if wantReply {
m, ok := (<-ch.msg)
if !ok {
return false, io.EOF
}
switch m.(type) {
case *channelRequestFailureMsg:
return false, nil
case *channelRequestSuccessMsg:
return true, nil
default:
return false, fmt.Errorf("ssh: unexpected response to channel request: %#v", m)
}
}
return false, nil
}
// ackRequest either sends an ack or nack to the channel request.
func (ch *channel) ackRequest(ok bool) error {
if !ch.decided {
return errUndecided
}
var msg interface{}
if !ok {
msg = channelRequestFailureMsg{
PeersId: ch.remoteId,
}
} else {
msg = channelRequestSuccessMsg{
PeersId: ch.remoteId,
}
}
return ch.sendMessage(msg)
}
func (ch *channel) ChannelType() string {
return ch.chanType
}
func (ch *channel) ExtraData() []byte {
return ch.extraData
}

View file

@ -0,0 +1,344 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ssh
import (
"crypto/aes"
"crypto/cipher"
"crypto/rc4"
"crypto/subtle"
"encoding/binary"
"errors"
"fmt"
"hash"
"io"
)
const (
packetSizeMultiple = 16 // TODO(huin) this should be determined by the cipher.
// RFC 4253 section 6.1 defines a minimum packet size of 32768 that implementations
// MUST be able to process (plus a few more kilobytes for padding and mac). The RFC
// indicates implementations SHOULD be able to handle larger packet sizes, but then
// waffles on about reasonable limits.
//
// OpenSSH caps their maxPacket at 256kB so we choose to do
// the same. maxPacket is also used to ensure that uint32
// length fields do not overflow, so it should remain well
// below 4G.
maxPacket = 256 * 1024
)
// noneCipher implements cipher.Stream and provides no encryption. It is used
// by the transport before the first key-exchange.
type noneCipher struct{}
func (c noneCipher) XORKeyStream(dst, src []byte) {
copy(dst, src)
}
func newAESCTR(key, iv []byte) (cipher.Stream, error) {
c, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
return cipher.NewCTR(c, iv), nil
}
func newRC4(key, iv []byte) (cipher.Stream, error) {
return rc4.NewCipher(key)
}
type streamCipherMode struct {
keySize int
ivSize int
skip int
createFunc func(key, iv []byte) (cipher.Stream, error)
}
func (c *streamCipherMode) createStream(key, iv []byte) (cipher.Stream, error) {
if len(key) < c.keySize {
panic("ssh: key length too small for cipher")
}
if len(iv) < c.ivSize {
panic("ssh: iv too small for cipher")
}
stream, err := c.createFunc(key[:c.keySize], iv[:c.ivSize])
if err != nil {
return nil, err
}
var streamDump []byte
if c.skip > 0 {
streamDump = make([]byte, 512)
}
for remainingToDump := c.skip; remainingToDump > 0; {
dumpThisTime := remainingToDump
if dumpThisTime > len(streamDump) {
dumpThisTime = len(streamDump)
}
stream.XORKeyStream(streamDump[:dumpThisTime], streamDump[:dumpThisTime])
remainingToDump -= dumpThisTime
}
return stream, nil
}
// cipherModes documents properties of supported ciphers. Ciphers not included
// are not supported and will not be negotiated, even if explicitly requested in
// ClientConfig.Crypto.Ciphers.
var cipherModes = map[string]*streamCipherMode{
// Ciphers from RFC4344, which introduced many CTR-based ciphers. Algorithms
// are defined in the order specified in the RFC.
"aes128-ctr": {16, aes.BlockSize, 0, newAESCTR},
"aes192-ctr": {24, aes.BlockSize, 0, newAESCTR},
"aes256-ctr": {32, aes.BlockSize, 0, newAESCTR},
// Ciphers from RFC4345, which introduces security-improved arcfour ciphers.
// They are defined in the order specified in the RFC.
"arcfour128": {16, 0, 1536, newRC4},
"arcfour256": {32, 0, 1536, newRC4},
// Cipher defined in RFC 4253, which describes SSH Transport Layer Protocol.
// Note that this cipher is not safe, as stated in RFC 4253: "Arcfour (and
// RC4) has problems with weak keys, and should be used with caution."
// RFC4345 introduces improved versions of Arcfour.
"arcfour": {16, 0, 0, newRC4},
// AES-GCM is not a stream cipher, so it is constructed with a
// special case. If we add any more non-stream ciphers, we
// should invest a cleaner way to do this.
gcmCipherID: {16, 12, 0, nil},
}
// prefixLen is the length of the packet prefix that contains the packet length
// and number of padding bytes.
const prefixLen = 5
// streamPacketCipher is a packetCipher using a stream cipher.
type streamPacketCipher struct {
mac hash.Hash
cipher cipher.Stream
// The following members are to avoid per-packet allocations.
prefix [prefixLen]byte
seqNumBytes [4]byte
padding [2 * packetSizeMultiple]byte
packetData []byte
macResult []byte
}
// readPacket reads and decrypt a single packet from the reader argument.
func (s *streamPacketCipher) readPacket(seqNum uint32, r io.Reader) ([]byte, error) {
if _, err := io.ReadFull(r, s.prefix[:]); err != nil {
return nil, err
}
s.cipher.XORKeyStream(s.prefix[:], s.prefix[:])
length := binary.BigEndian.Uint32(s.prefix[0:4])
paddingLength := uint32(s.prefix[4])
var macSize uint32
if s.mac != nil {
s.mac.Reset()
binary.BigEndian.PutUint32(s.seqNumBytes[:], seqNum)
s.mac.Write(s.seqNumBytes[:])
s.mac.Write(s.prefix[:])
macSize = uint32(s.mac.Size())
}
if length <= paddingLength+1 {
return nil, errors.New("ssh: invalid packet length, packet too small")
}
if length > maxPacket {
return nil, errors.New("ssh: invalid packet length, packet too large")
}
// the maxPacket check above ensures that length-1+macSize
// does not overflow.
if uint32(cap(s.packetData)) < length-1+macSize {
s.packetData = make([]byte, length-1+macSize)
} else {
s.packetData = s.packetData[:length-1+macSize]
}
if _, err := io.ReadFull(r, s.packetData); err != nil {
return nil, err
}
mac := s.packetData[length-1:]
data := s.packetData[:length-1]
s.cipher.XORKeyStream(data, data)
if s.mac != nil {
s.mac.Write(data)
s.macResult = s.mac.Sum(s.macResult[:0])
if subtle.ConstantTimeCompare(s.macResult, mac) != 1 {
return nil, errors.New("ssh: MAC failure")
}
}
return s.packetData[:length-paddingLength-1], nil
}
// writePacket encrypts and sends a packet of data to the writer argument
func (s *streamPacketCipher) writePacket(seqNum uint32, w io.Writer, rand io.Reader, packet []byte) error {
if len(packet) > maxPacket {
return errors.New("ssh: packet too large")
}
paddingLength := packetSizeMultiple - (prefixLen+len(packet))%packetSizeMultiple
if paddingLength < 4 {
paddingLength += packetSizeMultiple
}
length := len(packet) + 1 + paddingLength
binary.BigEndian.PutUint32(s.prefix[:], uint32(length))
s.prefix[4] = byte(paddingLength)
padding := s.padding[:paddingLength]
if _, err := io.ReadFull(rand, padding); err != nil {
return err
}
if s.mac != nil {
s.mac.Reset()
binary.BigEndian.PutUint32(s.seqNumBytes[:], seqNum)
s.mac.Write(s.seqNumBytes[:])
s.mac.Write(s.prefix[:])
s.mac.Write(packet)
s.mac.Write(padding)
}
s.cipher.XORKeyStream(s.prefix[:], s.prefix[:])
s.cipher.XORKeyStream(packet, packet)
s.cipher.XORKeyStream(padding, padding)
if _, err := w.Write(s.prefix[:]); err != nil {
return err
}
if _, err := w.Write(packet); err != nil {
return err
}
if _, err := w.Write(padding); err != nil {
return err
}
if s.mac != nil {
s.macResult = s.mac.Sum(s.macResult[:0])
if _, err := w.Write(s.macResult); err != nil {
return err
}
}
return nil
}
type gcmCipher struct {
aead cipher.AEAD
prefix [4]byte
iv []byte
buf []byte
}
func newGCMCipher(iv, key, macKey []byte) (packetCipher, error) {
c, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
aead, err := cipher.NewGCM(c)
if err != nil {
return nil, err
}
return &gcmCipher{
aead: aead,
iv: iv,
}, nil
}
const gcmTagSize = 16
func (c *gcmCipher) writePacket(seqNum uint32, w io.Writer, rand io.Reader, packet []byte) error {
// Pad out to multiple of 16 bytes. This is different from the
// stream cipher because that encrypts the length too.
padding := byte(packetSizeMultiple - (1+len(packet))%packetSizeMultiple)
if padding < 4 {
padding += packetSizeMultiple
}
length := uint32(len(packet) + int(padding) + 1)
binary.BigEndian.PutUint32(c.prefix[:], length)
if _, err := w.Write(c.prefix[:]); err != nil {
return err
}
if cap(c.buf) < int(length) {
c.buf = make([]byte, length)
} else {
c.buf = c.buf[:length]
}
c.buf[0] = padding
copy(c.buf[1:], packet)
if _, err := io.ReadFull(rand, c.buf[1+len(packet):]); err != nil {
return err
}
c.buf = c.aead.Seal(c.buf[:0], c.iv, c.buf, c.prefix[:])
if _, err := w.Write(c.buf); err != nil {
return err
}
c.incIV()
return nil
}
func (c *gcmCipher) incIV() {
for i := 4 + 7; i >= 4; i-- {
c.iv[i]++
if c.iv[i] != 0 {
break
}
}
}
func (c *gcmCipher) readPacket(seqNum uint32, r io.Reader) ([]byte, error) {
if _, err := io.ReadFull(r, c.prefix[:]); err != nil {
return nil, err
}
length := binary.BigEndian.Uint32(c.prefix[:])
if length > maxPacket {
return nil, errors.New("ssh: max packet length exceeded.")
}
if cap(c.buf) < int(length+gcmTagSize) {
c.buf = make([]byte, length+gcmTagSize)
} else {
c.buf = c.buf[:length+gcmTagSize]
}
if _, err := io.ReadFull(r, c.buf); err != nil {
return nil, err
}
plain, err := c.aead.Open(c.buf[:0], c.iv, c.buf, c.prefix[:])
if err != nil {
return nil, err
}
c.incIV()
padding := plain[0]
if padding < 4 || padding >= 20 {
return nil, fmt.Errorf("ssh: illegal padding %d", padding)
}
if int(padding+1) >= len(plain) {
return nil, fmt.Errorf("ssh: padding %d too large", padding)
}
plain = plain[1 : length-uint32(padding)]
return plain, nil
}

View file

@ -0,0 +1,59 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ssh
import (
"bytes"
"crypto"
"crypto/rand"
"testing"
)
func TestDefaultCiphersExist(t *testing.T) {
for _, cipherAlgo := range supportedCiphers {
if _, ok := cipherModes[cipherAlgo]; !ok {
t.Errorf("default cipher %q is unknown", cipherAlgo)
}
}
}
func TestPacketCiphers(t *testing.T) {
for cipher := range cipherModes {
kr := &kexResult{Hash: crypto.SHA1}
algs := directionAlgorithms{
Cipher: cipher,
MAC: "hmac-sha1",
Compression: "none",
}
client, err := newPacketCipher(clientKeys, algs, kr)
if err != nil {
t.Errorf("newPacketCipher(client, %q): %v", cipher, err)
continue
}
server, err := newPacketCipher(clientKeys, algs, kr)
if err != nil {
t.Errorf("newPacketCipher(client, %q): %v", cipher, err)
continue
}
want := "bla bla"
input := []byte(want)
buf := &bytes.Buffer{}
if err := client.writePacket(0, buf, rand.Reader, input); err != nil {
t.Errorf("writePacket(%q): %v", cipher, err)
continue
}
packet, err := server.readPacket(0, buf)
if err != nil {
t.Errorf("readPacket(%q): %v", cipher, err)
continue
}
if string(packet) != want {
t.Errorf("roundtrip(%q): got %q, want %q", cipher, packet, want)
}
}
}

View file

@ -0,0 +1,202 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ssh
import (
"errors"
"fmt"
"net"
"sync"
)
// Client implements a traditional SSH client that supports shells,
// subprocesses, port forwarding and tunneled dialing.
type Client struct {
Conn
forwards forwardList // forwarded tcpip connections from the remote side
mu sync.Mutex
channelHandlers map[string]chan NewChannel
}
// HandleChannelOpen returns a channel on which NewChannel requests
// for the given type are sent. If the type already is being handled,
// nil is returned. The channel is closed when the connection is closed.
func (c *Client) HandleChannelOpen(channelType string) <-chan NewChannel {
c.mu.Lock()
defer c.mu.Unlock()
if c.channelHandlers == nil {
// The SSH channel has been closed.
c := make(chan NewChannel)
close(c)
return c
}
ch := c.channelHandlers[channelType]
if ch != nil {
return nil
}
ch = make(chan NewChannel, 16)
c.channelHandlers[channelType] = ch
return ch
}
// NewClient creates a Client on top of the given connection.
func NewClient(c Conn, chans <-chan NewChannel, reqs <-chan *Request) *Client {
conn := &Client{
Conn: c,
channelHandlers: make(map[string]chan NewChannel, 1),
}
go conn.handleGlobalRequests(reqs)
go conn.handleChannelOpens(chans)
go func() {
conn.Wait()
conn.forwards.closeAll()
}()
go conn.forwards.handleChannels(conn.HandleChannelOpen("forwarded-tcpip"))
return conn
}
// NewClientConn establishes an authenticated SSH connection using c
// as the underlying transport. The Request and NewChannel channels
// must be serviced or the connection will hang.
func NewClientConn(c net.Conn, addr string, config *ClientConfig) (Conn, <-chan NewChannel, <-chan *Request, error) {
fullConf := *config
fullConf.SetDefaults()
conn := &connection{
sshConn: sshConn{conn: c},
}
if err := conn.clientHandshake(addr, &fullConf); err != nil {
c.Close()
return nil, nil, nil, fmt.Errorf("ssh: handshake failed: %v", err)
}
conn.mux = newMux(conn.transport)
return conn, conn.mux.incomingChannels, conn.mux.incomingRequests, nil
}
// clientHandshake performs the client side key exchange. See RFC 4253 Section
// 7.
func (c *connection) clientHandshake(dialAddress string, config *ClientConfig) error {
c.clientVersion = []byte(packageVersion)
if config.ClientVersion != "" {
c.clientVersion = []byte(config.ClientVersion)
}
var err error
c.serverVersion, err = exchangeVersions(c.sshConn.conn, c.clientVersion)
if err != nil {
return err
}
c.transport = newClientTransport(
newTransport(c.sshConn.conn, config.Rand, true /* is client */),
c.clientVersion, c.serverVersion, config, dialAddress, c.sshConn.RemoteAddr())
if err := c.transport.requestKeyChange(); err != nil {
return err
}
if packet, err := c.transport.readPacket(); err != nil {
return err
} else if packet[0] != msgNewKeys {
return unexpectedMessageError(msgNewKeys, packet[0])
}
return c.clientAuthenticate(config)
}
// verifyHostKeySignature verifies the host key obtained in the key
// exchange.
func verifyHostKeySignature(hostKey PublicKey, result *kexResult) error {
sig, rest, ok := parseSignatureBody(result.Signature)
if len(rest) > 0 || !ok {
return errors.New("ssh: signature parse error")
}
return hostKey.Verify(result.H, sig)
}
// NewSession opens a new Session for this client. (A session is a remote
// execution of a program.)
func (c *Client) NewSession() (*Session, error) {
ch, in, err := c.OpenChannel("session", nil)
if err != nil {
return nil, err
}
return newSession(ch, in)
}
func (c *Client) handleGlobalRequests(incoming <-chan *Request) {
for r := range incoming {
// This handles keepalive messages and matches
// the behaviour of OpenSSH.
r.Reply(false, nil)
}
}
// handleChannelOpens channel open messages from the remote side.
func (c *Client) handleChannelOpens(in <-chan NewChannel) {
for ch := range in {
c.mu.Lock()
handler := c.channelHandlers[ch.ChannelType()]
c.mu.Unlock()
if handler != nil {
handler <- ch
} else {
ch.Reject(UnknownChannelType, fmt.Sprintf("unknown channel type: %v", ch.ChannelType()))
}
}
c.mu.Lock()
for _, ch := range c.channelHandlers {
close(ch)
}
c.channelHandlers = nil
c.mu.Unlock()
}
// Dial starts a client connection to the given SSH server. It is a
// convenience function that connects to the given network address,
// initiates the SSH handshake, and then sets up a Client. For access
// to incoming channels and requests, use net.Dial with NewClientConn
// instead.
func Dial(network, addr string, config *ClientConfig) (*Client, error) {
conn, err := net.Dial(network, addr)
if err != nil {
return nil, err
}
c, chans, reqs, err := NewClientConn(conn, addr, config)
if err != nil {
return nil, err
}
return NewClient(c, chans, reqs), nil
}
// A ClientConfig structure is used to configure a Client. It must not be
// modified after having been passed to an SSH function.
type ClientConfig struct {
// Config contains configuration that is shared between clients and
// servers.
Config
// User contains the username to authenticate as.
User string
// Auth contains possible authentication methods to use with the
// server. Only the first instance of a particular RFC 4252 method will
// be used during authentication.
Auth []AuthMethod
// HostKeyCallback, if not nil, is called during the cryptographic
// handshake to validate the server's host key. A nil HostKeyCallback
// implies that all host keys are accepted.
HostKeyCallback func(hostname string, remote net.Addr, key PublicKey) error
// ClientVersion contains the version identification string that will
// be used for the connection. If empty, a reasonable default is used.
ClientVersion string
}

View file

@ -0,0 +1,441 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ssh
import (
"bytes"
"errors"
"fmt"
"io"
)
// clientAuthenticate authenticates with the remote server. See RFC 4252.
func (c *connection) clientAuthenticate(config *ClientConfig) error {
// initiate user auth session
if err := c.transport.writePacket(Marshal(&serviceRequestMsg{serviceUserAuth})); err != nil {
return err
}
packet, err := c.transport.readPacket()
if err != nil {
return err
}
var serviceAccept serviceAcceptMsg
if err := Unmarshal(packet, &serviceAccept); err != nil {
return err
}
// during the authentication phase the client first attempts the "none" method
// then any untried methods suggested by the server.
tried := make(map[string]bool)
var lastMethods []string
for auth := AuthMethod(new(noneAuth)); auth != nil; {
ok, methods, err := auth.auth(c.transport.getSessionID(), config.User, c.transport, config.Rand)
if err != nil {
return err
}
if ok {
// success
return nil
}
tried[auth.method()] = true
if methods == nil {
methods = lastMethods
}
lastMethods = methods
auth = nil
findNext:
for _, a := range config.Auth {
candidateMethod := a.method()
if tried[candidateMethod] {
continue
}
for _, meth := range methods {
if meth == candidateMethod {
auth = a
break findNext
}
}
}
}
return fmt.Errorf("ssh: unable to authenticate, attempted methods %v, no supported methods remain", keys(tried))
}
func keys(m map[string]bool) []string {
s := make([]string, 0, len(m))
for key := range m {
s = append(s, key)
}
return s
}
// An AuthMethod represents an instance of an RFC 4252 authentication method.
type AuthMethod interface {
// auth authenticates user over transport t.
// Returns true if authentication is successful.
// If authentication is not successful, a []string of alternative
// method names is returned. If the slice is nil, it will be ignored
// and the previous set of possible methods will be reused.
auth(session []byte, user string, p packetConn, rand io.Reader) (bool, []string, error)
// method returns the RFC 4252 method name.
method() string
}
// "none" authentication, RFC 4252 section 5.2.
type noneAuth int
func (n *noneAuth) auth(session []byte, user string, c packetConn, rand io.Reader) (bool, []string, error) {
if err := c.writePacket(Marshal(&userAuthRequestMsg{
User: user,
Service: serviceSSH,
Method: "none",
})); err != nil {
return false, nil, err
}
return handleAuthResponse(c)
}
func (n *noneAuth) method() string {
return "none"
}
// passwordCallback is an AuthMethod that fetches the password through
// a function call, e.g. by prompting the user.
type passwordCallback func() (password string, err error)
func (cb passwordCallback) auth(session []byte, user string, c packetConn, rand io.Reader) (bool, []string, error) {
type passwordAuthMsg struct {
User string `sshtype:"50"`
Service string
Method string
Reply bool
Password string
}
pw, err := cb()
// REVIEW NOTE: is there a need to support skipping a password attempt?
// The program may only find out that the user doesn't have a password
// when prompting.
if err != nil {
return false, nil, err
}
if err := c.writePacket(Marshal(&passwordAuthMsg{
User: user,
Service: serviceSSH,
Method: cb.method(),
Reply: false,
Password: pw,
})); err != nil {
return false, nil, err
}
return handleAuthResponse(c)
}
func (cb passwordCallback) method() string {
return "password"
}
// Password returns an AuthMethod using the given password.
func Password(secret string) AuthMethod {
return passwordCallback(func() (string, error) { return secret, nil })
}
// PasswordCallback returns an AuthMethod that uses a callback for
// fetching a password.
func PasswordCallback(prompt func() (secret string, err error)) AuthMethod {
return passwordCallback(prompt)
}
type publickeyAuthMsg struct {
User string `sshtype:"50"`
Service string
Method string
// HasSig indicates to the receiver packet that the auth request is signed and
// should be used for authentication of the request.
HasSig bool
Algoname string
PubKey []byte
// Sig is tagged with "rest" so Marshal will exclude it during
// validateKey
Sig []byte `ssh:"rest"`
}
// publicKeyCallback is an AuthMethod that uses a set of key
// pairs for authentication.
type publicKeyCallback func() ([]Signer, error)
func (cb publicKeyCallback) method() string {
return "publickey"
}
func (cb publicKeyCallback) auth(session []byte, user string, c packetConn, rand io.Reader) (bool, []string, error) {
// Authentication is performed in two stages. The first stage sends an
// enquiry to test if each key is acceptable to the remote. The second
// stage attempts to authenticate with the valid keys obtained in the
// first stage.
signers, err := cb()
if err != nil {
return false, nil, err
}
var validKeys []Signer
for _, signer := range signers {
if ok, err := validateKey(signer.PublicKey(), user, c); ok {
validKeys = append(validKeys, signer)
} else {
if err != nil {
return false, nil, err
}
}
}
// methods that may continue if this auth is not successful.
var methods []string
for _, signer := range validKeys {
pub := signer.PublicKey()
pubKey := pub.Marshal()
sign, err := signer.Sign(rand, buildDataSignedForAuth(session, userAuthRequestMsg{
User: user,
Service: serviceSSH,
Method: cb.method(),
}, []byte(pub.Type()), pubKey))
if err != nil {
return false, nil, err
}
// manually wrap the serialized signature in a string
s := Marshal(sign)
sig := make([]byte, stringLength(len(s)))
marshalString(sig, s)
msg := publickeyAuthMsg{
User: user,
Service: serviceSSH,
Method: cb.method(),
HasSig: true,
Algoname: pub.Type(),
PubKey: pubKey,
Sig: sig,
}
p := Marshal(&msg)
if err := c.writePacket(p); err != nil {
return false, nil, err
}
var success bool
success, methods, err = handleAuthResponse(c)
if err != nil {
return false, nil, err
}
if success {
return success, methods, err
}
}
return false, methods, nil
}
// validateKey validates the key provided is acceptable to the server.
func validateKey(key PublicKey, user string, c packetConn) (bool, error) {
pubKey := key.Marshal()
msg := publickeyAuthMsg{
User: user,
Service: serviceSSH,
Method: "publickey",
HasSig: false,
Algoname: key.Type(),
PubKey: pubKey,
}
if err := c.writePacket(Marshal(&msg)); err != nil {
return false, err
}
return confirmKeyAck(key, c)
}
func confirmKeyAck(key PublicKey, c packetConn) (bool, error) {
pubKey := key.Marshal()
algoname := key.Type()
for {
packet, err := c.readPacket()
if err != nil {
return false, err
}
switch packet[0] {
case msgUserAuthBanner:
// TODO(gpaul): add callback to present the banner to the user
case msgUserAuthPubKeyOk:
var msg userAuthPubKeyOkMsg
if err := Unmarshal(packet, &msg); err != nil {
return false, err
}
if msg.Algo != algoname || !bytes.Equal(msg.PubKey, pubKey) {
return false, nil
}
return true, nil
case msgUserAuthFailure:
return false, nil
default:
return false, unexpectedMessageError(msgUserAuthSuccess, packet[0])
}
}
}
// PublicKeys returns an AuthMethod that uses the given key
// pairs.
func PublicKeys(signers ...Signer) AuthMethod {
return publicKeyCallback(func() ([]Signer, error) { return signers, nil })
}
// PublicKeysCallback returns an AuthMethod that runs the given
// function to obtain a list of key pairs.
func PublicKeysCallback(getSigners func() (signers []Signer, err error)) AuthMethod {
return publicKeyCallback(getSigners)
}
// handleAuthResponse returns whether the preceding authentication request succeeded
// along with a list of remaining authentication methods to try next and
// an error if an unexpected response was received.
func handleAuthResponse(c packetConn) (bool, []string, error) {
for {
packet, err := c.readPacket()
if err != nil {
return false, nil, err
}
switch packet[0] {
case msgUserAuthBanner:
// TODO: add callback to present the banner to the user
case msgUserAuthFailure:
var msg userAuthFailureMsg
if err := Unmarshal(packet, &msg); err != nil {
return false, nil, err
}
return false, msg.Methods, nil
case msgUserAuthSuccess:
return true, nil, nil
case msgDisconnect:
return false, nil, io.EOF
default:
return false, nil, unexpectedMessageError(msgUserAuthSuccess, packet[0])
}
}
}
// KeyboardInteractiveChallenge should print questions, optionally
// disabling echoing (e.g. for passwords), and return all the answers.
// Challenge may be called multiple times in a single session. After
// successful authentication, the server may send a challenge with no
// questions, for which the user and instruction messages should be
// printed. RFC 4256 section 3.3 details how the UI should behave for
// both CLI and GUI environments.
type KeyboardInteractiveChallenge func(user, instruction string, questions []string, echos []bool) (answers []string, err error)
// KeyboardInteractive returns a AuthMethod using a prompt/response
// sequence controlled by the server.
func KeyboardInteractive(challenge KeyboardInteractiveChallenge) AuthMethod {
return challenge
}
func (cb KeyboardInteractiveChallenge) method() string {
return "keyboard-interactive"
}
func (cb KeyboardInteractiveChallenge) auth(session []byte, user string, c packetConn, rand io.Reader) (bool, []string, error) {
type initiateMsg struct {
User string `sshtype:"50"`
Service string
Method string
Language string
Submethods string
}
if err := c.writePacket(Marshal(&initiateMsg{
User: user,
Service: serviceSSH,
Method: "keyboard-interactive",
})); err != nil {
return false, nil, err
}
for {
packet, err := c.readPacket()
if err != nil {
return false, nil, err
}
// like handleAuthResponse, but with less options.
switch packet[0] {
case msgUserAuthBanner:
// TODO: Print banners during userauth.
continue
case msgUserAuthInfoRequest:
// OK
case msgUserAuthFailure:
var msg userAuthFailureMsg
if err := Unmarshal(packet, &msg); err != nil {
return false, nil, err
}
return false, msg.Methods, nil
case msgUserAuthSuccess:
return true, nil, nil
default:
return false, nil, unexpectedMessageError(msgUserAuthInfoRequest, packet[0])
}
var msg userAuthInfoRequestMsg
if err := Unmarshal(packet, &msg); err != nil {
return false, nil, err
}
// Manually unpack the prompt/echo pairs.
rest := msg.Prompts
var prompts []string
var echos []bool
for i := 0; i < int(msg.NumPrompts); i++ {
prompt, r, ok := parseString(rest)
if !ok || len(r) == 0 {
return false, nil, errors.New("ssh: prompt format error")
}
prompts = append(prompts, string(prompt))
echos = append(echos, r[0] != 0)
rest = r[1:]
}
if len(rest) != 0 {
return false, nil, errors.New("ssh: extra data following keyboard-interactive pairs")
}
answers, err := cb(msg.User, msg.Instruction, prompts, echos)
if err != nil {
return false, nil, err
}
if len(answers) != len(prompts) {
return false, nil, errors.New("ssh: not enough answers from keyboard-interactive callback")
}
responseLength := 1 + 4
for _, a := range answers {
responseLength += stringLength(len(a))
}
serialized := make([]byte, responseLength)
p := serialized
p[0] = msgUserAuthInfoResponse
p = p[1:]
p = marshalUint32(p, uint32(len(answers)))
for _, a := range answers {
p = marshalString(p, []byte(a))
}
if err := c.writePacket(serialized); err != nil {
return false, nil, err
}
}
}

View file

@ -0,0 +1,393 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ssh
import (
"bytes"
"crypto/rand"
"errors"
"fmt"
"strings"
"testing"
)
type keyboardInteractive map[string]string
func (cr keyboardInteractive) Challenge(user string, instruction string, questions []string, echos []bool) ([]string, error) {
var answers []string
for _, q := range questions {
answers = append(answers, cr[q])
}
return answers, nil
}
// reused internally by tests
var clientPassword = "tiger"
// tryAuth runs a handshake with a given config against an SSH server
// with config serverConfig
func tryAuth(t *testing.T, config *ClientConfig) error {
c1, c2, err := netPipe()
if err != nil {
t.Fatalf("netPipe: %v", err)
}
defer c1.Close()
defer c2.Close()
certChecker := CertChecker{
IsAuthority: func(k PublicKey) bool {
return bytes.Equal(k.Marshal(), testPublicKeys["ecdsa"].Marshal())
},
UserKeyFallback: func(conn ConnMetadata, key PublicKey) (*Permissions, error) {
if conn.User() == "testuser" && bytes.Equal(key.Marshal(), testPublicKeys["rsa"].Marshal()) {
return nil, nil
}
return nil, fmt.Errorf("pubkey for %q not acceptable", conn.User())
},
IsRevoked: func(c *Certificate) bool {
return c.Serial == 666
},
}
serverConfig := &ServerConfig{
PasswordCallback: func(conn ConnMetadata, pass []byte) (*Permissions, error) {
if conn.User() == "testuser" && string(pass) == clientPassword {
return nil, nil
}
return nil, errors.New("password auth failed")
},
PublicKeyCallback: certChecker.Authenticate,
KeyboardInteractiveCallback: func(conn ConnMetadata, challenge KeyboardInteractiveChallenge) (*Permissions, error) {
ans, err := challenge("user",
"instruction",
[]string{"question1", "question2"},
[]bool{true, true})
if err != nil {
return nil, err
}
ok := conn.User() == "testuser" && ans[0] == "answer1" && ans[1] == "answer2"
if ok {
challenge("user", "motd", nil, nil)
return nil, nil
}
return nil, errors.New("keyboard-interactive failed")
},
AuthLogCallback: func(conn ConnMetadata, method string, err error) {
t.Logf("user %q, method %q: %v", conn.User(), method, err)
},
}
serverConfig.AddHostKey(testSigners["rsa"])
go newServer(c1, serverConfig)
_, _, _, err = NewClientConn(c2, "", config)
return err
}
func TestClientAuthPublicKey(t *testing.T) {
config := &ClientConfig{
User: "testuser",
Auth: []AuthMethod{
PublicKeys(testSigners["rsa"]),
},
}
if err := tryAuth(t, config); err != nil {
t.Fatalf("unable to dial remote side: %s", err)
}
}
func TestAuthMethodPassword(t *testing.T) {
config := &ClientConfig{
User: "testuser",
Auth: []AuthMethod{
Password(clientPassword),
},
}
if err := tryAuth(t, config); err != nil {
t.Fatalf("unable to dial remote side: %s", err)
}
}
func TestAuthMethodFallback(t *testing.T) {
var passwordCalled bool
config := &ClientConfig{
User: "testuser",
Auth: []AuthMethod{
PublicKeys(testSigners["rsa"]),
PasswordCallback(
func() (string, error) {
passwordCalled = true
return "WRONG", nil
}),
},
}
if err := tryAuth(t, config); err != nil {
t.Fatalf("unable to dial remote side: %s", err)
}
if passwordCalled {
t.Errorf("password auth tried before public-key auth.")
}
}
func TestAuthMethodWrongPassword(t *testing.T) {
config := &ClientConfig{
User: "testuser",
Auth: []AuthMethod{
Password("wrong"),
PublicKeys(testSigners["rsa"]),
},
}
if err := tryAuth(t, config); err != nil {
t.Fatalf("unable to dial remote side: %s", err)
}
}
func TestAuthMethodKeyboardInteractive(t *testing.T) {
answers := keyboardInteractive(map[string]string{
"question1": "answer1",
"question2": "answer2",
})
config := &ClientConfig{
User: "testuser",
Auth: []AuthMethod{
KeyboardInteractive(answers.Challenge),
},
}
if err := tryAuth(t, config); err != nil {
t.Fatalf("unable to dial remote side: %s", err)
}
}
func TestAuthMethodWrongKeyboardInteractive(t *testing.T) {
answers := keyboardInteractive(map[string]string{
"question1": "answer1",
"question2": "WRONG",
})
config := &ClientConfig{
User: "testuser",
Auth: []AuthMethod{
KeyboardInteractive(answers.Challenge),
},
}
if err := tryAuth(t, config); err == nil {
t.Fatalf("wrong answers should not have authenticated with KeyboardInteractive")
}
}
// the mock server will only authenticate ssh-rsa keys
func TestAuthMethodInvalidPublicKey(t *testing.T) {
config := &ClientConfig{
User: "testuser",
Auth: []AuthMethod{
PublicKeys(testSigners["dsa"]),
},
}
if err := tryAuth(t, config); err == nil {
t.Fatalf("dsa private key should not have authenticated with rsa public key")
}
}
// the client should authenticate with the second key
func TestAuthMethodRSAandDSA(t *testing.T) {
config := &ClientConfig{
User: "testuser",
Auth: []AuthMethod{
PublicKeys(testSigners["dsa"], testSigners["rsa"]),
},
}
if err := tryAuth(t, config); err != nil {
t.Fatalf("client could not authenticate with rsa key: %v", err)
}
}
func TestClientHMAC(t *testing.T) {
for _, mac := range supportedMACs {
config := &ClientConfig{
User: "testuser",
Auth: []AuthMethod{
PublicKeys(testSigners["rsa"]),
},
Config: Config{
MACs: []string{mac},
},
}
if err := tryAuth(t, config); err != nil {
t.Fatalf("client could not authenticate with mac algo %s: %v", mac, err)
}
}
}
// issue 4285.
func TestClientUnsupportedCipher(t *testing.T) {
config := &ClientConfig{
User: "testuser",
Auth: []AuthMethod{
PublicKeys(),
},
Config: Config{
Ciphers: []string{"aes128-cbc"}, // not currently supported
},
}
if err := tryAuth(t, config); err == nil {
t.Errorf("expected no ciphers in common")
}
}
func TestClientUnsupportedKex(t *testing.T) {
config := &ClientConfig{
User: "testuser",
Auth: []AuthMethod{
PublicKeys(),
},
Config: Config{
KeyExchanges: []string{"diffie-hellman-group-exchange-sha256"}, // not currently supported
},
}
if err := tryAuth(t, config); err == nil || !strings.Contains(err.Error(), "no common algorithms") {
t.Errorf("got %v, expected 'no common algorithms'", err)
}
}
func TestClientLoginCert(t *testing.T) {
cert := &Certificate{
Key: testPublicKeys["rsa"],
ValidBefore: CertTimeInfinity,
CertType: UserCert,
}
cert.SignCert(rand.Reader, testSigners["ecdsa"])
certSigner, err := NewCertSigner(cert, testSigners["rsa"])
if err != nil {
t.Fatalf("NewCertSigner: %v", err)
}
clientConfig := &ClientConfig{
User: "user",
}
clientConfig.Auth = append(clientConfig.Auth, PublicKeys(certSigner))
t.Log("should succeed")
if err := tryAuth(t, clientConfig); err != nil {
t.Errorf("cert login failed: %v", err)
}
t.Log("corrupted signature")
cert.Signature.Blob[0]++
if err := tryAuth(t, clientConfig); err == nil {
t.Errorf("cert login passed with corrupted sig")
}
t.Log("revoked")
cert.Serial = 666
cert.SignCert(rand.Reader, testSigners["ecdsa"])
if err := tryAuth(t, clientConfig); err == nil {
t.Errorf("revoked cert login succeeded")
}
cert.Serial = 1
t.Log("sign with wrong key")
cert.SignCert(rand.Reader, testSigners["dsa"])
if err := tryAuth(t, clientConfig); err == nil {
t.Errorf("cert login passed with non-authoritive key")
}
t.Log("host cert")
cert.CertType = HostCert
cert.SignCert(rand.Reader, testSigners["ecdsa"])
if err := tryAuth(t, clientConfig); err == nil {
t.Errorf("cert login passed with wrong type")
}
cert.CertType = UserCert
t.Log("principal specified")
cert.ValidPrincipals = []string{"user"}
cert.SignCert(rand.Reader, testSigners["ecdsa"])
if err := tryAuth(t, clientConfig); err != nil {
t.Errorf("cert login failed: %v", err)
}
t.Log("wrong principal specified")
cert.ValidPrincipals = []string{"fred"}
cert.SignCert(rand.Reader, testSigners["ecdsa"])
if err := tryAuth(t, clientConfig); err == nil {
t.Errorf("cert login passed with wrong principal")
}
cert.ValidPrincipals = nil
t.Log("added critical option")
cert.CriticalOptions = map[string]string{"root-access": "yes"}
cert.SignCert(rand.Reader, testSigners["ecdsa"])
if err := tryAuth(t, clientConfig); err == nil {
t.Errorf("cert login passed with unrecognized critical option")
}
t.Log("allowed source address")
cert.CriticalOptions = map[string]string{"source-address": "127.0.0.42/24"}
cert.SignCert(rand.Reader, testSigners["ecdsa"])
if err := tryAuth(t, clientConfig); err != nil {
t.Errorf("cert login with source-address failed: %v", err)
}
t.Log("disallowed source address")
cert.CriticalOptions = map[string]string{"source-address": "127.0.0.42"}
cert.SignCert(rand.Reader, testSigners["ecdsa"])
if err := tryAuth(t, clientConfig); err == nil {
t.Errorf("cert login with source-address succeeded")
}
}
func testPermissionsPassing(withPermissions bool, t *testing.T) {
serverConfig := &ServerConfig{
PublicKeyCallback: func(conn ConnMetadata, key PublicKey) (*Permissions, error) {
if conn.User() == "nopermissions" {
return nil, nil
} else {
return &Permissions{}, nil
}
},
}
serverConfig.AddHostKey(testSigners["rsa"])
clientConfig := &ClientConfig{
Auth: []AuthMethod{
PublicKeys(testSigners["rsa"]),
},
}
if withPermissions {
clientConfig.User = "permissions"
} else {
clientConfig.User = "nopermissions"
}
c1, c2, err := netPipe()
if err != nil {
t.Fatalf("netPipe: %v", err)
}
defer c1.Close()
defer c2.Close()
go NewClientConn(c2, "", clientConfig)
serverConn, err := newServer(c1, serverConfig)
if err != nil {
t.Fatal(err)
}
if p := serverConn.Permissions; (p != nil) != withPermissions {
t.Fatalf("withPermissions is %t, but Permissions object is %#v", withPermissions, p)
}
}
func TestPermissionsPassing(t *testing.T) {
testPermissionsPassing(true, t)
}
func TestNoPermissionsPassing(t *testing.T) {
testPermissionsPassing(false, t)
}

View file

@ -0,0 +1,39 @@
// Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ssh
import (
"net"
"testing"
)
func testClientVersion(t *testing.T, config *ClientConfig, expected string) {
clientConn, serverConn := net.Pipe()
defer clientConn.Close()
receivedVersion := make(chan string, 1)
go func() {
version, err := readVersion(serverConn)
if err != nil {
receivedVersion <- ""
} else {
receivedVersion <- string(version)
}
serverConn.Close()
}()
NewClientConn(clientConn, "", config)
actual := <-receivedVersion
if actual != expected {
t.Fatalf("got %s; want %s", actual, expected)
}
}
func TestCustomClientVersion(t *testing.T) {
version := "Test-Client-Version-0.0"
testClientVersion(t, &ClientConfig{ClientVersion: version}, version)
}
func TestDefaultClientVersion(t *testing.T) {
testClientVersion(t, &ClientConfig{}, packageVersion)
}

View file

@ -0,0 +1,357 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ssh
import (
"crypto"
"crypto/rand"
"fmt"
"io"
"sync"
_ "crypto/sha1"
_ "crypto/sha256"
_ "crypto/sha512"
)
// These are string constants in the SSH protocol.
const (
compressionNone = "none"
serviceUserAuth = "ssh-userauth"
serviceSSH = "ssh-connection"
)
// supportedCiphers specifies the supported ciphers in preference order.
var supportedCiphers = []string{
"aes128-ctr", "aes192-ctr", "aes256-ctr",
"aes128-gcm@openssh.com",
"arcfour256", "arcfour128",
}
// supportedKexAlgos specifies the supported key-exchange algorithms in
// preference order.
var supportedKexAlgos = []string{
// P384 and P521 are not constant-time yet, but since we don't
// reuse ephemeral keys, using them for ECDH should be OK.
kexAlgoECDH256, kexAlgoECDH384, kexAlgoECDH521,
kexAlgoDH14SHA1, kexAlgoDH1SHA1,
}
// supportedKexAlgos specifies the supported host-key algorithms (i.e. methods
// of authenticating servers) in preference order.
var supportedHostKeyAlgos = []string{
CertAlgoRSAv01, CertAlgoDSAv01, CertAlgoECDSA256v01,
CertAlgoECDSA384v01, CertAlgoECDSA521v01,
KeyAlgoECDSA256, KeyAlgoECDSA384, KeyAlgoECDSA521,
KeyAlgoRSA, KeyAlgoDSA,
}
// supportedMACs specifies a default set of MAC algorithms in preference order.
// This is based on RFC 4253, section 6.4, but with hmac-md5 variants removed
// because they have reached the end of their useful life.
var supportedMACs = []string{
"hmac-sha1", "hmac-sha1-96",
}
var supportedCompressions = []string{compressionNone}
// hashFuncs keeps the mapping of supported algorithms to their respective
// hashes needed for signature verification.
var hashFuncs = map[string]crypto.Hash{
KeyAlgoRSA: crypto.SHA1,
KeyAlgoDSA: crypto.SHA1,
KeyAlgoECDSA256: crypto.SHA256,
KeyAlgoECDSA384: crypto.SHA384,
KeyAlgoECDSA521: crypto.SHA512,
CertAlgoRSAv01: crypto.SHA1,
CertAlgoDSAv01: crypto.SHA1,
CertAlgoECDSA256v01: crypto.SHA256,
CertAlgoECDSA384v01: crypto.SHA384,
CertAlgoECDSA521v01: crypto.SHA512,
}
// unexpectedMessageError results when the SSH message that we received didn't
// match what we wanted.
func unexpectedMessageError(expected, got uint8) error {
return fmt.Errorf("ssh: unexpected message type %d (expected %d)", got, expected)
}
// parseError results from a malformed SSH message.
func parseError(tag uint8) error {
return fmt.Errorf("ssh: parse error in message type %d", tag)
}
func findCommonAlgorithm(clientAlgos []string, serverAlgos []string) (commonAlgo string, ok bool) {
for _, clientAlgo := range clientAlgos {
for _, serverAlgo := range serverAlgos {
if clientAlgo == serverAlgo {
return clientAlgo, true
}
}
}
return
}
func findCommonCipher(clientCiphers []string, serverCiphers []string) (commonCipher string, ok bool) {
for _, clientCipher := range clientCiphers {
for _, serverCipher := range serverCiphers {
// reject the cipher if we have no cipherModes definition
if clientCipher == serverCipher && cipherModes[clientCipher] != nil {
return clientCipher, true
}
}
}
return
}
type directionAlgorithms struct {
Cipher string
MAC string
Compression string
}
type algorithms struct {
kex string
hostKey string
w directionAlgorithms
r directionAlgorithms
}
func findAgreedAlgorithms(clientKexInit, serverKexInit *kexInitMsg) (algs *algorithms) {
var ok bool
result := &algorithms{}
result.kex, ok = findCommonAlgorithm(clientKexInit.KexAlgos, serverKexInit.KexAlgos)
if !ok {
return
}
result.hostKey, ok = findCommonAlgorithm(clientKexInit.ServerHostKeyAlgos, serverKexInit.ServerHostKeyAlgos)
if !ok {
return
}
result.w.Cipher, ok = findCommonCipher(clientKexInit.CiphersClientServer, serverKexInit.CiphersClientServer)
if !ok {
return
}
result.r.Cipher, ok = findCommonCipher(clientKexInit.CiphersServerClient, serverKexInit.CiphersServerClient)
if !ok {
return
}
result.w.MAC, ok = findCommonAlgorithm(clientKexInit.MACsClientServer, serverKexInit.MACsClientServer)
if !ok {
return
}
result.r.MAC, ok = findCommonAlgorithm(clientKexInit.MACsServerClient, serverKexInit.MACsServerClient)
if !ok {
return
}
result.w.Compression, ok = findCommonAlgorithm(clientKexInit.CompressionClientServer, serverKexInit.CompressionClientServer)
if !ok {
return
}
result.r.Compression, ok = findCommonAlgorithm(clientKexInit.CompressionServerClient, serverKexInit.CompressionServerClient)
if !ok {
return
}
return result
}
// If rekeythreshold is too small, we can't make any progress sending
// stuff.
const minRekeyThreshold uint64 = 256
// Config contains configuration data common to both ServerConfig and
// ClientConfig.
type Config struct {
// Rand provides the source of entropy for cryptographic
// primitives. If Rand is nil, the cryptographic random reader
// in package crypto/rand will be used.
Rand io.Reader
// The maximum number of bytes sent or received after which a
// new key is negotiated. It must be at least 256. If
// unspecified, 1 gigabyte is used.
RekeyThreshold uint64
// The allowed key exchanges algorithms. If unspecified then a
// default set of algorithms is used.
KeyExchanges []string
// The allowed cipher algorithms. If unspecified then a sensible
// default is used.
Ciphers []string
// The allowed MAC algorithms. If unspecified then a sensible default
// is used.
MACs []string
}
// SetDefaults sets sensible values for unset fields in config. This is
// exported for testing: Configs passed to SSH functions are copied and have
// default values set automatically.
func (c *Config) SetDefaults() {
if c.Rand == nil {
c.Rand = rand.Reader
}
if c.Ciphers == nil {
c.Ciphers = supportedCiphers
}
if c.KeyExchanges == nil {
c.KeyExchanges = supportedKexAlgos
}
if c.MACs == nil {
c.MACs = supportedMACs
}
if c.RekeyThreshold == 0 {
// RFC 4253, section 9 suggests rekeying after 1G.
c.RekeyThreshold = 1 << 30
}
if c.RekeyThreshold < minRekeyThreshold {
c.RekeyThreshold = minRekeyThreshold
}
}
// buildDataSignedForAuth returns the data that is signed in order to prove
// possession of a private key. See RFC 4252, section 7.
func buildDataSignedForAuth(sessionId []byte, req userAuthRequestMsg, algo, pubKey []byte) []byte {
data := struct {
Session []byte
Type byte
User string
Service string
Method string
Sign bool
Algo []byte
PubKey []byte
}{
sessionId,
msgUserAuthRequest,
req.User,
req.Service,
req.Method,
true,
algo,
pubKey,
}
return Marshal(data)
}
func appendU16(buf []byte, n uint16) []byte {
return append(buf, byte(n>>8), byte(n))
}
func appendU32(buf []byte, n uint32) []byte {
return append(buf, byte(n>>24), byte(n>>16), byte(n>>8), byte(n))
}
func appendU64(buf []byte, n uint64) []byte {
return append(buf,
byte(n>>56), byte(n>>48), byte(n>>40), byte(n>>32),
byte(n>>24), byte(n>>16), byte(n>>8), byte(n))
}
func appendInt(buf []byte, n int) []byte {
return appendU32(buf, uint32(n))
}
func appendString(buf []byte, s string) []byte {
buf = appendU32(buf, uint32(len(s)))
buf = append(buf, s...)
return buf
}
func appendBool(buf []byte, b bool) []byte {
if b {
return append(buf, 1)
}
return append(buf, 0)
}
// newCond is a helper to hide the fact that there is no usable zero
// value for sync.Cond.
func newCond() *sync.Cond { return sync.NewCond(new(sync.Mutex)) }
// window represents the buffer available to clients
// wishing to write to a channel.
type window struct {
*sync.Cond
win uint32 // RFC 4254 5.2 says the window size can grow to 2^32-1
writeWaiters int
closed bool
}
// add adds win to the amount of window available
// for consumers.
func (w *window) add(win uint32) bool {
// a zero sized window adjust is a noop.
if win == 0 {
return true
}
w.L.Lock()
if w.win+win < win {
w.L.Unlock()
return false
}
w.win += win
// It is unusual that multiple goroutines would be attempting to reserve
// window space, but not guaranteed. Use broadcast to notify all waiters
// that additional window is available.
w.Broadcast()
w.L.Unlock()
return true
}
// close sets the window to closed, so all reservations fail
// immediately.
func (w *window) close() {
w.L.Lock()
w.closed = true
w.Broadcast()
w.L.Unlock()
}
// reserve reserves win from the available window capacity.
// If no capacity remains, reserve will block. reserve may
// return less than requested.
func (w *window) reserve(win uint32) (uint32, error) {
var err error
w.L.Lock()
w.writeWaiters++
w.Broadcast()
for w.win == 0 && !w.closed {
w.Wait()
}
w.writeWaiters--
if w.win < win {
win = w.win
}
w.win -= win
if w.closed {
err = io.EOF
}
w.L.Unlock()
return win, err
}
// waitWriterBlocked waits until some goroutine is blocked for further
// writes. It is used in tests only.
func (w *window) waitWriterBlocked() {
w.Cond.L.Lock()
for w.writeWaiters == 0 {
w.Cond.Wait()
}
w.Cond.L.Unlock()
}

View file

@ -0,0 +1,144 @@
// Copyright 2013 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ssh
import (
"fmt"
"net"
)
// OpenChannelError is returned if the other side rejects an
// OpenChannel request.
type OpenChannelError struct {
Reason RejectionReason
Message string
}
func (e *OpenChannelError) Error() string {
return fmt.Sprintf("ssh: rejected: %s (%s)", e.Reason, e.Message)
}
// ConnMetadata holds metadata for the connection.
type ConnMetadata interface {
// User returns the user ID for this connection.
// It is empty if no authentication is used.
User() string
// SessionID returns the sesson hash, also denoted by H.
SessionID() []byte
// ClientVersion returns the client's version string as hashed
// into the session ID.
ClientVersion() []byte
// ServerVersion returns the client's version string as hashed
// into the session ID.
ServerVersion() []byte
// RemoteAddr returns the remote address for this connection.
RemoteAddr() net.Addr
// LocalAddr returns the local address for this connection.
LocalAddr() net.Addr
}
// Conn represents an SSH connection for both server and client roles.
// Conn is the basis for implementing an application layer, such
// as ClientConn, which implements the traditional shell access for
// clients.
type Conn interface {
ConnMetadata
// SendRequest sends a global request, and returns the
// reply. If wantReply is true, it returns the response status
// and payload. See also RFC4254, section 4.
SendRequest(name string, wantReply bool, payload []byte) (bool, []byte, error)
// OpenChannel tries to open an channel. If the request is
// rejected, it returns *OpenChannelError. On success it returns
// the SSH Channel and a Go channel for incoming, out-of-band
// requests. The Go channel must be serviced, or the
// connection will hang.
OpenChannel(name string, data []byte) (Channel, <-chan *Request, error)
// Close closes the underlying network connection
Close() error
// Wait blocks until the connection has shut down, and returns the
// error causing the shutdown.
Wait() error
// TODO(hanwen): consider exposing:
// RequestKeyChange
// Disconnect
}
// DiscardRequests consumes and rejects all requests from the
// passed-in channel.
func DiscardRequests(in <-chan *Request) {
for req := range in {
if req.WantReply {
req.Reply(false, nil)
}
}
}
// A connection represents an incoming connection.
type connection struct {
transport *handshakeTransport
sshConn
// The connection protocol.
*mux
}
func (c *connection) Close() error {
return c.sshConn.conn.Close()
}
// sshconn provides net.Conn metadata, but disallows direct reads and
// writes.
type sshConn struct {
conn net.Conn
user string
sessionID []byte
clientVersion []byte
serverVersion []byte
}
func dup(src []byte) []byte {
dst := make([]byte, len(src))
copy(dst, src)
return dst
}
func (c *sshConn) User() string {
return c.user
}
func (c *sshConn) RemoteAddr() net.Addr {
return c.conn.RemoteAddr()
}
func (c *sshConn) Close() error {
return c.conn.Close()
}
func (c *sshConn) LocalAddr() net.Addr {
return c.conn.LocalAddr()
}
func (c *sshConn) SessionID() []byte {
return dup(c.sessionID)
}
func (c *sshConn) ClientVersion() []byte {
return dup(c.clientVersion)
}
func (c *sshConn) ServerVersion() []byte {
return dup(c.serverVersion)
}

View file

@ -0,0 +1,18 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
/*
Package ssh implements an SSH client and server.
SSH is a transport security protocol, an authentication protocol and a
family of application protocols. The most typical application level
protocol is a remote shell and this is specifically implemented. However,
the multiplexed nature of SSH is exposed to users that wish to support
others.
References:
[PROTOCOL.certkeys]: http://www.openbsd.org/cgi-bin/cvsweb/src/usr.bin/ssh/PROTOCOL.certkeys
[SSH-PARAMETERS]: http://www.iana.org/assignments/ssh-parameters/ssh-parameters.xml#ssh-parameters-1
*/
package ssh

View file

@ -0,0 +1,210 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ssh
import (
"bytes"
"fmt"
"io/ioutil"
"log"
"net"
"net/http"
"golang.org/x/crypto/ssh/terminal"
)
func ExampleNewServerConn() {
// An SSH server is represented by a ServerConfig, which holds
// certificate details and handles authentication of ServerConns.
config := &ServerConfig{
PasswordCallback: func(c ConnMetadata, pass []byte) (*Permissions, error) {
// Should use constant-time compare (or better, salt+hash) in
// a production setting.
if c.User() == "testuser" && string(pass) == "tiger" {
return nil, nil
}
return nil, fmt.Errorf("password rejected for %q", c.User())
},
}
privateBytes, err := ioutil.ReadFile("id_rsa")
if err != nil {
panic("Failed to load private key")
}
private, err := ParsePrivateKey(privateBytes)
if err != nil {
panic("Failed to parse private key")
}
config.AddHostKey(private)
// Once a ServerConfig has been configured, connections can be
// accepted.
listener, err := net.Listen("tcp", "0.0.0.0:2022")
if err != nil {
panic("failed to listen for connection")
}
nConn, err := listener.Accept()
if err != nil {
panic("failed to accept incoming connection")
}
// Before use, a handshake must be performed on the incoming
// net.Conn.
_, chans, reqs, err := NewServerConn(nConn, config)
if err != nil {
panic("failed to handshake")
}
// The incoming Request channel must be serviced.
go DiscardRequests(reqs)
// Service the incoming Channel channel.
for newChannel := range chans {
// Channels have a type, depending on the application level
// protocol intended. In the case of a shell, the type is
// "session" and ServerShell may be used to present a simple
// terminal interface.
if newChannel.ChannelType() != "session" {
newChannel.Reject(UnknownChannelType, "unknown channel type")
continue
}
channel, requests, err := newChannel.Accept()
if err != nil {
panic("could not accept channel.")
}
// Sessions have out-of-band requests such as "shell",
// "pty-req" and "env". Here we handle only the
// "shell" request.
go func(in <-chan *Request) {
for req := range in {
ok := false
switch req.Type {
case "shell":
ok = true
if len(req.Payload) > 0 {
// We don't accept any
// commands, only the
// default shell.
ok = false
}
}
req.Reply(ok, nil)
}
}(requests)
term := terminal.NewTerminal(channel, "> ")
go func() {
defer channel.Close()
for {
line, err := term.ReadLine()
if err != nil {
break
}
fmt.Println(line)
}
}()
}
}
func ExampleDial() {
// An SSH client is represented with a ClientConn. Currently only
// the "password" authentication method is supported.
//
// To authenticate with the remote server you must pass at least one
// implementation of AuthMethod via the Auth field in ClientConfig.
config := &ClientConfig{
User: "username",
Auth: []AuthMethod{
Password("yourpassword"),
},
}
client, err := Dial("tcp", "yourserver.com:22", config)
if err != nil {
panic("Failed to dial: " + err.Error())
}
// Each ClientConn can support multiple interactive sessions,
// represented by a Session.
session, err := client.NewSession()
if err != nil {
panic("Failed to create session: " + err.Error())
}
defer session.Close()
// Once a Session is created, you can execute a single command on
// the remote side using the Run method.
var b bytes.Buffer
session.Stdout = &b
if err := session.Run("/usr/bin/whoami"); err != nil {
panic("Failed to run: " + err.Error())
}
fmt.Println(b.String())
}
func ExampleClient_Listen() {
config := &ClientConfig{
User: "username",
Auth: []AuthMethod{
Password("password"),
},
}
// Dial your ssh server.
conn, err := Dial("tcp", "localhost:22", config)
if err != nil {
log.Fatalf("unable to connect: %s", err)
}
defer conn.Close()
// Request the remote side to open port 8080 on all interfaces.
l, err := conn.Listen("tcp", "0.0.0.0:8080")
if err != nil {
log.Fatalf("unable to register tcp forward: %v", err)
}
defer l.Close()
// Serve HTTP with your SSH server acting as a reverse proxy.
http.Serve(l, http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) {
fmt.Fprintf(resp, "Hello world!\n")
}))
}
func ExampleSession_RequestPty() {
// Create client config
config := &ClientConfig{
User: "username",
Auth: []AuthMethod{
Password("password"),
},
}
// Connect to ssh server
conn, err := Dial("tcp", "localhost:22", config)
if err != nil {
log.Fatalf("unable to connect: %s", err)
}
defer conn.Close()
// Create a session
session, err := conn.NewSession()
if err != nil {
log.Fatalf("unable to create session: %s", err)
}
defer session.Close()
// Set up terminal modes
modes := TerminalModes{
ECHO: 0, // disable echoing
TTY_OP_ISPEED: 14400, // input speed = 14.4kbaud
TTY_OP_OSPEED: 14400, // output speed = 14.4kbaud
}
// Request pseudo terminal
if err := session.RequestPty("xterm", 80, 40, modes); err != nil {
log.Fatalf("request for pseudo terminal failed: %s", err)
}
// Start remote shell
if err := session.Shell(); err != nil {
log.Fatalf("failed to start shell: %s", err)
}
}

View file

@ -0,0 +1,393 @@
// Copyright 2013 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ssh
import (
"crypto/rand"
"errors"
"fmt"
"io"
"log"
"net"
"sync"
)
// debugHandshake, if set, prints messages sent and received. Key
// exchange messages are printed as if DH were used, so the debug
// messages are wrong when using ECDH.
const debugHandshake = false
// keyingTransport is a packet based transport that supports key
// changes. It need not be thread-safe. It should pass through
// msgNewKeys in both directions.
type keyingTransport interface {
packetConn
// prepareKeyChange sets up a key change. The key change for a
// direction will be effected if a msgNewKeys message is sent
// or received.
prepareKeyChange(*algorithms, *kexResult) error
// getSessionID returns the session ID. prepareKeyChange must
// have been called once.
getSessionID() []byte
}
// rekeyingTransport is the interface of handshakeTransport that we
// (internally) expose to ClientConn and ServerConn.
type rekeyingTransport interface {
packetConn
// requestKeyChange asks the remote side to change keys. All
// writes are blocked until the key change succeeds, which is
// signaled by reading a msgNewKeys.
requestKeyChange() error
// getSessionID returns the session ID. This is only valid
// after the first key change has completed.
getSessionID() []byte
}
// handshakeTransport implements rekeying on top of a keyingTransport
// and offers a thread-safe writePacket() interface.
type handshakeTransport struct {
conn keyingTransport
config *Config
serverVersion []byte
clientVersion []byte
hostKeys []Signer // If hostKeys are given, we are the server.
// On read error, incoming is closed, and readError is set.
incoming chan []byte
readError error
// data for host key checking
hostKeyCallback func(hostname string, remote net.Addr, key PublicKey) error
dialAddress string
remoteAddr net.Addr
readSinceKex uint64
// Protects the writing side of the connection
mu sync.Mutex
cond *sync.Cond
sentInitPacket []byte
sentInitMsg *kexInitMsg
writtenSinceKex uint64
writeError error
}
func newHandshakeTransport(conn keyingTransport, config *Config, clientVersion, serverVersion []byte) *handshakeTransport {
t := &handshakeTransport{
conn: conn,
serverVersion: serverVersion,
clientVersion: clientVersion,
incoming: make(chan []byte, 16),
config: config,
}
t.cond = sync.NewCond(&t.mu)
return t
}
func newClientTransport(conn keyingTransport, clientVersion, serverVersion []byte, config *ClientConfig, dialAddr string, addr net.Addr) *handshakeTransport {
t := newHandshakeTransport(conn, &config.Config, clientVersion, serverVersion)
t.dialAddress = dialAddr
t.remoteAddr = addr
t.hostKeyCallback = config.HostKeyCallback
go t.readLoop()
return t
}
func newServerTransport(conn keyingTransport, clientVersion, serverVersion []byte, config *ServerConfig) *handshakeTransport {
t := newHandshakeTransport(conn, &config.Config, clientVersion, serverVersion)
t.hostKeys = config.hostKeys
go t.readLoop()
return t
}
func (t *handshakeTransport) getSessionID() []byte {
return t.conn.getSessionID()
}
func (t *handshakeTransport) id() string {
if len(t.hostKeys) > 0 {
return "server"
}
return "client"
}
func (t *handshakeTransport) readPacket() ([]byte, error) {
p, ok := <-t.incoming
if !ok {
return nil, t.readError
}
return p, nil
}
func (t *handshakeTransport) readLoop() {
for {
p, err := t.readOnePacket()
if err != nil {
t.readError = err
close(t.incoming)
break
}
if p[0] == msgIgnore || p[0] == msgDebug {
continue
}
t.incoming <- p
}
}
func (t *handshakeTransport) readOnePacket() ([]byte, error) {
if t.readSinceKex > t.config.RekeyThreshold {
if err := t.requestKeyChange(); err != nil {
return nil, err
}
}
p, err := t.conn.readPacket()
if err != nil {
return nil, err
}
t.readSinceKex += uint64(len(p))
if debugHandshake {
msg, err := decode(p)
log.Printf("%s got %T %v (%v)", t.id(), msg, msg, err)
}
if p[0] != msgKexInit {
return p, nil
}
err = t.enterKeyExchange(p)
t.mu.Lock()
if err != nil {
// drop connection
t.conn.Close()
t.writeError = err
}
if debugHandshake {
log.Printf("%s exited key exchange, err %v", t.id(), err)
}
// Unblock writers.
t.sentInitMsg = nil
t.sentInitPacket = nil
t.cond.Broadcast()
t.writtenSinceKex = 0
t.mu.Unlock()
if err != nil {
return nil, err
}
t.readSinceKex = 0
return []byte{msgNewKeys}, nil
}
// sendKexInit sends a key change message, and returns the message
// that was sent. After initiating the key change, all writes will be
// blocked until the change is done, and a failed key change will
// close the underlying transport. This function is safe for
// concurrent use by multiple goroutines.
func (t *handshakeTransport) sendKexInit() (*kexInitMsg, []byte, error) {
t.mu.Lock()
defer t.mu.Unlock()
return t.sendKexInitLocked()
}
func (t *handshakeTransport) requestKeyChange() error {
_, _, err := t.sendKexInit()
return err
}
// sendKexInitLocked sends a key change message. t.mu must be locked
// while this happens.
func (t *handshakeTransport) sendKexInitLocked() (*kexInitMsg, []byte, error) {
// kexInits may be sent either in response to the other side,
// or because our side wants to initiate a key change, so we
// may have already sent a kexInit. In that case, don't send a
// second kexInit.
if t.sentInitMsg != nil {
return t.sentInitMsg, t.sentInitPacket, nil
}
msg := &kexInitMsg{
KexAlgos: t.config.KeyExchanges,
CiphersClientServer: t.config.Ciphers,
CiphersServerClient: t.config.Ciphers,
MACsClientServer: t.config.MACs,
MACsServerClient: t.config.MACs,
CompressionClientServer: supportedCompressions,
CompressionServerClient: supportedCompressions,
}
io.ReadFull(rand.Reader, msg.Cookie[:])
if len(t.hostKeys) > 0 {
for _, k := range t.hostKeys {
msg.ServerHostKeyAlgos = append(
msg.ServerHostKeyAlgos, k.PublicKey().Type())
}
} else {
msg.ServerHostKeyAlgos = supportedHostKeyAlgos
}
packet := Marshal(msg)
// writePacket destroys the contents, so save a copy.
packetCopy := make([]byte, len(packet))
copy(packetCopy, packet)
if err := t.conn.writePacket(packetCopy); err != nil {
return nil, nil, err
}
t.sentInitMsg = msg
t.sentInitPacket = packet
return msg, packet, nil
}
func (t *handshakeTransport) writePacket(p []byte) error {
t.mu.Lock()
if t.writtenSinceKex > t.config.RekeyThreshold {
t.sendKexInitLocked()
}
for t.sentInitMsg != nil {
t.cond.Wait()
}
if t.writeError != nil {
return t.writeError
}
t.writtenSinceKex += uint64(len(p))
var err error
switch p[0] {
case msgKexInit:
err = errors.New("ssh: only handshakeTransport can send kexInit")
case msgNewKeys:
err = errors.New("ssh: only handshakeTransport can send newKeys")
default:
err = t.conn.writePacket(p)
}
t.mu.Unlock()
return err
}
func (t *handshakeTransport) Close() error {
return t.conn.Close()
}
// enterKeyExchange runs the key exchange.
func (t *handshakeTransport) enterKeyExchange(otherInitPacket []byte) error {
if debugHandshake {
log.Printf("%s entered key exchange", t.id())
}
myInit, myInitPacket, err := t.sendKexInit()
if err != nil {
return err
}
otherInit := &kexInitMsg{}
if err := Unmarshal(otherInitPacket, otherInit); err != nil {
return err
}
magics := handshakeMagics{
clientVersion: t.clientVersion,
serverVersion: t.serverVersion,
clientKexInit: otherInitPacket,
serverKexInit: myInitPacket,
}
clientInit := otherInit
serverInit := myInit
if len(t.hostKeys) == 0 {
clientInit = myInit
serverInit = otherInit
magics.clientKexInit = myInitPacket
magics.serverKexInit = otherInitPacket
}
algs := findAgreedAlgorithms(clientInit, serverInit)
if algs == nil {
return errors.New("ssh: no common algorithms")
}
// We don't send FirstKexFollows, but we handle receiving it.
if otherInit.FirstKexFollows && algs.kex != otherInit.KexAlgos[0] {
// other side sent a kex message for the wrong algorithm,
// which we have to ignore.
if _, err := t.conn.readPacket(); err != nil {
return err
}
}
kex, ok := kexAlgoMap[algs.kex]
if !ok {
return fmt.Errorf("ssh: unexpected key exchange algorithm %v", algs.kex)
}
var result *kexResult
if len(t.hostKeys) > 0 {
result, err = t.server(kex, algs, &magics)
} else {
result, err = t.client(kex, algs, &magics)
}
if err != nil {
return err
}
t.conn.prepareKeyChange(algs, result)
if err = t.conn.writePacket([]byte{msgNewKeys}); err != nil {
return err
}
if packet, err := t.conn.readPacket(); err != nil {
return err
} else if packet[0] != msgNewKeys {
return unexpectedMessageError(msgNewKeys, packet[0])
}
return nil
}
func (t *handshakeTransport) server(kex kexAlgorithm, algs *algorithms, magics *handshakeMagics) (*kexResult, error) {
var hostKey Signer
for _, k := range t.hostKeys {
if algs.hostKey == k.PublicKey().Type() {
hostKey = k
}
}
r, err := kex.Server(t.conn, t.config.Rand, magics, hostKey)
return r, err
}
func (t *handshakeTransport) client(kex kexAlgorithm, algs *algorithms, magics *handshakeMagics) (*kexResult, error) {
result, err := kex.Client(t.conn, t.config.Rand, magics)
if err != nil {
return nil, err
}
hostKey, err := ParsePublicKey(result.HostKey)
if err != nil {
return nil, err
}
if err := verifyHostKeySignature(hostKey, result); err != nil {
return nil, err
}
if t.hostKeyCallback != nil {
err = t.hostKeyCallback(t.dialAddress, t.remoteAddr, hostKey)
if err != nil {
return nil, err
}
}
return result, nil
}

View file

@ -0,0 +1,311 @@
// Copyright 2013 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ssh
import (
"bytes"
"crypto/rand"
"fmt"
"net"
"testing"
)
type testChecker struct {
calls []string
}
func (t *testChecker) Check(dialAddr string, addr net.Addr, key PublicKey) error {
if dialAddr == "bad" {
return fmt.Errorf("dialAddr is bad")
}
if tcpAddr, ok := addr.(*net.TCPAddr); !ok || tcpAddr == nil {
return fmt.Errorf("testChecker: got %T want *net.TCPAddr", addr)
}
t.calls = append(t.calls, fmt.Sprintf("%s %v %s %x", dialAddr, addr, key.Type(), key.Marshal()))
return nil
}
// netPipe is analogous to net.Pipe, but it uses a real net.Conn, and
// therefore is buffered (net.Pipe deadlocks if both sides start with
// a write.)
func netPipe() (net.Conn, net.Conn, error) {
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
return nil, nil, err
}
defer listener.Close()
c1, err := net.Dial("tcp", listener.Addr().String())
if err != nil {
return nil, nil, err
}
c2, err := listener.Accept()
if err != nil {
c1.Close()
return nil, nil, err
}
return c1, c2, nil
}
func handshakePair(clientConf *ClientConfig, addr string) (client *handshakeTransport, server *handshakeTransport, err error) {
a, b, err := netPipe()
if err != nil {
return nil, nil, err
}
trC := newTransport(a, rand.Reader, true)
trS := newTransport(b, rand.Reader, false)
clientConf.SetDefaults()
v := []byte("version")
client = newClientTransport(trC, v, v, clientConf, addr, a.RemoteAddr())
serverConf := &ServerConfig{}
serverConf.AddHostKey(testSigners["ecdsa"])
serverConf.SetDefaults()
server = newServerTransport(trS, v, v, serverConf)
return client, server, nil
}
func TestHandshakeBasic(t *testing.T) {
checker := &testChecker{}
trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "addr")
if err != nil {
t.Fatalf("handshakePair: %v", err)
}
defer trC.Close()
defer trS.Close()
go func() {
// Client writes a bunch of stuff, and does a key
// change in the middle. This should not confuse the
// handshake in progress
for i := 0; i < 10; i++ {
p := []byte{msgRequestSuccess, byte(i)}
if err := trC.writePacket(p); err != nil {
t.Fatalf("sendPacket: %v", err)
}
if i == 5 {
// halfway through, we request a key change.
_, _, err := trC.sendKexInit()
if err != nil {
t.Fatalf("sendKexInit: %v", err)
}
}
}
trC.Close()
}()
// Server checks that client messages come in cleanly
i := 0
for {
p, err := trS.readPacket()
if err != nil {
break
}
if p[0] == msgNewKeys {
continue
}
want := []byte{msgRequestSuccess, byte(i)}
if bytes.Compare(p, want) != 0 {
t.Errorf("message %d: got %q, want %q", i, p, want)
}
i++
}
if i != 10 {
t.Errorf("received %d messages, want 10.", i)
}
// If all went well, we registered exactly 1 key change.
if len(checker.calls) != 1 {
t.Fatalf("got %d host key checks, want 1", len(checker.calls))
}
pub := testSigners["ecdsa"].PublicKey()
want := fmt.Sprintf("%s %v %s %x", "addr", trC.remoteAddr, pub.Type(), pub.Marshal())
if want != checker.calls[0] {
t.Errorf("got %q want %q for host key check", checker.calls[0], want)
}
}
func TestHandshakeError(t *testing.T) {
checker := &testChecker{}
trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "bad")
if err != nil {
t.Fatalf("handshakePair: %v", err)
}
defer trC.Close()
defer trS.Close()
// send a packet
packet := []byte{msgRequestSuccess, 42}
if err := trC.writePacket(packet); err != nil {
t.Errorf("writePacket: %v", err)
}
// Now request a key change.
_, _, err = trC.sendKexInit()
if err != nil {
t.Errorf("sendKexInit: %v", err)
}
// the key change will fail, and afterwards we can't write.
if err := trC.writePacket([]byte{msgRequestSuccess, 43}); err == nil {
t.Errorf("writePacket after botched rekey succeeded.")
}
readback, err := trS.readPacket()
if err != nil {
t.Fatalf("server closed too soon: %v", err)
}
if bytes.Compare(readback, packet) != 0 {
t.Errorf("got %q want %q", readback, packet)
}
readback, err = trS.readPacket()
if err == nil {
t.Errorf("got a message %q after failed key change", readback)
}
}
func TestHandshakeTwice(t *testing.T) {
checker := &testChecker{}
trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "addr")
if err != nil {
t.Fatalf("handshakePair: %v", err)
}
defer trC.Close()
defer trS.Close()
// send a packet
packet := make([]byte, 5)
packet[0] = msgRequestSuccess
if err := trC.writePacket(packet); err != nil {
t.Errorf("writePacket: %v", err)
}
// Now request a key change.
_, _, err = trC.sendKexInit()
if err != nil {
t.Errorf("sendKexInit: %v", err)
}
// Send another packet. Use a fresh one, since writePacket destroys.
packet = make([]byte, 5)
packet[0] = msgRequestSuccess
if err := trC.writePacket(packet); err != nil {
t.Errorf("writePacket: %v", err)
}
// 2nd key change.
_, _, err = trC.sendKexInit()
if err != nil {
t.Errorf("sendKexInit: %v", err)
}
packet = make([]byte, 5)
packet[0] = msgRequestSuccess
if err := trC.writePacket(packet); err != nil {
t.Errorf("writePacket: %v", err)
}
packet = make([]byte, 5)
packet[0] = msgRequestSuccess
for i := 0; i < 5; i++ {
msg, err := trS.readPacket()
if err != nil {
t.Fatalf("server closed too soon: %v", err)
}
if msg[0] == msgNewKeys {
continue
}
if bytes.Compare(msg, packet) != 0 {
t.Errorf("packet %d: got %q want %q", i, msg, packet)
}
}
if len(checker.calls) != 2 {
t.Errorf("got %d key changes, want 2", len(checker.calls))
}
}
func TestHandshakeAutoRekeyWrite(t *testing.T) {
checker := &testChecker{}
clientConf := &ClientConfig{HostKeyCallback: checker.Check}
clientConf.RekeyThreshold = 500
trC, trS, err := handshakePair(clientConf, "addr")
if err != nil {
t.Fatalf("handshakePair: %v", err)
}
defer trC.Close()
defer trS.Close()
for i := 0; i < 5; i++ {
packet := make([]byte, 251)
packet[0] = msgRequestSuccess
if err := trC.writePacket(packet); err != nil {
t.Errorf("writePacket: %v", err)
}
}
j := 0
for ; j < 5; j++ {
_, err := trS.readPacket()
if err != nil {
break
}
}
if j != 5 {
t.Errorf("got %d, want 5 messages", j)
}
if len(checker.calls) != 2 {
t.Errorf("got %d key changes, wanted 2", len(checker.calls))
}
}
type syncChecker struct {
called chan int
}
func (t *syncChecker) Check(dialAddr string, addr net.Addr, key PublicKey) error {
t.called <- 1
return nil
}
func TestHandshakeAutoRekeyRead(t *testing.T) {
sync := &syncChecker{make(chan int, 2)}
clientConf := &ClientConfig{
HostKeyCallback: sync.Check,
}
clientConf.RekeyThreshold = 500
trC, trS, err := handshakePair(clientConf, "addr")
if err != nil {
t.Fatalf("handshakePair: %v", err)
}
defer trC.Close()
defer trS.Close()
packet := make([]byte, 501)
packet[0] = msgRequestSuccess
if err := trS.writePacket(packet); err != nil {
t.Fatalf("writePacket: %v", err)
}
// While we read out the packet, a key change will be
// initiated.
if _, err := trC.readPacket(); err != nil {
t.Fatalf("readPacket(client): %v", err)
}
<-sync.called
}

View file

@ -0,0 +1,386 @@
// Copyright 2013 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ssh
import (
"crypto"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"errors"
"io"
"math/big"
)
const (
kexAlgoDH1SHA1 = "diffie-hellman-group1-sha1"
kexAlgoDH14SHA1 = "diffie-hellman-group14-sha1"
kexAlgoECDH256 = "ecdh-sha2-nistp256"
kexAlgoECDH384 = "ecdh-sha2-nistp384"
kexAlgoECDH521 = "ecdh-sha2-nistp521"
)
// kexResult captures the outcome of a key exchange.
type kexResult struct {
// Session hash. See also RFC 4253, section 8.
H []byte
// Shared secret. See also RFC 4253, section 8.
K []byte
// Host key as hashed into H.
HostKey []byte
// Signature of H.
Signature []byte
// A cryptographic hash function that matches the security
// level of the key exchange algorithm. It is used for
// calculating H, and for deriving keys from H and K.
Hash crypto.Hash
// The session ID, which is the first H computed. This is used
// to signal data inside transport.
SessionID []byte
}
// handshakeMagics contains data that is always included in the
// session hash.
type handshakeMagics struct {
clientVersion, serverVersion []byte
clientKexInit, serverKexInit []byte
}
func (m *handshakeMagics) write(w io.Writer) {
writeString(w, m.clientVersion)
writeString(w, m.serverVersion)
writeString(w, m.clientKexInit)
writeString(w, m.serverKexInit)
}
// kexAlgorithm abstracts different key exchange algorithms.
type kexAlgorithm interface {
// Server runs server-side key agreement, signing the result
// with a hostkey.
Server(p packetConn, rand io.Reader, magics *handshakeMagics, s Signer) (*kexResult, error)
// Client runs the client-side key agreement. Caller is
// responsible for verifying the host key signature.
Client(p packetConn, rand io.Reader, magics *handshakeMagics) (*kexResult, error)
}
// dhGroup is a multiplicative group suitable for implementing Diffie-Hellman key agreement.
type dhGroup struct {
g, p *big.Int
}
func (group *dhGroup) diffieHellman(theirPublic, myPrivate *big.Int) (*big.Int, error) {
if theirPublic.Sign() <= 0 || theirPublic.Cmp(group.p) >= 0 {
return nil, errors.New("ssh: DH parameter out of bounds")
}
return new(big.Int).Exp(theirPublic, myPrivate, group.p), nil
}
func (group *dhGroup) Client(c packetConn, randSource io.Reader, magics *handshakeMagics) (*kexResult, error) {
hashFunc := crypto.SHA1
x, err := rand.Int(randSource, group.p)
if err != nil {
return nil, err
}
X := new(big.Int).Exp(group.g, x, group.p)
kexDHInit := kexDHInitMsg{
X: X,
}
if err := c.writePacket(Marshal(&kexDHInit)); err != nil {
return nil, err
}
packet, err := c.readPacket()
if err != nil {
return nil, err
}
var kexDHReply kexDHReplyMsg
if err = Unmarshal(packet, &kexDHReply); err != nil {
return nil, err
}
kInt, err := group.diffieHellman(kexDHReply.Y, x)
if err != nil {
return nil, err
}
h := hashFunc.New()
magics.write(h)
writeString(h, kexDHReply.HostKey)
writeInt(h, X)
writeInt(h, kexDHReply.Y)
K := make([]byte, intLength(kInt))
marshalInt(K, kInt)
h.Write(K)
return &kexResult{
H: h.Sum(nil),
K: K,
HostKey: kexDHReply.HostKey,
Signature: kexDHReply.Signature,
Hash: crypto.SHA1,
}, nil
}
func (group *dhGroup) Server(c packetConn, randSource io.Reader, magics *handshakeMagics, priv Signer) (result *kexResult, err error) {
hashFunc := crypto.SHA1
packet, err := c.readPacket()
if err != nil {
return
}
var kexDHInit kexDHInitMsg
if err = Unmarshal(packet, &kexDHInit); err != nil {
return
}
y, err := rand.Int(randSource, group.p)
if err != nil {
return
}
Y := new(big.Int).Exp(group.g, y, group.p)
kInt, err := group.diffieHellman(kexDHInit.X, y)
if err != nil {
return nil, err
}
hostKeyBytes := priv.PublicKey().Marshal()
h := hashFunc.New()
magics.write(h)
writeString(h, hostKeyBytes)
writeInt(h, kexDHInit.X)
writeInt(h, Y)
K := make([]byte, intLength(kInt))
marshalInt(K, kInt)
h.Write(K)
H := h.Sum(nil)
// H is already a hash, but the hostkey signing will apply its
// own key-specific hash algorithm.
sig, err := signAndMarshal(priv, randSource, H)
if err != nil {
return nil, err
}
kexDHReply := kexDHReplyMsg{
HostKey: hostKeyBytes,
Y: Y,
Signature: sig,
}
packet = Marshal(&kexDHReply)
err = c.writePacket(packet)
return &kexResult{
H: H,
K: K,
HostKey: hostKeyBytes,
Signature: sig,
Hash: crypto.SHA1,
}, nil
}
// ecdh performs Elliptic Curve Diffie-Hellman key exchange as
// described in RFC 5656, section 4.
type ecdh struct {
curve elliptic.Curve
}
func (kex *ecdh) Client(c packetConn, rand io.Reader, magics *handshakeMagics) (*kexResult, error) {
ephKey, err := ecdsa.GenerateKey(kex.curve, rand)
if err != nil {
return nil, err
}
kexInit := kexECDHInitMsg{
ClientPubKey: elliptic.Marshal(kex.curve, ephKey.PublicKey.X, ephKey.PublicKey.Y),
}
serialized := Marshal(&kexInit)
if err := c.writePacket(serialized); err != nil {
return nil, err
}
packet, err := c.readPacket()
if err != nil {
return nil, err
}
var reply kexECDHReplyMsg
if err = Unmarshal(packet, &reply); err != nil {
return nil, err
}
x, y, err := unmarshalECKey(kex.curve, reply.EphemeralPubKey)
if err != nil {
return nil, err
}
// generate shared secret
secret, _ := kex.curve.ScalarMult(x, y, ephKey.D.Bytes())
h := ecHash(kex.curve).New()
magics.write(h)
writeString(h, reply.HostKey)
writeString(h, kexInit.ClientPubKey)
writeString(h, reply.EphemeralPubKey)
K := make([]byte, intLength(secret))
marshalInt(K, secret)
h.Write(K)
return &kexResult{
H: h.Sum(nil),
K: K,
HostKey: reply.HostKey,
Signature: reply.Signature,
Hash: ecHash(kex.curve),
}, nil
}
// unmarshalECKey parses and checks an EC key.
func unmarshalECKey(curve elliptic.Curve, pubkey []byte) (x, y *big.Int, err error) {
x, y = elliptic.Unmarshal(curve, pubkey)
if x == nil {
return nil, nil, errors.New("ssh: elliptic.Unmarshal failure")
}
if !validateECPublicKey(curve, x, y) {
return nil, nil, errors.New("ssh: public key not on curve")
}
return x, y, nil
}
// validateECPublicKey checks that the point is a valid public key for
// the given curve. See [SEC1], 3.2.2
func validateECPublicKey(curve elliptic.Curve, x, y *big.Int) bool {
if x.Sign() == 0 && y.Sign() == 0 {
return false
}
if x.Cmp(curve.Params().P) >= 0 {
return false
}
if y.Cmp(curve.Params().P) >= 0 {
return false
}
if !curve.IsOnCurve(x, y) {
return false
}
// We don't check if N * PubKey == 0, since
//
// - the NIST curves have cofactor = 1, so this is implicit.
// (We don't foresee an implementation that supports non NIST
// curves)
//
// - for ephemeral keys, we don't need to worry about small
// subgroup attacks.
return true
}
func (kex *ecdh) Server(c packetConn, rand io.Reader, magics *handshakeMagics, priv Signer) (result *kexResult, err error) {
packet, err := c.readPacket()
if err != nil {
return nil, err
}
var kexECDHInit kexECDHInitMsg
if err = Unmarshal(packet, &kexECDHInit); err != nil {
return nil, err
}
clientX, clientY, err := unmarshalECKey(kex.curve, kexECDHInit.ClientPubKey)
if err != nil {
return nil, err
}
// We could cache this key across multiple users/multiple
// connection attempts, but the benefit is small. OpenSSH
// generates a new key for each incoming connection.
ephKey, err := ecdsa.GenerateKey(kex.curve, rand)
if err != nil {
return nil, err
}
hostKeyBytes := priv.PublicKey().Marshal()
serializedEphKey := elliptic.Marshal(kex.curve, ephKey.PublicKey.X, ephKey.PublicKey.Y)
// generate shared secret
secret, _ := kex.curve.ScalarMult(clientX, clientY, ephKey.D.Bytes())
h := ecHash(kex.curve).New()
magics.write(h)
writeString(h, hostKeyBytes)
writeString(h, kexECDHInit.ClientPubKey)
writeString(h, serializedEphKey)
K := make([]byte, intLength(secret))
marshalInt(K, secret)
h.Write(K)
H := h.Sum(nil)
// H is already a hash, but the hostkey signing will apply its
// own key-specific hash algorithm.
sig, err := signAndMarshal(priv, rand, H)
if err != nil {
return nil, err
}
reply := kexECDHReplyMsg{
EphemeralPubKey: serializedEphKey,
HostKey: hostKeyBytes,
Signature: sig,
}
serialized := Marshal(&reply)
if err := c.writePacket(serialized); err != nil {
return nil, err
}
return &kexResult{
H: H,
K: K,
HostKey: reply.HostKey,
Signature: sig,
Hash: ecHash(kex.curve),
}, nil
}
var kexAlgoMap = map[string]kexAlgorithm{}
func init() {
// This is the group called diffie-hellman-group1-sha1 in RFC
// 4253 and Oakley Group 2 in RFC 2409.
p, _ := new(big.Int).SetString("FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE65381FFFFFFFFFFFFFFFF", 16)
kexAlgoMap[kexAlgoDH1SHA1] = &dhGroup{
g: new(big.Int).SetInt64(2),
p: p,
}
// This is the group called diffie-hellman-group14-sha1 in RFC
// 4253 and Oakley Group 14 in RFC 3526.
p, _ = new(big.Int).SetString("FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3DC2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F83655D23DCA3AD961C62F356208552BB9ED529077096966D670C354E4ABC9804F1746C08CA18217C32905E462E36CE3BE39E772C180E86039B2783A2EC07A28FB5C55DF06F4C52C9DE2BCBF6955817183995497CEA956AE515D2261898FA051015728E5A8AACAA68FFFFFFFFFFFFFFFF", 16)
kexAlgoMap[kexAlgoDH14SHA1] = &dhGroup{
g: new(big.Int).SetInt64(2),
p: p,
}
kexAlgoMap[kexAlgoECDH521] = &ecdh{elliptic.P521()}
kexAlgoMap[kexAlgoECDH384] = &ecdh{elliptic.P384()}
kexAlgoMap[kexAlgoECDH256] = &ecdh{elliptic.P256()}
}

View file

@ -0,0 +1,48 @@
// Copyright 2013 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ssh
// Key exchange tests.
import (
"crypto/rand"
"reflect"
"testing"
)
func TestKexes(t *testing.T) {
type kexResultErr struct {
result *kexResult
err error
}
for name, kex := range kexAlgoMap {
a, b := memPipe()
s := make(chan kexResultErr, 1)
c := make(chan kexResultErr, 1)
var magics handshakeMagics
go func() {
r, e := kex.Client(a, rand.Reader, &magics)
c <- kexResultErr{r, e}
}()
go func() {
r, e := kex.Server(b, rand.Reader, &magics, testSigners["ecdsa"])
s <- kexResultErr{r, e}
}()
clientRes := <-c
serverRes := <-s
if clientRes.err != nil {
t.Errorf("client: %v", clientRes.err)
}
if serverRes.err != nil {
t.Errorf("server: %v", serverRes.err)
}
if !reflect.DeepEqual(clientRes.result, serverRes.result) {
t.Errorf("kex %q: mismatch %#v, %#v", name, clientRes.result, serverRes.result)
}
}
}

View file

@ -0,0 +1,628 @@
// Copyright 2012 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ssh
import (
"bytes"
"crypto"
"crypto/dsa"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rsa"
"crypto/x509"
"encoding/asn1"
"encoding/base64"
"encoding/pem"
"errors"
"fmt"
"io"
"math/big"
)
// These constants represent the algorithm names for key types supported by this
// package.
const (
KeyAlgoRSA = "ssh-rsa"
KeyAlgoDSA = "ssh-dss"
KeyAlgoECDSA256 = "ecdsa-sha2-nistp256"
KeyAlgoECDSA384 = "ecdsa-sha2-nistp384"
KeyAlgoECDSA521 = "ecdsa-sha2-nistp521"
)
// parsePubKey parses a public key of the given algorithm.
// Use ParsePublicKey for keys with prepended algorithm.
func parsePubKey(in []byte, algo string) (pubKey PublicKey, rest []byte, err error) {
switch algo {
case KeyAlgoRSA:
return parseRSA(in)
case KeyAlgoDSA:
return parseDSA(in)
case KeyAlgoECDSA256, KeyAlgoECDSA384, KeyAlgoECDSA521:
return parseECDSA(in)
case CertAlgoRSAv01, CertAlgoDSAv01, CertAlgoECDSA256v01, CertAlgoECDSA384v01, CertAlgoECDSA521v01:
cert, err := parseCert(in, certToPrivAlgo(algo))
if err != nil {
return nil, nil, err
}
return cert, nil, nil
}
return nil, nil, fmt.Errorf("ssh: unknown key algorithm: %v", err)
}
// parseAuthorizedKey parses a public key in OpenSSH authorized_keys format
// (see sshd(8) manual page) once the options and key type fields have been
// removed.
func parseAuthorizedKey(in []byte) (out PublicKey, comment string, err error) {
in = bytes.TrimSpace(in)
i := bytes.IndexAny(in, " \t")
if i == -1 {
i = len(in)
}
base64Key := in[:i]
key := make([]byte, base64.StdEncoding.DecodedLen(len(base64Key)))
n, err := base64.StdEncoding.Decode(key, base64Key)
if err != nil {
return nil, "", err
}
key = key[:n]
out, err = ParsePublicKey(key)
if err != nil {
return nil, "", err
}
comment = string(bytes.TrimSpace(in[i:]))
return out, comment, nil
}
// ParseAuthorizedKeys parses a public key from an authorized_keys
// file used in OpenSSH according to the sshd(8) manual page.
func ParseAuthorizedKey(in []byte) (out PublicKey, comment string, options []string, rest []byte, err error) {
for len(in) > 0 {
end := bytes.IndexByte(in, '\n')
if end != -1 {
rest = in[end+1:]
in = in[:end]
} else {
rest = nil
}
end = bytes.IndexByte(in, '\r')
if end != -1 {
in = in[:end]
}
in = bytes.TrimSpace(in)
if len(in) == 0 || in[0] == '#' {
in = rest
continue
}
i := bytes.IndexAny(in, " \t")
if i == -1 {
in = rest
continue
}
if out, comment, err = parseAuthorizedKey(in[i:]); err == nil {
return out, comment, options, rest, nil
}
// No key type recognised. Maybe there's an options field at
// the beginning.
var b byte
inQuote := false
var candidateOptions []string
optionStart := 0
for i, b = range in {
isEnd := !inQuote && (b == ' ' || b == '\t')
if (b == ',' && !inQuote) || isEnd {
if i-optionStart > 0 {
candidateOptions = append(candidateOptions, string(in[optionStart:i]))
}
optionStart = i + 1
}
if isEnd {
break
}
if b == '"' && (i == 0 || (i > 0 && in[i-1] != '\\')) {
inQuote = !inQuote
}
}
for i < len(in) && (in[i] == ' ' || in[i] == '\t') {
i++
}
if i == len(in) {
// Invalid line: unmatched quote
in = rest
continue
}
in = in[i:]
i = bytes.IndexAny(in, " \t")
if i == -1 {
in = rest
continue
}
if out, comment, err = parseAuthorizedKey(in[i:]); err == nil {
options = candidateOptions
return out, comment, options, rest, nil
}
in = rest
continue
}
return nil, "", nil, nil, errors.New("ssh: no key found")
}
// ParsePublicKey parses an SSH public key formatted for use in
// the SSH wire protocol according to RFC 4253, section 6.6.
func ParsePublicKey(in []byte) (out PublicKey, err error) {
algo, in, ok := parseString(in)
if !ok {
return nil, errShortRead
}
var rest []byte
out, rest, err = parsePubKey(in, string(algo))
if len(rest) > 0 {
return nil, errors.New("ssh: trailing junk in public key")
}
return out, err
}
// MarshalAuthorizedKey serializes key for inclusion in an OpenSSH
// authorized_keys file. The return value ends with newline.
func MarshalAuthorizedKey(key PublicKey) []byte {
b := &bytes.Buffer{}
b.WriteString(key.Type())
b.WriteByte(' ')
e := base64.NewEncoder(base64.StdEncoding, b)
e.Write(key.Marshal())
e.Close()
b.WriteByte('\n')
return b.Bytes()
}
// PublicKey is an abstraction of different types of public keys.
type PublicKey interface {
// Type returns the key's type, e.g. "ssh-rsa".
Type() string
// Marshal returns the serialized key data in SSH wire format,
// with the name prefix.
Marshal() []byte
// Verify that sig is a signature on the given data using this
// key. This function will hash the data appropriately first.
Verify(data []byte, sig *Signature) error
}
// A Signer can create signatures that verify against a public key.
type Signer interface {
// PublicKey returns an associated PublicKey instance.
PublicKey() PublicKey
// Sign returns raw signature for the given data. This method
// will apply the hash specified for the keytype to the data.
Sign(rand io.Reader, data []byte) (*Signature, error)
}
type rsaPublicKey rsa.PublicKey
func (r *rsaPublicKey) Type() string {
return "ssh-rsa"
}
// parseRSA parses an RSA key according to RFC 4253, section 6.6.
func parseRSA(in []byte) (out PublicKey, rest []byte, err error) {
var w struct {
E *big.Int
N *big.Int
Rest []byte `ssh:"rest"`
}
if err := Unmarshal(in, &w); err != nil {
return nil, nil, err
}
if w.E.BitLen() > 24 {
return nil, nil, errors.New("ssh: exponent too large")
}
e := w.E.Int64()
if e < 3 || e&1 == 0 {
return nil, nil, errors.New("ssh: incorrect exponent")
}
var key rsa.PublicKey
key.E = int(e)
key.N = w.N
return (*rsaPublicKey)(&key), w.Rest, nil
}
func (r *rsaPublicKey) Marshal() []byte {
e := new(big.Int).SetInt64(int64(r.E))
wirekey := struct {
Name string
E *big.Int
N *big.Int
}{
KeyAlgoRSA,
e,
r.N,
}
return Marshal(&wirekey)
}
func (r *rsaPublicKey) Verify(data []byte, sig *Signature) error {
if sig.Format != r.Type() {
return fmt.Errorf("ssh: signature type %s for key type %s", sig.Format, r.Type())
}
h := crypto.SHA1.New()
h.Write(data)
digest := h.Sum(nil)
return rsa.VerifyPKCS1v15((*rsa.PublicKey)(r), crypto.SHA1, digest, sig.Blob)
}
type rsaPrivateKey struct {
*rsa.PrivateKey
}
func (r *rsaPrivateKey) PublicKey() PublicKey {
return (*rsaPublicKey)(&r.PrivateKey.PublicKey)
}
func (r *rsaPrivateKey) Sign(rand io.Reader, data []byte) (*Signature, error) {
h := crypto.SHA1.New()
h.Write(data)
digest := h.Sum(nil)
blob, err := rsa.SignPKCS1v15(rand, r.PrivateKey, crypto.SHA1, digest)
if err != nil {
return nil, err
}
return &Signature{
Format: r.PublicKey().Type(),
Blob: blob,
}, nil
}
type dsaPublicKey dsa.PublicKey
func (r *dsaPublicKey) Type() string {
return "ssh-dss"
}
// parseDSA parses an DSA key according to RFC 4253, section 6.6.
func parseDSA(in []byte) (out PublicKey, rest []byte, err error) {
var w struct {
P, Q, G, Y *big.Int
Rest []byte `ssh:"rest"`
}
if err := Unmarshal(in, &w); err != nil {
return nil, nil, err
}
key := &dsaPublicKey{
Parameters: dsa.Parameters{
P: w.P,
Q: w.Q,
G: w.G,
},
Y: w.Y,
}
return key, w.Rest, nil
}
func (k *dsaPublicKey) Marshal() []byte {
w := struct {
Name string
P, Q, G, Y *big.Int
}{
k.Type(),
k.P,
k.Q,
k.G,
k.Y,
}
return Marshal(&w)
}
func (k *dsaPublicKey) Verify(data []byte, sig *Signature) error {
if sig.Format != k.Type() {
return fmt.Errorf("ssh: signature type %s for key type %s", sig.Format, k.Type())
}
h := crypto.SHA1.New()
h.Write(data)
digest := h.Sum(nil)
// Per RFC 4253, section 6.6,
// The value for 'dss_signature_blob' is encoded as a string containing
// r, followed by s (which are 160-bit integers, without lengths or
// padding, unsigned, and in network byte order).
// For DSS purposes, sig.Blob should be exactly 40 bytes in length.
if len(sig.Blob) != 40 {
return errors.New("ssh: DSA signature parse error")
}
r := new(big.Int).SetBytes(sig.Blob[:20])
s := new(big.Int).SetBytes(sig.Blob[20:])
if dsa.Verify((*dsa.PublicKey)(k), digest, r, s) {
return nil
}
return errors.New("ssh: signature did not verify")
}
type dsaPrivateKey struct {
*dsa.PrivateKey
}
func (k *dsaPrivateKey) PublicKey() PublicKey {
return (*dsaPublicKey)(&k.PrivateKey.PublicKey)
}
func (k *dsaPrivateKey) Sign(rand io.Reader, data []byte) (*Signature, error) {
h := crypto.SHA1.New()
h.Write(data)
digest := h.Sum(nil)
r, s, err := dsa.Sign(rand, k.PrivateKey, digest)
if err != nil {
return nil, err
}
sig := make([]byte, 40)
rb := r.Bytes()
sb := s.Bytes()
copy(sig[20-len(rb):20], rb)
copy(sig[40-len(sb):], sb)
return &Signature{
Format: k.PublicKey().Type(),
Blob: sig,
}, nil
}
type ecdsaPublicKey ecdsa.PublicKey
func (key *ecdsaPublicKey) Type() string {
return "ecdsa-sha2-" + key.nistID()
}
func (key *ecdsaPublicKey) nistID() string {
switch key.Params().BitSize {
case 256:
return "nistp256"
case 384:
return "nistp384"
case 521:
return "nistp521"
}
panic("ssh: unsupported ecdsa key size")
}
func supportedEllipticCurve(curve elliptic.Curve) bool {
return curve == elliptic.P256() || curve == elliptic.P384() || curve == elliptic.P521()
}
// ecHash returns the hash to match the given elliptic curve, see RFC
// 5656, section 6.2.1
func ecHash(curve elliptic.Curve) crypto.Hash {
bitSize := curve.Params().BitSize
switch {
case bitSize <= 256:
return crypto.SHA256
case bitSize <= 384:
return crypto.SHA384
}
return crypto.SHA512
}
// parseECDSA parses an ECDSA key according to RFC 5656, section 3.1.
func parseECDSA(in []byte) (out PublicKey, rest []byte, err error) {
identifier, in, ok := parseString(in)
if !ok {
return nil, nil, errShortRead
}
key := new(ecdsa.PublicKey)
switch string(identifier) {
case "nistp256":
key.Curve = elliptic.P256()
case "nistp384":
key.Curve = elliptic.P384()
case "nistp521":
key.Curve = elliptic.P521()
default:
return nil, nil, errors.New("ssh: unsupported curve")
}
var keyBytes []byte
if keyBytes, in, ok = parseString(in); !ok {
return nil, nil, errShortRead
}
key.X, key.Y = elliptic.Unmarshal(key.Curve, keyBytes)
if key.X == nil || key.Y == nil {
return nil, nil, errors.New("ssh: invalid curve point")
}
return (*ecdsaPublicKey)(key), in, nil
}
func (key *ecdsaPublicKey) Marshal() []byte {
// See RFC 5656, section 3.1.
keyBytes := elliptic.Marshal(key.Curve, key.X, key.Y)
w := struct {
Name string
ID string
Key []byte
}{
key.Type(),
key.nistID(),
keyBytes,
}
return Marshal(&w)
}
func (key *ecdsaPublicKey) Verify(data []byte, sig *Signature) error {
if sig.Format != key.Type() {
return fmt.Errorf("ssh: signature type %s for key type %s", sig.Format, key.Type())
}
h := ecHash(key.Curve).New()
h.Write(data)
digest := h.Sum(nil)
// Per RFC 5656, section 3.1.2,
// The ecdsa_signature_blob value has the following specific encoding:
// mpint r
// mpint s
var ecSig struct {
R *big.Int
S *big.Int
}
if err := Unmarshal(sig.Blob, &ecSig); err != nil {
return err
}
if ecdsa.Verify((*ecdsa.PublicKey)(key), digest, ecSig.R, ecSig.S) {
return nil
}
return errors.New("ssh: signature did not verify")
}
type ecdsaPrivateKey struct {
*ecdsa.PrivateKey
}
func (k *ecdsaPrivateKey) PublicKey() PublicKey {
return (*ecdsaPublicKey)(&k.PrivateKey.PublicKey)
}
func (k *ecdsaPrivateKey) Sign(rand io.Reader, data []byte) (*Signature, error) {
h := ecHash(k.PrivateKey.PublicKey.Curve).New()
h.Write(data)
digest := h.Sum(nil)
r, s, err := ecdsa.Sign(rand, k.PrivateKey, digest)
if err != nil {
return nil, err
}
sig := make([]byte, intLength(r)+intLength(s))
rest := marshalInt(sig, r)
marshalInt(rest, s)
return &Signature{
Format: k.PublicKey().Type(),
Blob: sig,
}, nil
}
// NewSignerFromKey takes a pointer to rsa, dsa or ecdsa PrivateKey
// returns a corresponding Signer instance. EC keys should use P256,
// P384 or P521.
func NewSignerFromKey(k interface{}) (Signer, error) {
var sshKey Signer
switch t := k.(type) {
case *rsa.PrivateKey:
sshKey = &rsaPrivateKey{t}
case *dsa.PrivateKey:
sshKey = &dsaPrivateKey{t}
case *ecdsa.PrivateKey:
if !supportedEllipticCurve(t.Curve) {
return nil, errors.New("ssh: only P256, P384 and P521 EC keys are supported.")
}
sshKey = &ecdsaPrivateKey{t}
default:
return nil, fmt.Errorf("ssh: unsupported key type %T", k)
}
return sshKey, nil
}
// NewPublicKey takes a pointer to rsa, dsa or ecdsa PublicKey
// and returns a corresponding ssh PublicKey instance. EC keys should use P256, P384 or P521.
func NewPublicKey(k interface{}) (PublicKey, error) {
var sshKey PublicKey
switch t := k.(type) {
case *rsa.PublicKey:
sshKey = (*rsaPublicKey)(t)
case *ecdsa.PublicKey:
if !supportedEllipticCurve(t.Curve) {
return nil, errors.New("ssh: only P256, P384 and P521 EC keys are supported.")
}
sshKey = (*ecdsaPublicKey)(t)
case *dsa.PublicKey:
sshKey = (*dsaPublicKey)(t)
default:
return nil, fmt.Errorf("ssh: unsupported key type %T", k)
}
return sshKey, nil
}
// ParsePrivateKey returns a Signer from a PEM encoded private key. It supports
// the same keys as ParseRawPrivateKey.
func ParsePrivateKey(pemBytes []byte) (Signer, error) {
key, err := ParseRawPrivateKey(pemBytes)
if err != nil {
return nil, err
}
return NewSignerFromKey(key)
}
// ParseRawPrivateKey returns a private key from a PEM encoded private key. It
// supports RSA (PKCS#1), DSA (OpenSSL), and ECDSA private keys.
func ParseRawPrivateKey(pemBytes []byte) (interface{}, error) {
block, _ := pem.Decode(pemBytes)
if block == nil {
return nil, errors.New("ssh: no key found")
}
switch block.Type {
case "RSA PRIVATE KEY":
return x509.ParsePKCS1PrivateKey(block.Bytes)
case "EC PRIVATE KEY":
return x509.ParseECPrivateKey(block.Bytes)
case "DSA PRIVATE KEY":
return ParseDSAPrivateKey(block.Bytes)
default:
return nil, fmt.Errorf("ssh: unsupported key type %q", block.Type)
}
}
// ParseDSAPrivateKey returns a DSA private key from its ASN.1 DER encoding, as
// specified by the OpenSSL DSA man page.
func ParseDSAPrivateKey(der []byte) (*dsa.PrivateKey, error) {
var k struct {
Version int
P *big.Int
Q *big.Int
G *big.Int
Priv *big.Int
Pub *big.Int
}
rest, err := asn1.Unmarshal(der, &k)
if err != nil {
return nil, errors.New("ssh: failed to parse DSA key: " + err.Error())
}
if len(rest) > 0 {
return nil, errors.New("ssh: garbage after DSA key")
}
return &dsa.PrivateKey{
PublicKey: dsa.PublicKey{
Parameters: dsa.Parameters{
P: k.P,
Q: k.Q,
G: k.G,
},
Y: k.Priv,
},
X: k.Pub,
}, nil
}

View file

@ -0,0 +1,306 @@
// Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ssh
import (
"bytes"
"crypto/dsa"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/rsa"
"encoding/base64"
"fmt"
"reflect"
"strings"
"testing"
"golang.org/x/crypto/ssh/testdata"
)
func rawKey(pub PublicKey) interface{} {
switch k := pub.(type) {
case *rsaPublicKey:
return (*rsa.PublicKey)(k)
case *dsaPublicKey:
return (*dsa.PublicKey)(k)
case *ecdsaPublicKey:
return (*ecdsa.PublicKey)(k)
case *Certificate:
return k
}
panic("unknown key type")
}
func TestKeyMarshalParse(t *testing.T) {
for _, priv := range testSigners {
pub := priv.PublicKey()
roundtrip, err := ParsePublicKey(pub.Marshal())
if err != nil {
t.Errorf("ParsePublicKey(%T): %v", pub, err)
}
k1 := rawKey(pub)
k2 := rawKey(roundtrip)
if !reflect.DeepEqual(k1, k2) {
t.Errorf("got %#v in roundtrip, want %#v", k2, k1)
}
}
}
func TestUnsupportedCurves(t *testing.T) {
raw, err := ecdsa.GenerateKey(elliptic.P224(), rand.Reader)
if err != nil {
t.Fatalf("GenerateKey: %v", err)
}
if _, err = NewSignerFromKey(raw); err == nil || !strings.Contains(err.Error(), "only P256") {
t.Fatalf("NewPrivateKey should not succeed with P224, got: %v", err)
}
if _, err = NewPublicKey(&raw.PublicKey); err == nil || !strings.Contains(err.Error(), "only P256") {
t.Fatalf("NewPublicKey should not succeed with P224, got: %v", err)
}
}
func TestNewPublicKey(t *testing.T) {
for _, k := range testSigners {
raw := rawKey(k.PublicKey())
// Skip certificates, as NewPublicKey does not support them.
if _, ok := raw.(*Certificate); ok {
continue
}
pub, err := NewPublicKey(raw)
if err != nil {
t.Errorf("NewPublicKey(%#v): %v", raw, err)
}
if !reflect.DeepEqual(k.PublicKey(), pub) {
t.Errorf("NewPublicKey(%#v) = %#v, want %#v", raw, pub, k.PublicKey())
}
}
}
func TestKeySignVerify(t *testing.T) {
for _, priv := range testSigners {
pub := priv.PublicKey()
data := []byte("sign me")
sig, err := priv.Sign(rand.Reader, data)
if err != nil {
t.Fatalf("Sign(%T): %v", priv, err)
}
if err := pub.Verify(data, sig); err != nil {
t.Errorf("publicKey.Verify(%T): %v", priv, err)
}
sig.Blob[5]++
if err := pub.Verify(data, sig); err == nil {
t.Errorf("publicKey.Verify on broken sig did not fail")
}
}
}
func TestParseRSAPrivateKey(t *testing.T) {
key := testPrivateKeys["rsa"]
rsa, ok := key.(*rsa.PrivateKey)
if !ok {
t.Fatalf("got %T, want *rsa.PrivateKey", rsa)
}
if err := rsa.Validate(); err != nil {
t.Errorf("Validate: %v", err)
}
}
func TestParseECPrivateKey(t *testing.T) {
key := testPrivateKeys["ecdsa"]
ecKey, ok := key.(*ecdsa.PrivateKey)
if !ok {
t.Fatalf("got %T, want *ecdsa.PrivateKey", ecKey)
}
if !validateECPublicKey(ecKey.Curve, ecKey.X, ecKey.Y) {
t.Fatalf("public key does not validate.")
}
}
func TestParseDSA(t *testing.T) {
// We actually exercise the ParsePrivateKey codepath here, as opposed to
// using the ParseRawPrivateKey+NewSignerFromKey path that testdata_test.go
// uses.
s, err := ParsePrivateKey(testdata.PEMBytes["dsa"])
if err != nil {
t.Fatalf("ParsePrivateKey returned error: %s", err)
}
data := []byte("sign me")
sig, err := s.Sign(rand.Reader, data)
if err != nil {
t.Fatalf("dsa.Sign: %v", err)
}
if err := s.PublicKey().Verify(data, sig); err != nil {
t.Errorf("Verify failed: %v", err)
}
}
// Tests for authorized_keys parsing.
// getTestKey returns a public key, and its base64 encoding.
func getTestKey() (PublicKey, string) {
k := testPublicKeys["rsa"]
b := &bytes.Buffer{}
e := base64.NewEncoder(base64.StdEncoding, b)
e.Write(k.Marshal())
e.Close()
return k, b.String()
}
func TestMarshalParsePublicKey(t *testing.T) {
pub, pubSerialized := getTestKey()
line := fmt.Sprintf("%s %s user@host", pub.Type(), pubSerialized)
authKeys := MarshalAuthorizedKey(pub)
actualFields := strings.Fields(string(authKeys))
if len(actualFields) == 0 {
t.Fatalf("failed authKeys: %v", authKeys)
}
// drop the comment
expectedFields := strings.Fields(line)[0:2]
if !reflect.DeepEqual(actualFields, expectedFields) {
t.Errorf("got %v, expected %v", actualFields, expectedFields)
}
actPub, _, _, _, err := ParseAuthorizedKey([]byte(line))
if err != nil {
t.Fatalf("cannot parse %v: %v", line, err)
}
if !reflect.DeepEqual(actPub, pub) {
t.Errorf("got %v, expected %v", actPub, pub)
}
}
type authResult struct {
pubKey PublicKey
options []string
comments string
rest string
ok bool
}
func testAuthorizedKeys(t *testing.T, authKeys []byte, expected []authResult) {
rest := authKeys
var values []authResult
for len(rest) > 0 {
var r authResult
var err error
r.pubKey, r.comments, r.options, rest, err = ParseAuthorizedKey(rest)
r.ok = (err == nil)
t.Log(err)
r.rest = string(rest)
values = append(values, r)
}
if !reflect.DeepEqual(values, expected) {
t.Errorf("got %#v, expected %#v", values, expected)
}
}
func TestAuthorizedKeyBasic(t *testing.T) {
pub, pubSerialized := getTestKey()
line := "ssh-rsa " + pubSerialized + " user@host"
testAuthorizedKeys(t, []byte(line),
[]authResult{
{pub, nil, "user@host", "", true},
})
}
func TestAuth(t *testing.T) {
pub, pubSerialized := getTestKey()
authWithOptions := []string{
`# comments to ignore before any keys...`,
``,
`env="HOME=/home/root",no-port-forwarding ssh-rsa ` + pubSerialized + ` user@host`,
`# comments to ignore, along with a blank line`,
``,
`env="HOME=/home/root2" ssh-rsa ` + pubSerialized + ` user2@host2`,
``,
`# more comments, plus a invalid entry`,
`ssh-rsa data-that-will-not-parse user@host3`,
}
for _, eol := range []string{"\n", "\r\n"} {
authOptions := strings.Join(authWithOptions, eol)
rest2 := strings.Join(authWithOptions[3:], eol)
rest3 := strings.Join(authWithOptions[6:], eol)
testAuthorizedKeys(t, []byte(authOptions), []authResult{
{pub, []string{`env="HOME=/home/root"`, "no-port-forwarding"}, "user@host", rest2, true},
{pub, []string{`env="HOME=/home/root2"`}, "user2@host2", rest3, true},
{nil, nil, "", "", false},
})
}
}
func TestAuthWithQuotedSpaceInEnv(t *testing.T) {
pub, pubSerialized := getTestKey()
authWithQuotedSpaceInEnv := []byte(`env="HOME=/home/root dir",no-port-forwarding ssh-rsa ` + pubSerialized + ` user@host`)
testAuthorizedKeys(t, []byte(authWithQuotedSpaceInEnv), []authResult{
{pub, []string{`env="HOME=/home/root dir"`, "no-port-forwarding"}, "user@host", "", true},
})
}
func TestAuthWithQuotedCommaInEnv(t *testing.T) {
pub, pubSerialized := getTestKey()
authWithQuotedCommaInEnv := []byte(`env="HOME=/home/root,dir",no-port-forwarding ssh-rsa ` + pubSerialized + ` user@host`)
testAuthorizedKeys(t, []byte(authWithQuotedCommaInEnv), []authResult{
{pub, []string{`env="HOME=/home/root,dir"`, "no-port-forwarding"}, "user@host", "", true},
})
}
func TestAuthWithQuotedQuoteInEnv(t *testing.T) {
pub, pubSerialized := getTestKey()
authWithQuotedQuoteInEnv := []byte(`env="HOME=/home/\"root dir",no-port-forwarding` + "\t" + `ssh-rsa` + "\t" + pubSerialized + ` user@host`)
authWithDoubleQuotedQuote := []byte(`no-port-forwarding,env="HOME=/home/ \"root dir\"" ssh-rsa ` + pubSerialized + "\t" + `user@host`)
testAuthorizedKeys(t, []byte(authWithQuotedQuoteInEnv), []authResult{
{pub, []string{`env="HOME=/home/\"root dir"`, "no-port-forwarding"}, "user@host", "", true},
})
testAuthorizedKeys(t, []byte(authWithDoubleQuotedQuote), []authResult{
{pub, []string{"no-port-forwarding", `env="HOME=/home/ \"root dir\""`}, "user@host", "", true},
})
}
func TestAuthWithInvalidSpace(t *testing.T) {
_, pubSerialized := getTestKey()
authWithInvalidSpace := []byte(`env="HOME=/home/root dir", no-port-forwarding ssh-rsa ` + pubSerialized + ` user@host
#more to follow but still no valid keys`)
testAuthorizedKeys(t, []byte(authWithInvalidSpace), []authResult{
{nil, nil, "", "", false},
})
}
func TestAuthWithMissingQuote(t *testing.T) {
pub, pubSerialized := getTestKey()
authWithMissingQuote := []byte(`env="HOME=/home/root,no-port-forwarding ssh-rsa ` + pubSerialized + ` user@host
env="HOME=/home/root",shared-control ssh-rsa ` + pubSerialized + ` user@host`)
testAuthorizedKeys(t, []byte(authWithMissingQuote), []authResult{
{pub, []string{`env="HOME=/home/root"`, `shared-control`}, "user@host", "", true},
})
}
func TestInvalidEntry(t *testing.T) {
authInvalid := []byte(`ssh-rsa`)
_, _, _, _, err := ParseAuthorizedKey(authInvalid)
if err == nil {
t.Errorf("got valid entry for %q", authInvalid)
}
}

View file

@ -0,0 +1,53 @@
// Copyright 2012 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ssh
// Message authentication support
import (
"crypto/hmac"
"crypto/sha1"
"hash"
)
type macMode struct {
keySize int
new func(key []byte) hash.Hash
}
// truncatingMAC wraps around a hash.Hash and truncates the output digest to
// a given size.
type truncatingMAC struct {
length int
hmac hash.Hash
}
func (t truncatingMAC) Write(data []byte) (int, error) {
return t.hmac.Write(data)
}
func (t truncatingMAC) Sum(in []byte) []byte {
out := t.hmac.Sum(in)
return out[:len(in)+t.length]
}
func (t truncatingMAC) Reset() {
t.hmac.Reset()
}
func (t truncatingMAC) Size() int {
return t.length
}
func (t truncatingMAC) BlockSize() int { return t.hmac.BlockSize() }
var macModes = map[string]*macMode{
"hmac-sha1": {20, func(key []byte) hash.Hash {
return hmac.New(sha1.New, key)
}},
"hmac-sha1-96": {20, func(key []byte) hash.Hash {
return truncatingMAC{12, hmac.New(sha1.New, key)}
}},
}

View file

@ -0,0 +1,110 @@
// Copyright 2013 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ssh
import (
"io"
"sync"
"testing"
)
// An in-memory packetConn. It is safe to call Close and writePacket
// from different goroutines.
type memTransport struct {
eof bool
pending [][]byte
write *memTransport
sync.Mutex
*sync.Cond
}
func (t *memTransport) readPacket() ([]byte, error) {
t.Lock()
defer t.Unlock()
for {
if len(t.pending) > 0 {
r := t.pending[0]
t.pending = t.pending[1:]
return r, nil
}
if t.eof {
return nil, io.EOF
}
t.Cond.Wait()
}
}
func (t *memTransport) closeSelf() error {
t.Lock()
defer t.Unlock()
if t.eof {
return io.EOF
}
t.eof = true
t.Cond.Broadcast()
return nil
}
func (t *memTransport) Close() error {
err := t.write.closeSelf()
t.closeSelf()
return err
}
func (t *memTransport) writePacket(p []byte) error {
t.write.Lock()
defer t.write.Unlock()
if t.write.eof {
return io.EOF
}
c := make([]byte, len(p))
copy(c, p)
t.write.pending = append(t.write.pending, c)
t.write.Cond.Signal()
return nil
}
func memPipe() (a, b packetConn) {
t1 := memTransport{}
t2 := memTransport{}
t1.write = &t2
t2.write = &t1
t1.Cond = sync.NewCond(&t1.Mutex)
t2.Cond = sync.NewCond(&t2.Mutex)
return &t1, &t2
}
func TestmemPipe(t *testing.T) {
a, b := memPipe()
if err := a.writePacket([]byte{42}); err != nil {
t.Fatalf("writePacket: %v", err)
}
if err := a.Close(); err != nil {
t.Fatal("Close: ", err)
}
p, err := b.readPacket()
if err != nil {
t.Fatal("readPacket: ", err)
}
if len(p) != 1 || p[0] != 42 {
t.Fatalf("got %v, want {42}", p)
}
p, err = b.readPacket()
if err != io.EOF {
t.Fatalf("got %v, %v, want EOF", p, err)
}
}
func TestDoubleClose(t *testing.T) {
a, _ := memPipe()
err := a.Close()
if err != nil {
t.Errorf("Close: %v", err)
}
err = a.Close()
if err != io.EOF {
t.Errorf("expect EOF on double close.")
}
}

View file

@ -0,0 +1,724 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ssh
import (
"bytes"
"encoding/binary"
"errors"
"fmt"
"io"
"math/big"
"reflect"
"strconv"
)
// These are SSH message type numbers. They are scattered around several
// documents but many were taken from [SSH-PARAMETERS].
const (
msgIgnore = 2
msgUnimplemented = 3
msgDebug = 4
msgNewKeys = 21
// Standard authentication messages
msgUserAuthSuccess = 52
msgUserAuthBanner = 53
)
// SSH messages:
//
// These structures mirror the wire format of the corresponding SSH messages.
// They are marshaled using reflection with the marshal and unmarshal functions
// in this file. The only wrinkle is that a final member of type []byte with a
// ssh tag of "rest" receives the remainder of a packet when unmarshaling.
// See RFC 4253, section 11.1.
const msgDisconnect = 1
// disconnectMsg is the message that signals a disconnect. It is also
// the error type returned from mux.Wait()
type disconnectMsg struct {
Reason uint32 `sshtype:"1"`
Message string
Language string
}
func (d *disconnectMsg) Error() string {
return fmt.Sprintf("ssh: disconnect reason %d: %s", d.Reason, d.Message)
}
// See RFC 4253, section 7.1.
const msgKexInit = 20
type kexInitMsg struct {
Cookie [16]byte `sshtype:"20"`
KexAlgos []string
ServerHostKeyAlgos []string
CiphersClientServer []string
CiphersServerClient []string
MACsClientServer []string
MACsServerClient []string
CompressionClientServer []string
CompressionServerClient []string
LanguagesClientServer []string
LanguagesServerClient []string
FirstKexFollows bool
Reserved uint32
}
// See RFC 4253, section 8.
// Diffie-Helman
const msgKexDHInit = 30
type kexDHInitMsg struct {
X *big.Int `sshtype:"30"`
}
const msgKexECDHInit = 30
type kexECDHInitMsg struct {
ClientPubKey []byte `sshtype:"30"`
}
const msgKexECDHReply = 31
type kexECDHReplyMsg struct {
HostKey []byte `sshtype:"31"`
EphemeralPubKey []byte
Signature []byte
}
const msgKexDHReply = 31
type kexDHReplyMsg struct {
HostKey []byte `sshtype:"31"`
Y *big.Int
Signature []byte
}
// See RFC 4253, section 10.
const msgServiceRequest = 5
type serviceRequestMsg struct {
Service string `sshtype:"5"`
}
// See RFC 4253, section 10.
const msgServiceAccept = 6
type serviceAcceptMsg struct {
Service string `sshtype:"6"`
}
// See RFC 4252, section 5.
const msgUserAuthRequest = 50
type userAuthRequestMsg struct {
User string `sshtype:"50"`
Service string
Method string
Payload []byte `ssh:"rest"`
}
// See RFC 4252, section 5.1
const msgUserAuthFailure = 51
type userAuthFailureMsg struct {
Methods []string `sshtype:"51"`
PartialSuccess bool
}
// See RFC 4256, section 3.2
const msgUserAuthInfoRequest = 60
const msgUserAuthInfoResponse = 61
type userAuthInfoRequestMsg struct {
User string `sshtype:"60"`
Instruction string
DeprecatedLanguage string
NumPrompts uint32
Prompts []byte `ssh:"rest"`
}
// See RFC 4254, section 5.1.
const msgChannelOpen = 90
type channelOpenMsg struct {
ChanType string `sshtype:"90"`
PeersId uint32
PeersWindow uint32
MaxPacketSize uint32
TypeSpecificData []byte `ssh:"rest"`
}
const msgChannelExtendedData = 95
const msgChannelData = 94
// See RFC 4254, section 5.1.
const msgChannelOpenConfirm = 91
type channelOpenConfirmMsg struct {
PeersId uint32 `sshtype:"91"`
MyId uint32
MyWindow uint32
MaxPacketSize uint32
TypeSpecificData []byte `ssh:"rest"`
}
// See RFC 4254, section 5.1.
const msgChannelOpenFailure = 92
type channelOpenFailureMsg struct {
PeersId uint32 `sshtype:"92"`
Reason RejectionReason
Message string
Language string
}
const msgChannelRequest = 98
type channelRequestMsg struct {
PeersId uint32 `sshtype:"98"`
Request string
WantReply bool
RequestSpecificData []byte `ssh:"rest"`
}
// See RFC 4254, section 5.4.
const msgChannelSuccess = 99
type channelRequestSuccessMsg struct {
PeersId uint32 `sshtype:"99"`
}
// See RFC 4254, section 5.4.
const msgChannelFailure = 100
type channelRequestFailureMsg struct {
PeersId uint32 `sshtype:"100"`
}
// See RFC 4254, section 5.3
const msgChannelClose = 97
type channelCloseMsg struct {
PeersId uint32 `sshtype:"97"`
}
// See RFC 4254, section 5.3
const msgChannelEOF = 96
type channelEOFMsg struct {
PeersId uint32 `sshtype:"96"`
}
// See RFC 4254, section 4
const msgGlobalRequest = 80
type globalRequestMsg struct {
Type string `sshtype:"80"`
WantReply bool
Data []byte `ssh:"rest"`
}
// See RFC 4254, section 4
const msgRequestSuccess = 81
type globalRequestSuccessMsg struct {
Data []byte `ssh:"rest" sshtype:"81"`
}
// See RFC 4254, section 4
const msgRequestFailure = 82
type globalRequestFailureMsg struct {
Data []byte `ssh:"rest" sshtype:"82"`
}
// See RFC 4254, section 5.2
const msgChannelWindowAdjust = 93
type windowAdjustMsg struct {
PeersId uint32 `sshtype:"93"`
AdditionalBytes uint32
}
// See RFC 4252, section 7
const msgUserAuthPubKeyOk = 60
type userAuthPubKeyOkMsg struct {
Algo string `sshtype:"60"`
PubKey []byte
}
// typeTag returns the type byte for the given type. The type should
// be struct.
func typeTag(structType reflect.Type) byte {
var tag byte
var tagStr string
tagStr = structType.Field(0).Tag.Get("sshtype")
i, err := strconv.Atoi(tagStr)
if err == nil {
tag = byte(i)
}
return tag
}
func fieldError(t reflect.Type, field int, problem string) error {
if problem != "" {
problem = ": " + problem
}
return fmt.Errorf("ssh: unmarshal error for field %s of type %s%s", t.Field(field).Name, t.Name(), problem)
}
var errShortRead = errors.New("ssh: short read")
// Unmarshal parses data in SSH wire format into a structure. The out
// argument should be a pointer to struct. If the first member of the
// struct has the "sshtype" tag set to a number in decimal, the packet
// must start that number. In case of error, Unmarshal returns a
// ParseError or UnexpectedMessageError.
func Unmarshal(data []byte, out interface{}) error {
v := reflect.ValueOf(out).Elem()
structType := v.Type()
expectedType := typeTag(structType)
if len(data) == 0 {
return parseError(expectedType)
}
if expectedType > 0 {
if data[0] != expectedType {
return unexpectedMessageError(expectedType, data[0])
}
data = data[1:]
}
var ok bool
for i := 0; i < v.NumField(); i++ {
field := v.Field(i)
t := field.Type()
switch t.Kind() {
case reflect.Bool:
if len(data) < 1 {
return errShortRead
}
field.SetBool(data[0] != 0)
data = data[1:]
case reflect.Array:
if t.Elem().Kind() != reflect.Uint8 {
return fieldError(structType, i, "array of unsupported type")
}
if len(data) < t.Len() {
return errShortRead
}
for j, n := 0, t.Len(); j < n; j++ {
field.Index(j).Set(reflect.ValueOf(data[j]))
}
data = data[t.Len():]
case reflect.Uint64:
var u64 uint64
if u64, data, ok = parseUint64(data); !ok {
return errShortRead
}
field.SetUint(u64)
case reflect.Uint32:
var u32 uint32
if u32, data, ok = parseUint32(data); !ok {
return errShortRead
}
field.SetUint(uint64(u32))
case reflect.Uint8:
if len(data) < 1 {
return errShortRead
}
field.SetUint(uint64(data[0]))
data = data[1:]
case reflect.String:
var s []byte
if s, data, ok = parseString(data); !ok {
return fieldError(structType, i, "")
}
field.SetString(string(s))
case reflect.Slice:
switch t.Elem().Kind() {
case reflect.Uint8:
if structType.Field(i).Tag.Get("ssh") == "rest" {
field.Set(reflect.ValueOf(data))
data = nil
} else {
var s []byte
if s, data, ok = parseString(data); !ok {
return errShortRead
}
field.Set(reflect.ValueOf(s))
}
case reflect.String:
var nl []string
if nl, data, ok = parseNameList(data); !ok {
return errShortRead
}
field.Set(reflect.ValueOf(nl))
default:
return fieldError(structType, i, "slice of unsupported type")
}
case reflect.Ptr:
if t == bigIntType {
var n *big.Int
if n, data, ok = parseInt(data); !ok {
return errShortRead
}
field.Set(reflect.ValueOf(n))
} else {
return fieldError(structType, i, "pointer to unsupported type")
}
default:
return fieldError(structType, i, "unsupported type")
}
}
if len(data) != 0 {
return parseError(expectedType)
}
return nil
}
// Marshal serializes the message in msg to SSH wire format. The msg
// argument should be a struct or pointer to struct. If the first
// member has the "sshtype" tag set to a number in decimal, that
// number is prepended to the result. If the last of member has the
// "ssh" tag set to "rest", its contents are appended to the output.
func Marshal(msg interface{}) []byte {
out := make([]byte, 0, 64)
return marshalStruct(out, msg)
}
func marshalStruct(out []byte, msg interface{}) []byte {
v := reflect.Indirect(reflect.ValueOf(msg))
msgType := typeTag(v.Type())
if msgType > 0 {
out = append(out, msgType)
}
for i, n := 0, v.NumField(); i < n; i++ {
field := v.Field(i)
switch t := field.Type(); t.Kind() {
case reflect.Bool:
var v uint8
if field.Bool() {
v = 1
}
out = append(out, v)
case reflect.Array:
if t.Elem().Kind() != reflect.Uint8 {
panic(fmt.Sprintf("array of non-uint8 in field %d: %T", i, field.Interface()))
}
for j, l := 0, t.Len(); j < l; j++ {
out = append(out, uint8(field.Index(j).Uint()))
}
case reflect.Uint32:
out = appendU32(out, uint32(field.Uint()))
case reflect.Uint64:
out = appendU64(out, uint64(field.Uint()))
case reflect.Uint8:
out = append(out, uint8(field.Uint()))
case reflect.String:
s := field.String()
out = appendInt(out, len(s))
out = append(out, s...)
case reflect.Slice:
switch t.Elem().Kind() {
case reflect.Uint8:
if v.Type().Field(i).Tag.Get("ssh") != "rest" {
out = appendInt(out, field.Len())
}
out = append(out, field.Bytes()...)
case reflect.String:
offset := len(out)
out = appendU32(out, 0)
if n := field.Len(); n > 0 {
for j := 0; j < n; j++ {
f := field.Index(j)
if j != 0 {
out = append(out, ',')
}
out = append(out, f.String()...)
}
// overwrite length value
binary.BigEndian.PutUint32(out[offset:], uint32(len(out)-offset-4))
}
default:
panic(fmt.Sprintf("slice of unknown type in field %d: %T", i, field.Interface()))
}
case reflect.Ptr:
if t == bigIntType {
var n *big.Int
nValue := reflect.ValueOf(&n)
nValue.Elem().Set(field)
needed := intLength(n)
oldLength := len(out)
if cap(out)-len(out) < needed {
newOut := make([]byte, len(out), 2*(len(out)+needed))
copy(newOut, out)
out = newOut
}
out = out[:oldLength+needed]
marshalInt(out[oldLength:], n)
} else {
panic(fmt.Sprintf("pointer to unknown type in field %d: %T", i, field.Interface()))
}
}
}
return out
}
var bigOne = big.NewInt(1)
func parseString(in []byte) (out, rest []byte, ok bool) {
if len(in) < 4 {
return
}
length := binary.BigEndian.Uint32(in)
if uint32(len(in)) < 4+length {
return
}
out = in[4 : 4+length]
rest = in[4+length:]
ok = true
return
}
var (
comma = []byte{','}
emptyNameList = []string{}
)
func parseNameList(in []byte) (out []string, rest []byte, ok bool) {
contents, rest, ok := parseString(in)
if !ok {
return
}
if len(contents) == 0 {
out = emptyNameList
return
}
parts := bytes.Split(contents, comma)
out = make([]string, len(parts))
for i, part := range parts {
out[i] = string(part)
}
return
}
func parseInt(in []byte) (out *big.Int, rest []byte, ok bool) {
contents, rest, ok := parseString(in)
if !ok {
return
}
out = new(big.Int)
if len(contents) > 0 && contents[0]&0x80 == 0x80 {
// This is a negative number
notBytes := make([]byte, len(contents))
for i := range notBytes {
notBytes[i] = ^contents[i]
}
out.SetBytes(notBytes)
out.Add(out, bigOne)
out.Neg(out)
} else {
// Positive number
out.SetBytes(contents)
}
ok = true
return
}
func parseUint32(in []byte) (uint32, []byte, bool) {
if len(in) < 4 {
return 0, nil, false
}
return binary.BigEndian.Uint32(in), in[4:], true
}
func parseUint64(in []byte) (uint64, []byte, bool) {
if len(in) < 8 {
return 0, nil, false
}
return binary.BigEndian.Uint64(in), in[8:], true
}
func intLength(n *big.Int) int {
length := 4 /* length bytes */
if n.Sign() < 0 {
nMinus1 := new(big.Int).Neg(n)
nMinus1.Sub(nMinus1, bigOne)
bitLen := nMinus1.BitLen()
if bitLen%8 == 0 {
// The number will need 0xff padding
length++
}
length += (bitLen + 7) / 8
} else if n.Sign() == 0 {
// A zero is the zero length string
} else {
bitLen := n.BitLen()
if bitLen%8 == 0 {
// The number will need 0x00 padding
length++
}
length += (bitLen + 7) / 8
}
return length
}
func marshalUint32(to []byte, n uint32) []byte {
binary.BigEndian.PutUint32(to, n)
return to[4:]
}
func marshalUint64(to []byte, n uint64) []byte {
binary.BigEndian.PutUint64(to, n)
return to[8:]
}
func marshalInt(to []byte, n *big.Int) []byte {
lengthBytes := to
to = to[4:]
length := 0
if n.Sign() < 0 {
// A negative number has to be converted to two's-complement
// form. So we'll subtract 1 and invert. If the
// most-significant-bit isn't set then we'll need to pad the
// beginning with 0xff in order to keep the number negative.
nMinus1 := new(big.Int).Neg(n)
nMinus1.Sub(nMinus1, bigOne)
bytes := nMinus1.Bytes()
for i := range bytes {
bytes[i] ^= 0xff
}
if len(bytes) == 0 || bytes[0]&0x80 == 0 {
to[0] = 0xff
to = to[1:]
length++
}
nBytes := copy(to, bytes)
to = to[nBytes:]
length += nBytes
} else if n.Sign() == 0 {
// A zero is the zero length string
} else {
bytes := n.Bytes()
if len(bytes) > 0 && bytes[0]&0x80 != 0 {
// We'll have to pad this with a 0x00 in order to
// stop it looking like a negative number.
to[0] = 0
to = to[1:]
length++
}
nBytes := copy(to, bytes)
to = to[nBytes:]
length += nBytes
}
lengthBytes[0] = byte(length >> 24)
lengthBytes[1] = byte(length >> 16)
lengthBytes[2] = byte(length >> 8)
lengthBytes[3] = byte(length)
return to
}
func writeInt(w io.Writer, n *big.Int) {
length := intLength(n)
buf := make([]byte, length)
marshalInt(buf, n)
w.Write(buf)
}
func writeString(w io.Writer, s []byte) {
var lengthBytes [4]byte
lengthBytes[0] = byte(len(s) >> 24)
lengthBytes[1] = byte(len(s) >> 16)
lengthBytes[2] = byte(len(s) >> 8)
lengthBytes[3] = byte(len(s))
w.Write(lengthBytes[:])
w.Write(s)
}
func stringLength(n int) int {
return 4 + n
}
func marshalString(to []byte, s []byte) []byte {
to[0] = byte(len(s) >> 24)
to[1] = byte(len(s) >> 16)
to[2] = byte(len(s) >> 8)
to[3] = byte(len(s))
to = to[4:]
copy(to, s)
return to[len(s):]
}
var bigIntType = reflect.TypeOf((*big.Int)(nil))
// Decode a packet into its corresponding message.
func decode(packet []byte) (interface{}, error) {
var msg interface{}
switch packet[0] {
case msgDisconnect:
msg = new(disconnectMsg)
case msgServiceRequest:
msg = new(serviceRequestMsg)
case msgServiceAccept:
msg = new(serviceAcceptMsg)
case msgKexInit:
msg = new(kexInitMsg)
case msgKexDHInit:
msg = new(kexDHInitMsg)
case msgKexDHReply:
msg = new(kexDHReplyMsg)
case msgUserAuthRequest:
msg = new(userAuthRequestMsg)
case msgUserAuthFailure:
msg = new(userAuthFailureMsg)
case msgUserAuthPubKeyOk:
msg = new(userAuthPubKeyOkMsg)
case msgGlobalRequest:
msg = new(globalRequestMsg)
case msgRequestSuccess:
msg = new(globalRequestSuccessMsg)
case msgRequestFailure:
msg = new(globalRequestFailureMsg)
case msgChannelOpen:
msg = new(channelOpenMsg)
case msgChannelOpenConfirm:
msg = new(channelOpenConfirmMsg)
case msgChannelOpenFailure:
msg = new(channelOpenFailureMsg)
case msgChannelWindowAdjust:
msg = new(windowAdjustMsg)
case msgChannelEOF:
msg = new(channelEOFMsg)
case msgChannelClose:
msg = new(channelCloseMsg)
case msgChannelRequest:
msg = new(channelRequestMsg)
case msgChannelSuccess:
msg = new(channelRequestSuccessMsg)
case msgChannelFailure:
msg = new(channelRequestFailureMsg)
default:
return nil, unexpectedMessageError(0, packet[0])
}
if err := Unmarshal(packet, msg); err != nil {
return nil, err
}
return msg, nil
}

View file

@ -0,0 +1,244 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ssh
import (
"bytes"
"math/big"
"math/rand"
"reflect"
"testing"
"testing/quick"
)
var intLengthTests = []struct {
val, length int
}{
{0, 4 + 0},
{1, 4 + 1},
{127, 4 + 1},
{128, 4 + 2},
{-1, 4 + 1},
}
func TestIntLength(t *testing.T) {
for _, test := range intLengthTests {
v := new(big.Int).SetInt64(int64(test.val))
length := intLength(v)
if length != test.length {
t.Errorf("For %d, got length %d but expected %d", test.val, length, test.length)
}
}
}
type msgAllTypes struct {
Bool bool `sshtype:"21"`
Array [16]byte
Uint64 uint64
Uint32 uint32
Uint8 uint8
String string
Strings []string
Bytes []byte
Int *big.Int
Rest []byte `ssh:"rest"`
}
func (t *msgAllTypes) Generate(rand *rand.Rand, size int) reflect.Value {
m := &msgAllTypes{}
m.Bool = rand.Intn(2) == 1
randomBytes(m.Array[:], rand)
m.Uint64 = uint64(rand.Int63n(1<<63 - 1))
m.Uint32 = uint32(rand.Intn((1 << 31) - 1))
m.Uint8 = uint8(rand.Intn(1 << 8))
m.String = string(m.Array[:])
m.Strings = randomNameList(rand)
m.Bytes = m.Array[:]
m.Int = randomInt(rand)
m.Rest = m.Array[:]
return reflect.ValueOf(m)
}
func TestMarshalUnmarshal(t *testing.T) {
rand := rand.New(rand.NewSource(0))
iface := &msgAllTypes{}
ty := reflect.ValueOf(iface).Type()
n := 100
if testing.Short() {
n = 5
}
for j := 0; j < n; j++ {
v, ok := quick.Value(ty, rand)
if !ok {
t.Errorf("failed to create value")
break
}
m1 := v.Elem().Interface()
m2 := iface
marshaled := Marshal(m1)
if err := Unmarshal(marshaled, m2); err != nil {
t.Errorf("Unmarshal %#v: %s", m1, err)
break
}
if !reflect.DeepEqual(v.Interface(), m2) {
t.Errorf("got: %#v\nwant:%#v\n%x", m2, m1, marshaled)
break
}
}
}
func TestUnmarshalEmptyPacket(t *testing.T) {
var b []byte
var m channelRequestSuccessMsg
if err := Unmarshal(b, &m); err == nil {
t.Fatalf("unmarshal of empty slice succeeded")
}
}
func TestUnmarshalUnexpectedPacket(t *testing.T) {
type S struct {
I uint32 `sshtype:"43"`
S string
B bool
}
s := S{11, "hello", true}
packet := Marshal(s)
packet[0] = 42
roundtrip := S{}
err := Unmarshal(packet, &roundtrip)
if err == nil {
t.Fatal("expected error, not nil")
}
}
func TestMarshalPtr(t *testing.T) {
s := struct {
S string
}{"hello"}
m1 := Marshal(s)
m2 := Marshal(&s)
if !bytes.Equal(m1, m2) {
t.Errorf("got %q, want %q for marshaled pointer", m2, m1)
}
}
func TestBareMarshalUnmarshal(t *testing.T) {
type S struct {
I uint32
S string
B bool
}
s := S{42, "hello", true}
packet := Marshal(s)
roundtrip := S{}
Unmarshal(packet, &roundtrip)
if !reflect.DeepEqual(s, roundtrip) {
t.Errorf("got %#v, want %#v", roundtrip, s)
}
}
func TestBareMarshal(t *testing.T) {
type S2 struct {
I uint32
}
s := S2{42}
packet := Marshal(s)
i, rest, ok := parseUint32(packet)
if len(rest) > 0 || !ok {
t.Errorf("parseInt(%q): parse error", packet)
}
if i != s.I {
t.Errorf("got %d, want %d", i, s.I)
}
}
func randomBytes(out []byte, rand *rand.Rand) {
for i := 0; i < len(out); i++ {
out[i] = byte(rand.Int31())
}
}
func randomNameList(rand *rand.Rand) []string {
ret := make([]string, rand.Int31()&15)
for i := range ret {
s := make([]byte, 1+(rand.Int31()&15))
for j := range s {
s[j] = 'a' + uint8(rand.Int31()&15)
}
ret[i] = string(s)
}
return ret
}
func randomInt(rand *rand.Rand) *big.Int {
return new(big.Int).SetInt64(int64(int32(rand.Uint32())))
}
func (*kexInitMsg) Generate(rand *rand.Rand, size int) reflect.Value {
ki := &kexInitMsg{}
randomBytes(ki.Cookie[:], rand)
ki.KexAlgos = randomNameList(rand)
ki.ServerHostKeyAlgos = randomNameList(rand)
ki.CiphersClientServer = randomNameList(rand)
ki.CiphersServerClient = randomNameList(rand)
ki.MACsClientServer = randomNameList(rand)
ki.MACsServerClient = randomNameList(rand)
ki.CompressionClientServer = randomNameList(rand)
ki.CompressionServerClient = randomNameList(rand)
ki.LanguagesClientServer = randomNameList(rand)
ki.LanguagesServerClient = randomNameList(rand)
if rand.Int31()&1 == 1 {
ki.FirstKexFollows = true
}
return reflect.ValueOf(ki)
}
func (*kexDHInitMsg) Generate(rand *rand.Rand, size int) reflect.Value {
dhi := &kexDHInitMsg{}
dhi.X = randomInt(rand)
return reflect.ValueOf(dhi)
}
var (
_kexInitMsg = new(kexInitMsg).Generate(rand.New(rand.NewSource(0)), 10).Elem().Interface()
_kexDHInitMsg = new(kexDHInitMsg).Generate(rand.New(rand.NewSource(0)), 10).Elem().Interface()
_kexInit = Marshal(_kexInitMsg)
_kexDHInit = Marshal(_kexDHInitMsg)
)
func BenchmarkMarshalKexInitMsg(b *testing.B) {
for i := 0; i < b.N; i++ {
Marshal(_kexInitMsg)
}
}
func BenchmarkUnmarshalKexInitMsg(b *testing.B) {
m := new(kexInitMsg)
for i := 0; i < b.N; i++ {
Unmarshal(_kexInit, m)
}
}
func BenchmarkMarshalKexDHInitMsg(b *testing.B) {
for i := 0; i < b.N; i++ {
Marshal(_kexDHInitMsg)
}
}
func BenchmarkUnmarshalKexDHInitMsg(b *testing.B) {
m := new(kexDHInitMsg)
for i := 0; i < b.N; i++ {
Unmarshal(_kexDHInit, m)
}
}

View file

@ -0,0 +1,356 @@
// Copyright 2013 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ssh
import (
"encoding/binary"
"fmt"
"io"
"log"
"sync"
"sync/atomic"
)
// debugMux, if set, causes messages in the connection protocol to be
// logged.
const debugMux = false
// chanList is a thread safe channel list.
type chanList struct {
// protects concurrent access to chans
sync.Mutex
// chans are indexed by the local id of the channel, which the
// other side should send in the PeersId field.
chans []*channel
// This is a debugging aid: it offsets all IDs by this
// amount. This helps distinguish otherwise identical
// server/client muxes
offset uint32
}
// Assigns a channel ID to the given channel.
func (c *chanList) add(ch *channel) uint32 {
c.Lock()
defer c.Unlock()
for i := range c.chans {
if c.chans[i] == nil {
c.chans[i] = ch
return uint32(i) + c.offset
}
}
c.chans = append(c.chans, ch)
return uint32(len(c.chans)-1) + c.offset
}
// getChan returns the channel for the given ID.
func (c *chanList) getChan(id uint32) *channel {
id -= c.offset
c.Lock()
defer c.Unlock()
if id < uint32(len(c.chans)) {
return c.chans[id]
}
return nil
}
func (c *chanList) remove(id uint32) {
id -= c.offset
c.Lock()
if id < uint32(len(c.chans)) {
c.chans[id] = nil
}
c.Unlock()
}
// dropAll forgets all channels it knows, returning them in a slice.
func (c *chanList) dropAll() []*channel {
c.Lock()
defer c.Unlock()
var r []*channel
for _, ch := range c.chans {
if ch == nil {
continue
}
r = append(r, ch)
}
c.chans = nil
return r
}
// mux represents the state for the SSH connection protocol, which
// multiplexes many channels onto a single packet transport.
type mux struct {
conn packetConn
chanList chanList
incomingChannels chan NewChannel
globalSentMu sync.Mutex
globalResponses chan interface{}
incomingRequests chan *Request
errCond *sync.Cond
err error
}
// When debugging, each new chanList instantiation has a different
// offset.
var globalOff uint32
func (m *mux) Wait() error {
m.errCond.L.Lock()
defer m.errCond.L.Unlock()
for m.err == nil {
m.errCond.Wait()
}
return m.err
}
// newMux returns a mux that runs over the given connection.
func newMux(p packetConn) *mux {
m := &mux{
conn: p,
incomingChannels: make(chan NewChannel, 16),
globalResponses: make(chan interface{}, 1),
incomingRequests: make(chan *Request, 16),
errCond: newCond(),
}
if debugMux {
m.chanList.offset = atomic.AddUint32(&globalOff, 1)
}
go m.loop()
return m
}
func (m *mux) sendMessage(msg interface{}) error {
p := Marshal(msg)
return m.conn.writePacket(p)
}
func (m *mux) SendRequest(name string, wantReply bool, payload []byte) (bool, []byte, error) {
if wantReply {
m.globalSentMu.Lock()
defer m.globalSentMu.Unlock()
}
if err := m.sendMessage(globalRequestMsg{
Type: name,
WantReply: wantReply,
Data: payload,
}); err != nil {
return false, nil, err
}
if !wantReply {
return false, nil, nil
}
msg, ok := <-m.globalResponses
if !ok {
return false, nil, io.EOF
}
switch msg := msg.(type) {
case *globalRequestFailureMsg:
return false, msg.Data, nil
case *globalRequestSuccessMsg:
return true, msg.Data, nil
default:
return false, nil, fmt.Errorf("ssh: unexpected response to request: %#v", msg)
}
}
// ackRequest must be called after processing a global request that
// has WantReply set.
func (m *mux) ackRequest(ok bool, data []byte) error {
if ok {
return m.sendMessage(globalRequestSuccessMsg{Data: data})
}
return m.sendMessage(globalRequestFailureMsg{Data: data})
}
// TODO(hanwen): Disconnect is a transport layer message. We should
// probably send and receive Disconnect somewhere in the transport
// code.
// Disconnect sends a disconnect message.
func (m *mux) Disconnect(reason uint32, message string) error {
return m.sendMessage(disconnectMsg{
Reason: reason,
Message: message,
})
}
func (m *mux) Close() error {
return m.conn.Close()
}
// loop runs the connection machine. It will process packets until an
// error is encountered. To synchronize on loop exit, use mux.Wait.
func (m *mux) loop() {
var err error
for err == nil {
err = m.onePacket()
}
for _, ch := range m.chanList.dropAll() {
ch.close()
}
close(m.incomingChannels)
close(m.incomingRequests)
close(m.globalResponses)
m.conn.Close()
m.errCond.L.Lock()
m.err = err
m.errCond.Broadcast()
m.errCond.L.Unlock()
if debugMux {
log.Println("loop exit", err)
}
}
// onePacket reads and processes one packet.
func (m *mux) onePacket() error {
packet, err := m.conn.readPacket()
if err != nil {
return err
}
if debugMux {
if packet[0] == msgChannelData || packet[0] == msgChannelExtendedData {
log.Printf("decoding(%d): data packet - %d bytes", m.chanList.offset, len(packet))
} else {
p, _ := decode(packet)
log.Printf("decoding(%d): %d %#v - %d bytes", m.chanList.offset, packet[0], p, len(packet))
}
}
switch packet[0] {
case msgNewKeys:
// Ignore notification of key change.
return nil
case msgDisconnect:
return m.handleDisconnect(packet)
case msgChannelOpen:
return m.handleChannelOpen(packet)
case msgGlobalRequest, msgRequestSuccess, msgRequestFailure:
return m.handleGlobalPacket(packet)
}
// assume a channel packet.
if len(packet) < 5 {
return parseError(packet[0])
}
id := binary.BigEndian.Uint32(packet[1:])
ch := m.chanList.getChan(id)
if ch == nil {
return fmt.Errorf("ssh: invalid channel %d", id)
}
return ch.handlePacket(packet)
}
func (m *mux) handleDisconnect(packet []byte) error {
var d disconnectMsg
if err := Unmarshal(packet, &d); err != nil {
return err
}
if debugMux {
log.Printf("caught disconnect: %v", d)
}
return &d
}
func (m *mux) handleGlobalPacket(packet []byte) error {
msg, err := decode(packet)
if err != nil {
return err
}
switch msg := msg.(type) {
case *globalRequestMsg:
m.incomingRequests <- &Request{
Type: msg.Type,
WantReply: msg.WantReply,
Payload: msg.Data,
mux: m,
}
case *globalRequestSuccessMsg, *globalRequestFailureMsg:
m.globalResponses <- msg
default:
panic(fmt.Sprintf("not a global message %#v", msg))
}
return nil
}
// handleChannelOpen schedules a channel to be Accept()ed.
func (m *mux) handleChannelOpen(packet []byte) error {
var msg channelOpenMsg
if err := Unmarshal(packet, &msg); err != nil {
return err
}
if msg.MaxPacketSize < minPacketLength || msg.MaxPacketSize > 1<<31 {
failMsg := channelOpenFailureMsg{
PeersId: msg.PeersId,
Reason: ConnectionFailed,
Message: "invalid request",
Language: "en_US.UTF-8",
}
return m.sendMessage(failMsg)
}
c := m.newChannel(msg.ChanType, channelInbound, msg.TypeSpecificData)
c.remoteId = msg.PeersId
c.maxRemotePayload = msg.MaxPacketSize
c.remoteWin.add(msg.PeersWindow)
m.incomingChannels <- c
return nil
}
func (m *mux) OpenChannel(chanType string, extra []byte) (Channel, <-chan *Request, error) {
ch, err := m.openChannel(chanType, extra)
if err != nil {
return nil, nil, err
}
return ch, ch.incomingRequests, nil
}
func (m *mux) openChannel(chanType string, extra []byte) (*channel, error) {
ch := m.newChannel(chanType, channelOutbound, extra)
ch.maxIncomingPayload = channelMaxPacket
open := channelOpenMsg{
ChanType: chanType,
PeersWindow: ch.myWindow,
MaxPacketSize: ch.maxIncomingPayload,
TypeSpecificData: extra,
PeersId: ch.localId,
}
if err := m.sendMessage(open); err != nil {
return nil, err
}
switch msg := (<-ch.msg).(type) {
case *channelOpenConfirmMsg:
return ch, nil
case *channelOpenFailureMsg:
return nil, &OpenChannelError{msg.Reason, msg.Message}
default:
return nil, fmt.Errorf("ssh: unexpected packet in response to channel open: %T", msg)
}
}

View file

@ -0,0 +1,525 @@
// Copyright 2013 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ssh
import (
"io"
"io/ioutil"
"sync"
"testing"
)
func muxPair() (*mux, *mux) {
a, b := memPipe()
s := newMux(a)
c := newMux(b)
return s, c
}
// Returns both ends of a channel, and the mux for the the 2nd
// channel.
func channelPair(t *testing.T) (*channel, *channel, *mux) {
c, s := muxPair()
res := make(chan *channel, 1)
go func() {
newCh, ok := <-s.incomingChannels
if !ok {
t.Fatalf("No incoming channel")
}
if newCh.ChannelType() != "chan" {
t.Fatalf("got type %q want chan", newCh.ChannelType())
}
ch, _, err := newCh.Accept()
if err != nil {
t.Fatalf("Accept %v", err)
}
res <- ch.(*channel)
}()
ch, err := c.openChannel("chan", nil)
if err != nil {
t.Fatalf("OpenChannel: %v", err)
}
return <-res, ch, c
}
// Test that stderr and stdout can be addressed from different
// goroutines. This is intended for use with the race detector.
func TestMuxChannelExtendedThreadSafety(t *testing.T) {
writer, reader, mux := channelPair(t)
defer writer.Close()
defer reader.Close()
defer mux.Close()
var wr, rd sync.WaitGroup
magic := "hello world"
wr.Add(2)
go func() {
io.WriteString(writer, magic)
wr.Done()
}()
go func() {
io.WriteString(writer.Stderr(), magic)
wr.Done()
}()
rd.Add(2)
go func() {
c, err := ioutil.ReadAll(reader)
if string(c) != magic {
t.Fatalf("stdout read got %q, want %q (error %s)", c, magic, err)
}
rd.Done()
}()
go func() {
c, err := ioutil.ReadAll(reader.Stderr())
if string(c) != magic {
t.Fatalf("stderr read got %q, want %q (error %s)", c, magic, err)
}
rd.Done()
}()
wr.Wait()
writer.CloseWrite()
rd.Wait()
}
func TestMuxReadWrite(t *testing.T) {
s, c, mux := channelPair(t)
defer s.Close()
defer c.Close()
defer mux.Close()
magic := "hello world"
magicExt := "hello stderr"
go func() {
_, err := s.Write([]byte(magic))
if err != nil {
t.Fatalf("Write: %v", err)
}
_, err = s.Extended(1).Write([]byte(magicExt))
if err != nil {
t.Fatalf("Write: %v", err)
}
err = s.Close()
if err != nil {
t.Fatalf("Close: %v", err)
}
}()
var buf [1024]byte
n, err := c.Read(buf[:])
if err != nil {
t.Fatalf("server Read: %v", err)
}
got := string(buf[:n])
if got != magic {
t.Fatalf("server: got %q want %q", got, magic)
}
n, err = c.Extended(1).Read(buf[:])
if err != nil {
t.Fatalf("server Read: %v", err)
}
got = string(buf[:n])
if got != magicExt {
t.Fatalf("server: got %q want %q", got, magic)
}
}
func TestMuxChannelOverflow(t *testing.T) {
reader, writer, mux := channelPair(t)
defer reader.Close()
defer writer.Close()
defer mux.Close()
wDone := make(chan int, 1)
go func() {
if _, err := writer.Write(make([]byte, channelWindowSize)); err != nil {
t.Errorf("could not fill window: %v", err)
}
writer.Write(make([]byte, 1))
wDone <- 1
}()
writer.remoteWin.waitWriterBlocked()
// Send 1 byte.
packet := make([]byte, 1+4+4+1)
packet[0] = msgChannelData
marshalUint32(packet[1:], writer.remoteId)
marshalUint32(packet[5:], uint32(1))
packet[9] = 42
if err := writer.mux.conn.writePacket(packet); err != nil {
t.Errorf("could not send packet")
}
if _, err := reader.SendRequest("hello", true, nil); err == nil {
t.Errorf("SendRequest succeeded.")
}
<-wDone
}
func TestMuxChannelCloseWriteUnblock(t *testing.T) {
reader, writer, mux := channelPair(t)
defer reader.Close()
defer writer.Close()
defer mux.Close()
wDone := make(chan int, 1)
go func() {
if _, err := writer.Write(make([]byte, channelWindowSize)); err != nil {
t.Errorf("could not fill window: %v", err)
}
if _, err := writer.Write(make([]byte, 1)); err != io.EOF {
t.Errorf("got %v, want EOF for unblock write", err)
}
wDone <- 1
}()
writer.remoteWin.waitWriterBlocked()
reader.Close()
<-wDone
}
func TestMuxConnectionCloseWriteUnblock(t *testing.T) {
reader, writer, mux := channelPair(t)
defer reader.Close()
defer writer.Close()
defer mux.Close()
wDone := make(chan int, 1)
go func() {
if _, err := writer.Write(make([]byte, channelWindowSize)); err != nil {
t.Errorf("could not fill window: %v", err)
}
if _, err := writer.Write(make([]byte, 1)); err != io.EOF {
t.Errorf("got %v, want EOF for unblock write", err)
}
wDone <- 1
}()
writer.remoteWin.waitWriterBlocked()
mux.Close()
<-wDone
}
func TestMuxReject(t *testing.T) {
client, server := muxPair()
defer server.Close()
defer client.Close()
go func() {
ch, ok := <-server.incomingChannels
if !ok {
t.Fatalf("Accept")
}
if ch.ChannelType() != "ch" || string(ch.ExtraData()) != "extra" {
t.Fatalf("unexpected channel: %q, %q", ch.ChannelType(), ch.ExtraData())
}
ch.Reject(RejectionReason(42), "message")
}()
ch, err := client.openChannel("ch", []byte("extra"))
if ch != nil {
t.Fatal("openChannel not rejected")
}
ocf, ok := err.(*OpenChannelError)
if !ok {
t.Errorf("got %#v want *OpenChannelError", err)
} else if ocf.Reason != 42 || ocf.Message != "message" {
t.Errorf("got %#v, want {Reason: 42, Message: %q}", ocf, "message")
}
want := "ssh: rejected: unknown reason 42 (message)"
if err.Error() != want {
t.Errorf("got %q, want %q", err.Error(), want)
}
}
func TestMuxChannelRequest(t *testing.T) {
client, server, mux := channelPair(t)
defer server.Close()
defer client.Close()
defer mux.Close()
var received int
var wg sync.WaitGroup
wg.Add(1)
go func() {
for r := range server.incomingRequests {
received++
r.Reply(r.Type == "yes", nil)
}
wg.Done()
}()
_, err := client.SendRequest("yes", false, nil)
if err != nil {
t.Fatalf("SendRequest: %v", err)
}
ok, err := client.SendRequest("yes", true, nil)
if err != nil {
t.Fatalf("SendRequest: %v", err)
}
if !ok {
t.Errorf("SendRequest(yes): %v", ok)
}
ok, err = client.SendRequest("no", true, nil)
if err != nil {
t.Fatalf("SendRequest: %v", err)
}
if ok {
t.Errorf("SendRequest(no): %v", ok)
}
client.Close()
wg.Wait()
if received != 3 {
t.Errorf("got %d requests, want %d", received, 3)
}
}
func TestMuxGlobalRequest(t *testing.T) {
clientMux, serverMux := muxPair()
defer serverMux.Close()
defer clientMux.Close()
var seen bool
go func() {
for r := range serverMux.incomingRequests {
seen = seen || r.Type == "peek"
if r.WantReply {
err := r.Reply(r.Type == "yes",
append([]byte(r.Type), r.Payload...))
if err != nil {
t.Errorf("AckRequest: %v", err)
}
}
}
}()
_, _, err := clientMux.SendRequest("peek", false, nil)
if err != nil {
t.Errorf("SendRequest: %v", err)
}
ok, data, err := clientMux.SendRequest("yes", true, []byte("a"))
if !ok || string(data) != "yesa" || err != nil {
t.Errorf("SendRequest(\"yes\", true, \"a\"): %v %v %v",
ok, data, err)
}
if ok, data, err := clientMux.SendRequest("yes", true, []byte("a")); !ok || string(data) != "yesa" || err != nil {
t.Errorf("SendRequest(\"yes\", true, \"a\"): %v %v %v",
ok, data, err)
}
if ok, data, err := clientMux.SendRequest("no", true, []byte("a")); ok || string(data) != "noa" || err != nil {
t.Errorf("SendRequest(\"no\", true, \"a\"): %v %v %v",
ok, data, err)
}
clientMux.Disconnect(0, "")
if !seen {
t.Errorf("never saw 'peek' request")
}
}
func TestMuxGlobalRequestUnblock(t *testing.T) {
clientMux, serverMux := muxPair()
defer serverMux.Close()
defer clientMux.Close()
result := make(chan error, 1)
go func() {
_, _, err := clientMux.SendRequest("hello", true, nil)
result <- err
}()
<-serverMux.incomingRequests
serverMux.conn.Close()
err := <-result
if err != io.EOF {
t.Errorf("want EOF, got %v", io.EOF)
}
}
func TestMuxChannelRequestUnblock(t *testing.T) {
a, b, connB := channelPair(t)
defer a.Close()
defer b.Close()
defer connB.Close()
result := make(chan error, 1)
go func() {
_, err := a.SendRequest("hello", true, nil)
result <- err
}()
<-b.incomingRequests
connB.conn.Close()
err := <-result
if err != io.EOF {
t.Errorf("want EOF, got %v", err)
}
}
func TestMuxDisconnect(t *testing.T) {
a, b := muxPair()
defer a.Close()
defer b.Close()
go func() {
for r := range b.incomingRequests {
r.Reply(true, nil)
}
}()
a.Disconnect(42, "whatever")
ok, _, err := a.SendRequest("hello", true, nil)
if ok || err == nil {
t.Errorf("got reply after disconnecting")
}
err = b.Wait()
if d, ok := err.(*disconnectMsg); !ok || d.Reason != 42 {
t.Errorf("got %#v, want disconnectMsg{Reason:42}", err)
}
}
func TestMuxCloseChannel(t *testing.T) {
r, w, mux := channelPair(t)
defer mux.Close()
defer r.Close()
defer w.Close()
result := make(chan error, 1)
go func() {
var b [1024]byte
_, err := r.Read(b[:])
result <- err
}()
if err := w.Close(); err != nil {
t.Errorf("w.Close: %v", err)
}
if _, err := w.Write([]byte("hello")); err != io.EOF {
t.Errorf("got err %v, want io.EOF after Close", err)
}
if err := <-result; err != io.EOF {
t.Errorf("got %v (%T), want io.EOF", err, err)
}
}
func TestMuxCloseWriteChannel(t *testing.T) {
r, w, mux := channelPair(t)
defer mux.Close()
result := make(chan error, 1)
go func() {
var b [1024]byte
_, err := r.Read(b[:])
result <- err
}()
if err := w.CloseWrite(); err != nil {
t.Errorf("w.CloseWrite: %v", err)
}
if _, err := w.Write([]byte("hello")); err != io.EOF {
t.Errorf("got err %v, want io.EOF after CloseWrite", err)
}
if err := <-result; err != io.EOF {
t.Errorf("got %v (%T), want io.EOF", err, err)
}
}
func TestMuxInvalidRecord(t *testing.T) {
a, b := muxPair()
defer a.Close()
defer b.Close()
packet := make([]byte, 1+4+4+1)
packet[0] = msgChannelData
marshalUint32(packet[1:], 29348723 /* invalid channel id */)
marshalUint32(packet[5:], 1)
packet[9] = 42
a.conn.writePacket(packet)
go a.SendRequest("hello", false, nil)
// 'a' wrote an invalid packet, so 'b' has exited.
req, ok := <-b.incomingRequests
if ok {
t.Errorf("got request %#v after receiving invalid packet", req)
}
}
func TestZeroWindowAdjust(t *testing.T) {
a, b, mux := channelPair(t)
defer a.Close()
defer b.Close()
defer mux.Close()
go func() {
io.WriteString(a, "hello")
// bogus adjust.
a.sendMessage(windowAdjustMsg{})
io.WriteString(a, "world")
a.Close()
}()
want := "helloworld"
c, _ := ioutil.ReadAll(b)
if string(c) != want {
t.Errorf("got %q want %q", c, want)
}
}
func TestMuxMaxPacketSize(t *testing.T) {
a, b, mux := channelPair(t)
defer a.Close()
defer b.Close()
defer mux.Close()
large := make([]byte, a.maxRemotePayload+1)
packet := make([]byte, 1+4+4+1+len(large))
packet[0] = msgChannelData
marshalUint32(packet[1:], a.remoteId)
marshalUint32(packet[5:], uint32(len(large)))
packet[9] = 42
if err := a.mux.conn.writePacket(packet); err != nil {
t.Errorf("could not send packet")
}
go a.SendRequest("hello", false, nil)
_, ok := <-b.incomingRequests
if ok {
t.Errorf("connection still alive after receiving large packet.")
}
}
// Don't ship code with debug=true.
func TestDebug(t *testing.T) {
if debugMux {
t.Error("mux debug switched on")
}
if debugHandshake {
t.Error("handshake debug switched on")
}
}

View file

@ -0,0 +1,477 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ssh
import (
"bytes"
"errors"
"fmt"
"io"
"net"
)
// The Permissions type holds fine-grained permissions that are
// specific to a user or a specific authentication method for a
// user. Permissions, except for "source-address", must be enforced in
// the server application layer, after successful authentication. The
// Permissions are passed on in ServerConn so a server implementation
// can honor them.
type Permissions struct {
// Critical options restrict default permissions. Common
// restrictions are "source-address" and "force-command". If
// the server cannot enforce the restriction, or does not
// recognize it, the user should not authenticate.
CriticalOptions map[string]string
// Extensions are extra functionality that the server may
// offer on authenticated connections. Common extensions are
// "permit-agent-forwarding", "permit-X11-forwarding". Lack of
// support for an extension does not preclude authenticating a
// user.
Extensions map[string]string
}
// ServerConfig holds server specific configuration data.
type ServerConfig struct {
// Config contains configuration shared between client and server.
Config
hostKeys []Signer
// NoClientAuth is true if clients are allowed to connect without
// authenticating.
NoClientAuth bool
// PasswordCallback, if non-nil, is called when a user
// attempts to authenticate using a password.
PasswordCallback func(conn ConnMetadata, password []byte) (*Permissions, error)
// PublicKeyCallback, if non-nil, is called when a client attempts public
// key authentication. It must return true if the given public key is
// valid for the given user. For example, see CertChecker.Authenticate.
PublicKeyCallback func(conn ConnMetadata, key PublicKey) (*Permissions, error)
// KeyboardInteractiveCallback, if non-nil, is called when
// keyboard-interactive authentication is selected (RFC
// 4256). The client object's Challenge function should be
// used to query the user. The callback may offer multiple
// Challenge rounds. To avoid information leaks, the client
// should be presented a challenge even if the user is
// unknown.
KeyboardInteractiveCallback func(conn ConnMetadata, client KeyboardInteractiveChallenge) (*Permissions, error)
// AuthLogCallback, if non-nil, is called to log all authentication
// attempts.
AuthLogCallback func(conn ConnMetadata, method string, err error)
}
// AddHostKey adds a private key as a host key. If an existing host
// key exists with the same algorithm, it is overwritten. Each server
// config must have at least one host key.
func (s *ServerConfig) AddHostKey(key Signer) {
for i, k := range s.hostKeys {
if k.PublicKey().Type() == key.PublicKey().Type() {
s.hostKeys[i] = key
return
}
}
s.hostKeys = append(s.hostKeys, key)
}
// cachedPubKey contains the results of querying whether a public key is
// acceptable for a user.
type cachedPubKey struct {
user string
pubKeyData []byte
result error
perms *Permissions
}
const maxCachedPubKeys = 16
// pubKeyCache caches tests for public keys. Since SSH clients
// will query whether a public key is acceptable before attempting to
// authenticate with it, we end up with duplicate queries for public
// key validity. The cache only applies to a single ServerConn.
type pubKeyCache struct {
keys []cachedPubKey
}
// get returns the result for a given user/algo/key tuple.
func (c *pubKeyCache) get(user string, pubKeyData []byte) (cachedPubKey, bool) {
for _, k := range c.keys {
if k.user == user && bytes.Equal(k.pubKeyData, pubKeyData) {
return k, true
}
}
return cachedPubKey{}, false
}
// add adds the given tuple to the cache.
func (c *pubKeyCache) add(candidate cachedPubKey) {
if len(c.keys) < maxCachedPubKeys {
c.keys = append(c.keys, candidate)
}
}
// ServerConn is an authenticated SSH connection, as seen from the
// server
type ServerConn struct {
Conn
// If the succeeding authentication callback returned a
// non-nil Permissions pointer, it is stored here.
Permissions *Permissions
}
// NewServerConn starts a new SSH server with c as the underlying
// transport. It starts with a handshake and, if the handshake is
// unsuccessful, it closes the connection and returns an error. The
// Request and NewChannel channels must be serviced, or the connection
// will hang.
func NewServerConn(c net.Conn, config *ServerConfig) (*ServerConn, <-chan NewChannel, <-chan *Request, error) {
fullConf := *config
fullConf.SetDefaults()
s := &connection{
sshConn: sshConn{conn: c},
}
perms, err := s.serverHandshake(&fullConf)
if err != nil {
c.Close()
return nil, nil, nil, err
}
return &ServerConn{s, perms}, s.mux.incomingChannels, s.mux.incomingRequests, nil
}
// signAndMarshal signs the data with the appropriate algorithm,
// and serializes the result in SSH wire format.
func signAndMarshal(k Signer, rand io.Reader, data []byte) ([]byte, error) {
sig, err := k.Sign(rand, data)
if err != nil {
return nil, err
}
return Marshal(sig), nil
}
// handshake performs key exchange and user authentication.
func (s *connection) serverHandshake(config *ServerConfig) (*Permissions, error) {
if len(config.hostKeys) == 0 {
return nil, errors.New("ssh: server has no host keys")
}
var err error
s.serverVersion = []byte(packageVersion)
s.clientVersion, err = exchangeVersions(s.sshConn.conn, s.serverVersion)
if err != nil {
return nil, err
}
tr := newTransport(s.sshConn.conn, config.Rand, false /* not client */)
s.transport = newServerTransport(tr, s.clientVersion, s.serverVersion, config)
if err := s.transport.requestKeyChange(); err != nil {
return nil, err
}
if packet, err := s.transport.readPacket(); err != nil {
return nil, err
} else if packet[0] != msgNewKeys {
return nil, unexpectedMessageError(msgNewKeys, packet[0])
}
var packet []byte
if packet, err = s.transport.readPacket(); err != nil {
return nil, err
}
var serviceRequest serviceRequestMsg
if err = Unmarshal(packet, &serviceRequest); err != nil {
return nil, err
}
if serviceRequest.Service != serviceUserAuth {
return nil, errors.New("ssh: requested service '" + serviceRequest.Service + "' before authenticating")
}
serviceAccept := serviceAcceptMsg{
Service: serviceUserAuth,
}
if err := s.transport.writePacket(Marshal(&serviceAccept)); err != nil {
return nil, err
}
perms, err := s.serverAuthenticate(config)
if err != nil {
return nil, err
}
s.mux = newMux(s.transport)
return perms, err
}
func isAcceptableAlgo(algo string) bool {
switch algo {
case KeyAlgoRSA, KeyAlgoDSA, KeyAlgoECDSA256, KeyAlgoECDSA384, KeyAlgoECDSA521,
CertAlgoRSAv01, CertAlgoDSAv01, CertAlgoECDSA256v01, CertAlgoECDSA384v01, CertAlgoECDSA521v01:
return true
}
return false
}
func checkSourceAddress(addr net.Addr, sourceAddr string) error {
if addr == nil {
return errors.New("ssh: no address known for client, but source-address match required")
}
tcpAddr, ok := addr.(*net.TCPAddr)
if !ok {
return fmt.Errorf("ssh: remote address %v is not an TCP address when checking source-address match", addr)
}
if allowedIP := net.ParseIP(sourceAddr); allowedIP != nil {
if bytes.Equal(allowedIP, tcpAddr.IP) {
return nil
}
} else {
_, ipNet, err := net.ParseCIDR(sourceAddr)
if err != nil {
return fmt.Errorf("ssh: error parsing source-address restriction %q: %v", sourceAddr, err)
}
if ipNet.Contains(tcpAddr.IP) {
return nil
}
}
return fmt.Errorf("ssh: remote address %v is not allowed because of source-address restriction", addr)
}
func (s *connection) serverAuthenticate(config *ServerConfig) (*Permissions, error) {
var err error
var cache pubKeyCache
var perms *Permissions
userAuthLoop:
for {
var userAuthReq userAuthRequestMsg
if packet, err := s.transport.readPacket(); err != nil {
return nil, err
} else if err = Unmarshal(packet, &userAuthReq); err != nil {
return nil, err
}
if userAuthReq.Service != serviceSSH {
return nil, errors.New("ssh: client attempted to negotiate for unknown service: " + userAuthReq.Service)
}
s.user = userAuthReq.User
perms = nil
authErr := errors.New("no auth passed yet")
switch userAuthReq.Method {
case "none":
if config.NoClientAuth {
s.user = ""
authErr = nil
}
case "password":
if config.PasswordCallback == nil {
authErr = errors.New("ssh: password auth not configured")
break
}
payload := userAuthReq.Payload
if len(payload) < 1 || payload[0] != 0 {
return nil, parseError(msgUserAuthRequest)
}
payload = payload[1:]
password, payload, ok := parseString(payload)
if !ok || len(payload) > 0 {
return nil, parseError(msgUserAuthRequest)
}
perms, authErr = config.PasswordCallback(s, password)
case "keyboard-interactive":
if config.KeyboardInteractiveCallback == nil {
authErr = errors.New("ssh: keyboard-interactive auth not configubred")
break
}
prompter := &sshClientKeyboardInteractive{s}
perms, authErr = config.KeyboardInteractiveCallback(s, prompter.Challenge)
case "publickey":
if config.PublicKeyCallback == nil {
authErr = errors.New("ssh: publickey auth not configured")
break
}
payload := userAuthReq.Payload
if len(payload) < 1 {
return nil, parseError(msgUserAuthRequest)
}
isQuery := payload[0] == 0
payload = payload[1:]
algoBytes, payload, ok := parseString(payload)
if !ok {
return nil, parseError(msgUserAuthRequest)
}
algo := string(algoBytes)
if !isAcceptableAlgo(algo) {
authErr = fmt.Errorf("ssh: algorithm %q not accepted", algo)
break
}
pubKeyData, payload, ok := parseString(payload)
if !ok {
return nil, parseError(msgUserAuthRequest)
}
pubKey, err := ParsePublicKey(pubKeyData)
if err != nil {
return nil, err
}
candidate, ok := cache.get(s.user, pubKeyData)
if !ok {
candidate.user = s.user
candidate.pubKeyData = pubKeyData
candidate.perms, candidate.result = config.PublicKeyCallback(s, pubKey)
if candidate.result == nil && candidate.perms != nil && candidate.perms.CriticalOptions != nil && candidate.perms.CriticalOptions[sourceAddressCriticalOption] != "" {
candidate.result = checkSourceAddress(
s.RemoteAddr(),
candidate.perms.CriticalOptions[sourceAddressCriticalOption])
}
cache.add(candidate)
}
if isQuery {
// The client can query if the given public key
// would be okay.
if len(payload) > 0 {
return nil, parseError(msgUserAuthRequest)
}
if candidate.result == nil {
okMsg := userAuthPubKeyOkMsg{
Algo: algo,
PubKey: pubKeyData,
}
if err = s.transport.writePacket(Marshal(&okMsg)); err != nil {
return nil, err
}
continue userAuthLoop
}
authErr = candidate.result
} else {
sig, payload, ok := parseSignature(payload)
if !ok || len(payload) > 0 {
return nil, parseError(msgUserAuthRequest)
}
// Ensure the public key algo and signature algo
// are supported. Compare the private key
// algorithm name that corresponds to algo with
// sig.Format. This is usually the same, but
// for certs, the names differ.
if !isAcceptableAlgo(sig.Format) {
break
}
signedData := buildDataSignedForAuth(s.transport.getSessionID(), userAuthReq, algoBytes, pubKeyData)
if err := pubKey.Verify(signedData, sig); err != nil {
return nil, err
}
authErr = candidate.result
perms = candidate.perms
}
default:
authErr = fmt.Errorf("ssh: unknown method %q", userAuthReq.Method)
}
if config.AuthLogCallback != nil {
config.AuthLogCallback(s, userAuthReq.Method, authErr)
}
if authErr == nil {
break userAuthLoop
}
var failureMsg userAuthFailureMsg
if config.PasswordCallback != nil {
failureMsg.Methods = append(failureMsg.Methods, "password")
}
if config.PublicKeyCallback != nil {
failureMsg.Methods = append(failureMsg.Methods, "publickey")
}
if config.KeyboardInteractiveCallback != nil {
failureMsg.Methods = append(failureMsg.Methods, "keyboard-interactive")
}
if len(failureMsg.Methods) == 0 {
return nil, errors.New("ssh: no authentication methods configured but NoClientAuth is also false")
}
if err = s.transport.writePacket(Marshal(&failureMsg)); err != nil {
return nil, err
}
}
if err = s.transport.writePacket([]byte{msgUserAuthSuccess}); err != nil {
return nil, err
}
return perms, nil
}
// sshClientKeyboardInteractive implements a ClientKeyboardInteractive by
// asking the client on the other side of a ServerConn.
type sshClientKeyboardInteractive struct {
*connection
}
func (c *sshClientKeyboardInteractive) Challenge(user, instruction string, questions []string, echos []bool) (answers []string, err error) {
if len(questions) != len(echos) {
return nil, errors.New("ssh: echos and questions must have equal length")
}
var prompts []byte
for i := range questions {
prompts = appendString(prompts, questions[i])
prompts = appendBool(prompts, echos[i])
}
if err := c.transport.writePacket(Marshal(&userAuthInfoRequestMsg{
Instruction: instruction,
NumPrompts: uint32(len(questions)),
Prompts: prompts,
})); err != nil {
return nil, err
}
packet, err := c.transport.readPacket()
if err != nil {
return nil, err
}
if packet[0] != msgUserAuthInfoResponse {
return nil, unexpectedMessageError(msgUserAuthInfoResponse, packet[0])
}
packet = packet[1:]
n, packet, ok := parseUint32(packet)
if !ok || int(n) != len(questions) {
return nil, parseError(msgUserAuthInfoResponse)
}
for i := uint32(0); i < n; i++ {
ans, rest, ok := parseString(packet)
if !ok {
return nil, parseError(msgUserAuthInfoResponse)
}
answers = append(answers, string(ans))
packet = rest
}
if len(packet) != 0 {
return nil, errors.New("ssh: junk at end of message")
}
return answers, nil
}

View file

@ -0,0 +1,605 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ssh
// Session implements an interactive session described in
// "RFC 4254, section 6".
import (
"bytes"
"errors"
"fmt"
"io"
"io/ioutil"
"sync"
)
type Signal string
// POSIX signals as listed in RFC 4254 Section 6.10.
const (
SIGABRT Signal = "ABRT"
SIGALRM Signal = "ALRM"
SIGFPE Signal = "FPE"
SIGHUP Signal = "HUP"
SIGILL Signal = "ILL"
SIGINT Signal = "INT"
SIGKILL Signal = "KILL"
SIGPIPE Signal = "PIPE"
SIGQUIT Signal = "QUIT"
SIGSEGV Signal = "SEGV"
SIGTERM Signal = "TERM"
SIGUSR1 Signal = "USR1"
SIGUSR2 Signal = "USR2"
)
var signals = map[Signal]int{
SIGABRT: 6,
SIGALRM: 14,
SIGFPE: 8,
SIGHUP: 1,
SIGILL: 4,
SIGINT: 2,
SIGKILL: 9,
SIGPIPE: 13,
SIGQUIT: 3,
SIGSEGV: 11,
SIGTERM: 15,
}
type TerminalModes map[uint8]uint32
// POSIX terminal mode flags as listed in RFC 4254 Section 8.
const (
tty_OP_END = 0
VINTR = 1
VQUIT = 2
VERASE = 3
VKILL = 4
VEOF = 5
VEOL = 6
VEOL2 = 7
VSTART = 8
VSTOP = 9
VSUSP = 10
VDSUSP = 11
VREPRINT = 12
VWERASE = 13
VLNEXT = 14
VFLUSH = 15
VSWTCH = 16
VSTATUS = 17
VDISCARD = 18
IGNPAR = 30
PARMRK = 31
INPCK = 32
ISTRIP = 33
INLCR = 34
IGNCR = 35
ICRNL = 36
IUCLC = 37
IXON = 38
IXANY = 39
IXOFF = 40
IMAXBEL = 41
ISIG = 50
ICANON = 51
XCASE = 52
ECHO = 53
ECHOE = 54
ECHOK = 55
ECHONL = 56
NOFLSH = 57
TOSTOP = 58
IEXTEN = 59
ECHOCTL = 60
ECHOKE = 61
PENDIN = 62
OPOST = 70
OLCUC = 71
ONLCR = 72
OCRNL = 73
ONOCR = 74
ONLRET = 75
CS7 = 90
CS8 = 91
PARENB = 92
PARODD = 93
TTY_OP_ISPEED = 128
TTY_OP_OSPEED = 129
)
// A Session represents a connection to a remote command or shell.
type Session struct {
// Stdin specifies the remote process's standard input.
// If Stdin is nil, the remote process reads from an empty
// bytes.Buffer.
Stdin io.Reader
// Stdout and Stderr specify the remote process's standard
// output and error.
//
// If either is nil, Run connects the corresponding file
// descriptor to an instance of ioutil.Discard. There is a
// fixed amount of buffering that is shared for the two streams.
// If either blocks it may eventually cause the remote
// command to block.
Stdout io.Writer
Stderr io.Writer
ch Channel // the channel backing this session
started bool // true once Start, Run or Shell is invoked.
copyFuncs []func() error
errors chan error // one send per copyFunc
// true if pipe method is active
stdinpipe, stdoutpipe, stderrpipe bool
// stdinPipeWriter is non-nil if StdinPipe has not been called
// and Stdin was specified by the user; it is the write end of
// a pipe connecting Session.Stdin to the stdin channel.
stdinPipeWriter io.WriteCloser
exitStatus chan error
}
// SendRequest sends an out-of-band channel request on the SSH channel
// underlying the session.
func (s *Session) SendRequest(name string, wantReply bool, payload []byte) (bool, error) {
return s.ch.SendRequest(name, wantReply, payload)
}
func (s *Session) Close() error {
return s.ch.Close()
}
// RFC 4254 Section 6.4.
type setenvRequest struct {
Name string
Value string
}
// Setenv sets an environment variable that will be applied to any
// command executed by Shell or Run.
func (s *Session) Setenv(name, value string) error {
msg := setenvRequest{
Name: name,
Value: value,
}
ok, err := s.ch.SendRequest("env", true, Marshal(&msg))
if err == nil && !ok {
err = errors.New("ssh: setenv failed")
}
return err
}
// RFC 4254 Section 6.2.
type ptyRequestMsg struct {
Term string
Columns uint32
Rows uint32
Width uint32
Height uint32
Modelist string
}
// RequestPty requests the association of a pty with the session on the remote host.
func (s *Session) RequestPty(term string, h, w int, termmodes TerminalModes) error {
var tm []byte
for k, v := range termmodes {
kv := struct {
Key byte
Val uint32
}{k, v}
tm = append(tm, Marshal(&kv)...)
}
tm = append(tm, tty_OP_END)
req := ptyRequestMsg{
Term: term,
Columns: uint32(w),
Rows: uint32(h),
Width: uint32(w * 8),
Height: uint32(h * 8),
Modelist: string(tm),
}
ok, err := s.ch.SendRequest("pty-req", true, Marshal(&req))
if err == nil && !ok {
err = errors.New("ssh: pty-req failed")
}
return err
}
// RFC 4254 Section 6.5.
type subsystemRequestMsg struct {
Subsystem string
}
// RequestSubsystem requests the association of a subsystem with the session on the remote host.
// A subsystem is a predefined command that runs in the background when the ssh session is initiated
func (s *Session) RequestSubsystem(subsystem string) error {
msg := subsystemRequestMsg{
Subsystem: subsystem,
}
ok, err := s.ch.SendRequest("subsystem", true, Marshal(&msg))
if err == nil && !ok {
err = errors.New("ssh: subsystem request failed")
}
return err
}
// RFC 4254 Section 6.9.
type signalMsg struct {
Signal string
}
// Signal sends the given signal to the remote process.
// sig is one of the SIG* constants.
func (s *Session) Signal(sig Signal) error {
msg := signalMsg{
Signal: string(sig),
}
_, err := s.ch.SendRequest("signal", false, Marshal(&msg))
return err
}
// RFC 4254 Section 6.5.
type execMsg struct {
Command string
}
// Start runs cmd on the remote host. Typically, the remote
// server passes cmd to the shell for interpretation.
// A Session only accepts one call to Run, Start or Shell.
func (s *Session) Start(cmd string) error {
if s.started {
return errors.New("ssh: session already started")
}
req := execMsg{
Command: cmd,
}
ok, err := s.ch.SendRequest("exec", true, Marshal(&req))
if err == nil && !ok {
err = fmt.Errorf("ssh: command %v failed", cmd)
}
if err != nil {
return err
}
return s.start()
}
// Run runs cmd on the remote host. Typically, the remote
// server passes cmd to the shell for interpretation.
// A Session only accepts one call to Run, Start, Shell, Output,
// or CombinedOutput.
//
// The returned error is nil if the command runs, has no problems
// copying stdin, stdout, and stderr, and exits with a zero exit
// status.
//
// If the command fails to run or doesn't complete successfully, the
// error is of type *ExitError. Other error types may be
// returned for I/O problems.
func (s *Session) Run(cmd string) error {
err := s.Start(cmd)
if err != nil {
return err
}
return s.Wait()
}
// Output runs cmd on the remote host and returns its standard output.
func (s *Session) Output(cmd string) ([]byte, error) {
if s.Stdout != nil {
return nil, errors.New("ssh: Stdout already set")
}
var b bytes.Buffer
s.Stdout = &b
err := s.Run(cmd)
return b.Bytes(), err
}
type singleWriter struct {
b bytes.Buffer
mu sync.Mutex
}
func (w *singleWriter) Write(p []byte) (int, error) {
w.mu.Lock()
defer w.mu.Unlock()
return w.b.Write(p)
}
// CombinedOutput runs cmd on the remote host and returns its combined
// standard output and standard error.
func (s *Session) CombinedOutput(cmd string) ([]byte, error) {
if s.Stdout != nil {
return nil, errors.New("ssh: Stdout already set")
}
if s.Stderr != nil {
return nil, errors.New("ssh: Stderr already set")
}
var b singleWriter
s.Stdout = &b
s.Stderr = &b
err := s.Run(cmd)
return b.b.Bytes(), err
}
// Shell starts a login shell on the remote host. A Session only
// accepts one call to Run, Start, Shell, Output, or CombinedOutput.
func (s *Session) Shell() error {
if s.started {
return errors.New("ssh: session already started")
}
ok, err := s.ch.SendRequest("shell", true, nil)
if err == nil && !ok {
return fmt.Errorf("ssh: cound not start shell")
}
if err != nil {
return err
}
return s.start()
}
func (s *Session) start() error {
s.started = true
type F func(*Session)
for _, setupFd := range []F{(*Session).stdin, (*Session).stdout, (*Session).stderr} {
setupFd(s)
}
s.errors = make(chan error, len(s.copyFuncs))
for _, fn := range s.copyFuncs {
go func(fn func() error) {
s.errors <- fn()
}(fn)
}
return nil
}
// Wait waits for the remote command to exit.
//
// The returned error is nil if the command runs, has no problems
// copying stdin, stdout, and stderr, and exits with a zero exit
// status.
//
// If the command fails to run or doesn't complete successfully, the
// error is of type *ExitError. Other error types may be
// returned for I/O problems.
func (s *Session) Wait() error {
if !s.started {
return errors.New("ssh: session not started")
}
waitErr := <-s.exitStatus
if s.stdinPipeWriter != nil {
s.stdinPipeWriter.Close()
}
var copyError error
for _ = range s.copyFuncs {
if err := <-s.errors; err != nil && copyError == nil {
copyError = err
}
}
if waitErr != nil {
return waitErr
}
return copyError
}
func (s *Session) wait(reqs <-chan *Request) error {
wm := Waitmsg{status: -1}
// Wait for msg channel to be closed before returning.
for msg := range reqs {
switch msg.Type {
case "exit-status":
d := msg.Payload
wm.status = int(d[0])<<24 | int(d[1])<<16 | int(d[2])<<8 | int(d[3])
case "exit-signal":
var sigval struct {
Signal string
CoreDumped bool
Error string
Lang string
}
if err := Unmarshal(msg.Payload, &sigval); err != nil {
return err
}
// Must sanitize strings?
wm.signal = sigval.Signal
wm.msg = sigval.Error
wm.lang = sigval.Lang
default:
// This handles keepalives and matches
// OpenSSH's behaviour.
if msg.WantReply {
msg.Reply(false, nil)
}
}
}
if wm.status == 0 {
return nil
}
if wm.status == -1 {
// exit-status was never sent from server
if wm.signal == "" {
return errors.New("wait: remote command exited without exit status or exit signal")
}
wm.status = 128
if _, ok := signals[Signal(wm.signal)]; ok {
wm.status += signals[Signal(wm.signal)]
}
}
return &ExitError{wm}
}
func (s *Session) stdin() {
if s.stdinpipe {
return
}
var stdin io.Reader
if s.Stdin == nil {
stdin = new(bytes.Buffer)
} else {
r, w := io.Pipe()
go func() {
_, err := io.Copy(w, s.Stdin)
w.CloseWithError(err)
}()
stdin, s.stdinPipeWriter = r, w
}
s.copyFuncs = append(s.copyFuncs, func() error {
_, err := io.Copy(s.ch, stdin)
if err1 := s.ch.CloseWrite(); err == nil && err1 != io.EOF {
err = err1
}
return err
})
}
func (s *Session) stdout() {
if s.stdoutpipe {
return
}
if s.Stdout == nil {
s.Stdout = ioutil.Discard
}
s.copyFuncs = append(s.copyFuncs, func() error {
_, err := io.Copy(s.Stdout, s.ch)
return err
})
}
func (s *Session) stderr() {
if s.stderrpipe {
return
}
if s.Stderr == nil {
s.Stderr = ioutil.Discard
}
s.copyFuncs = append(s.copyFuncs, func() error {
_, err := io.Copy(s.Stderr, s.ch.Stderr())
return err
})
}
// sessionStdin reroutes Close to CloseWrite.
type sessionStdin struct {
io.Writer
ch Channel
}
func (s *sessionStdin) Close() error {
return s.ch.CloseWrite()
}
// StdinPipe returns a pipe that will be connected to the
// remote command's standard input when the command starts.
func (s *Session) StdinPipe() (io.WriteCloser, error) {
if s.Stdin != nil {
return nil, errors.New("ssh: Stdin already set")
}
if s.started {
return nil, errors.New("ssh: StdinPipe after process started")
}
s.stdinpipe = true
return &sessionStdin{s.ch, s.ch}, nil
}
// StdoutPipe returns a pipe that will be connected to the
// remote command's standard output when the command starts.
// There is a fixed amount of buffering that is shared between
// stdout and stderr streams. If the StdoutPipe reader is
// not serviced fast enough it may eventually cause the
// remote command to block.
func (s *Session) StdoutPipe() (io.Reader, error) {
if s.Stdout != nil {
return nil, errors.New("ssh: Stdout already set")
}
if s.started {
return nil, errors.New("ssh: StdoutPipe after process started")
}
s.stdoutpipe = true
return s.ch, nil
}
// StderrPipe returns a pipe that will be connected to the
// remote command's standard error when the command starts.
// There is a fixed amount of buffering that is shared between
// stdout and stderr streams. If the StderrPipe reader is
// not serviced fast enough it may eventually cause the
// remote command to block.
func (s *Session) StderrPipe() (io.Reader, error) {
if s.Stderr != nil {
return nil, errors.New("ssh: Stderr already set")
}
if s.started {
return nil, errors.New("ssh: StderrPipe after process started")
}
s.stderrpipe = true
return s.ch.Stderr(), nil
}
// newSession returns a new interactive session on the remote host.
func newSession(ch Channel, reqs <-chan *Request) (*Session, error) {
s := &Session{
ch: ch,
}
s.exitStatus = make(chan error, 1)
go func() {
s.exitStatus <- s.wait(reqs)
}()
return s, nil
}
// An ExitError reports unsuccessful completion of a remote command.
type ExitError struct {
Waitmsg
}
func (e *ExitError) Error() string {
return e.Waitmsg.String()
}
// Waitmsg stores the information about an exited remote command
// as reported by Wait.
type Waitmsg struct {
status int
signal string
msg string
lang string
}
// ExitStatus returns the exit status of the remote command.
func (w Waitmsg) ExitStatus() int {
return w.status
}
// Signal returns the exit signal of the remote command if
// it was terminated violently.
func (w Waitmsg) Signal() string {
return w.signal
}
// Msg returns the exit message given by the remote command
func (w Waitmsg) Msg() string {
return w.msg
}
// Lang returns the language tag. See RFC 3066
func (w Waitmsg) Lang() string {
return w.lang
}
func (w Waitmsg) String() string {
return fmt.Sprintf("Process exited with: %v. Reason was: %v (%v)", w.status, w.msg, w.signal)
}

View file

@ -0,0 +1,628 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ssh
// Session tests.
import (
"bytes"
crypto_rand "crypto/rand"
"io"
"io/ioutil"
"math/rand"
"testing"
"golang.org/x/crypto/ssh/terminal"
)
type serverType func(Channel, <-chan *Request, *testing.T)
// dial constructs a new test server and returns a *ClientConn.
func dial(handler serverType, t *testing.T) *Client {
c1, c2, err := netPipe()
if err != nil {
t.Fatalf("netPipe: %v", err)
}
go func() {
defer c1.Close()
conf := ServerConfig{
NoClientAuth: true,
}
conf.AddHostKey(testSigners["rsa"])
_, chans, reqs, err := NewServerConn(c1, &conf)
if err != nil {
t.Fatalf("Unable to handshake: %v", err)
}
go DiscardRequests(reqs)
for newCh := range chans {
if newCh.ChannelType() != "session" {
newCh.Reject(UnknownChannelType, "unknown channel type")
continue
}
ch, inReqs, err := newCh.Accept()
if err != nil {
t.Errorf("Accept: %v", err)
continue
}
go func() {
handler(ch, inReqs, t)
}()
}
}()
config := &ClientConfig{
User: "testuser",
}
conn, chans, reqs, err := NewClientConn(c2, "", config)
if err != nil {
t.Fatalf("unable to dial remote side: %v", err)
}
return NewClient(conn, chans, reqs)
}
// Test a simple string is returned to session.Stdout.
func TestSessionShell(t *testing.T) {
conn := dial(shellHandler, t)
defer conn.Close()
session, err := conn.NewSession()
if err != nil {
t.Fatalf("Unable to request new session: %v", err)
}
defer session.Close()
stdout := new(bytes.Buffer)
session.Stdout = stdout
if err := session.Shell(); err != nil {
t.Fatalf("Unable to execute command: %s", err)
}
if err := session.Wait(); err != nil {
t.Fatalf("Remote command did not exit cleanly: %v", err)
}
actual := stdout.String()
if actual != "golang" {
t.Fatalf("Remote shell did not return expected string: expected=golang, actual=%s", actual)
}
}
// TODO(dfc) add support for Std{in,err}Pipe when the Server supports it.
// Test a simple string is returned via StdoutPipe.
func TestSessionStdoutPipe(t *testing.T) {
conn := dial(shellHandler, t)
defer conn.Close()
session, err := conn.NewSession()
if err != nil {
t.Fatalf("Unable to request new session: %v", err)
}
defer session.Close()
stdout, err := session.StdoutPipe()
if err != nil {
t.Fatalf("Unable to request StdoutPipe(): %v", err)
}
var buf bytes.Buffer
if err := session.Shell(); err != nil {
t.Fatalf("Unable to execute command: %v", err)
}
done := make(chan bool, 1)
go func() {
if _, err := io.Copy(&buf, stdout); err != nil {
t.Errorf("Copy of stdout failed: %v", err)
}
done <- true
}()
if err := session.Wait(); err != nil {
t.Fatalf("Remote command did not exit cleanly: %v", err)
}
<-done
actual := buf.String()
if actual != "golang" {
t.Fatalf("Remote shell did not return expected string: expected=golang, actual=%s", actual)
}
}
// Test that a simple string is returned via the Output helper,
// and that stderr is discarded.
func TestSessionOutput(t *testing.T) {
conn := dial(fixedOutputHandler, t)
defer conn.Close()
session, err := conn.NewSession()
if err != nil {
t.Fatalf("Unable to request new session: %v", err)
}
defer session.Close()
buf, err := session.Output("") // cmd is ignored by fixedOutputHandler
if err != nil {
t.Error("Remote command did not exit cleanly:", err)
}
w := "this-is-stdout."
g := string(buf)
if g != w {
t.Error("Remote command did not return expected string:")
t.Logf("want %q", w)
t.Logf("got %q", g)
}
}
// Test that both stdout and stderr are returned
// via the CombinedOutput helper.
func TestSessionCombinedOutput(t *testing.T) {
conn := dial(fixedOutputHandler, t)
defer conn.Close()
session, err := conn.NewSession()
if err != nil {
t.Fatalf("Unable to request new session: %v", err)
}
defer session.Close()
buf, err := session.CombinedOutput("") // cmd is ignored by fixedOutputHandler
if err != nil {
t.Error("Remote command did not exit cleanly:", err)
}
const stdout = "this-is-stdout."
const stderr = "this-is-stderr."
g := string(buf)
if g != stdout+stderr && g != stderr+stdout {
t.Error("Remote command did not return expected string:")
t.Logf("want %q, or %q", stdout+stderr, stderr+stdout)
t.Logf("got %q", g)
}
}
// Test non-0 exit status is returned correctly.
func TestExitStatusNonZero(t *testing.T) {
conn := dial(exitStatusNonZeroHandler, t)
defer conn.Close()
session, err := conn.NewSession()
if err != nil {
t.Fatalf("Unable to request new session: %v", err)
}
defer session.Close()
if err := session.Shell(); err != nil {
t.Fatalf("Unable to execute command: %v", err)
}
err = session.Wait()
if err == nil {
t.Fatalf("expected command to fail but it didn't")
}
e, ok := err.(*ExitError)
if !ok {
t.Fatalf("expected *ExitError but got %T", err)
}
if e.ExitStatus() != 15 {
t.Fatalf("expected command to exit with 15 but got %v", e.ExitStatus())
}
}
// Test 0 exit status is returned correctly.
func TestExitStatusZero(t *testing.T) {
conn := dial(exitStatusZeroHandler, t)
defer conn.Close()
session, err := conn.NewSession()
if err != nil {
t.Fatalf("Unable to request new session: %v", err)
}
defer session.Close()
if err := session.Shell(); err != nil {
t.Fatalf("Unable to execute command: %v", err)
}
err = session.Wait()
if err != nil {
t.Fatalf("expected nil but got %v", err)
}
}
// Test exit signal and status are both returned correctly.
func TestExitSignalAndStatus(t *testing.T) {
conn := dial(exitSignalAndStatusHandler, t)
defer conn.Close()
session, err := conn.NewSession()
if err != nil {
t.Fatalf("Unable to request new session: %v", err)
}
defer session.Close()
if err := session.Shell(); err != nil {
t.Fatalf("Unable to execute command: %v", err)
}
err = session.Wait()
if err == nil {
t.Fatalf("expected command to fail but it didn't")
}
e, ok := err.(*ExitError)
if !ok {
t.Fatalf("expected *ExitError but got %T", err)
}
if e.Signal() != "TERM" || e.ExitStatus() != 15 {
t.Fatalf("expected command to exit with signal TERM and status 15 but got signal %s and status %v", e.Signal(), e.ExitStatus())
}
}
// Test exit signal and status are both returned correctly.
func TestKnownExitSignalOnly(t *testing.T) {
conn := dial(exitSignalHandler, t)
defer conn.Close()
session, err := conn.NewSession()
if err != nil {
t.Fatalf("Unable to request new session: %v", err)
}
defer session.Close()
if err := session.Shell(); err != nil {
t.Fatalf("Unable to execute command: %v", err)
}
err = session.Wait()
if err == nil {
t.Fatalf("expected command to fail but it didn't")
}
e, ok := err.(*ExitError)
if !ok {
t.Fatalf("expected *ExitError but got %T", err)
}
if e.Signal() != "TERM" || e.ExitStatus() != 143 {
t.Fatalf("expected command to exit with signal TERM and status 143 but got signal %s and status %v", e.Signal(), e.ExitStatus())
}
}
// Test exit signal and status are both returned correctly.
func TestUnknownExitSignal(t *testing.T) {
conn := dial(exitSignalUnknownHandler, t)
defer conn.Close()
session, err := conn.NewSession()
if err != nil {
t.Fatalf("Unable to request new session: %v", err)
}
defer session.Close()
if err := session.Shell(); err != nil {
t.Fatalf("Unable to execute command: %v", err)
}
err = session.Wait()
if err == nil {
t.Fatalf("expected command to fail but it didn't")
}
e, ok := err.(*ExitError)
if !ok {
t.Fatalf("expected *ExitError but got %T", err)
}
if e.Signal() != "SYS" || e.ExitStatus() != 128 {
t.Fatalf("expected command to exit with signal SYS and status 128 but got signal %s and status %v", e.Signal(), e.ExitStatus())
}
}
// Test WaitMsg is not returned if the channel closes abruptly.
func TestExitWithoutStatusOrSignal(t *testing.T) {
conn := dial(exitWithoutSignalOrStatus, t)
defer conn.Close()
session, err := conn.NewSession()
if err != nil {
t.Fatalf("Unable to request new session: %v", err)
}
defer session.Close()
if err := session.Shell(); err != nil {
t.Fatalf("Unable to execute command: %v", err)
}
err = session.Wait()
if err == nil {
t.Fatalf("expected command to fail but it didn't")
}
_, ok := err.(*ExitError)
if ok {
// you can't actually test for errors.errorString
// because it's not exported.
t.Fatalf("expected *errorString but got %T", err)
}
}
// windowTestBytes is the number of bytes that we'll send to the SSH server.
const windowTestBytes = 16000 * 200
// TestServerWindow writes random data to the server. The server is expected to echo
// the same data back, which is compared against the original.
func TestServerWindow(t *testing.T) {
origBuf := bytes.NewBuffer(make([]byte, 0, windowTestBytes))
io.CopyN(origBuf, crypto_rand.Reader, windowTestBytes)
origBytes := origBuf.Bytes()
conn := dial(echoHandler, t)
defer conn.Close()
session, err := conn.NewSession()
if err != nil {
t.Fatal(err)
}
defer session.Close()
result := make(chan []byte)
go func() {
defer close(result)
echoedBuf := bytes.NewBuffer(make([]byte, 0, windowTestBytes))
serverStdout, err := session.StdoutPipe()
if err != nil {
t.Errorf("StdoutPipe failed: %v", err)
return
}
n, err := copyNRandomly("stdout", echoedBuf, serverStdout, windowTestBytes)
if err != nil && err != io.EOF {
t.Errorf("Read only %d bytes from server, expected %d: %v", n, windowTestBytes, err)
}
result <- echoedBuf.Bytes()
}()
serverStdin, err := session.StdinPipe()
if err != nil {
t.Fatalf("StdinPipe failed: %v", err)
}
written, err := copyNRandomly("stdin", serverStdin, origBuf, windowTestBytes)
if err != nil {
t.Fatalf("failed to copy origBuf to serverStdin: %v", err)
}
if written != windowTestBytes {
t.Fatalf("Wrote only %d of %d bytes to server", written, windowTestBytes)
}
echoedBytes := <-result
if !bytes.Equal(origBytes, echoedBytes) {
t.Fatalf("Echoed buffer differed from original, orig %d, echoed %d", len(origBytes), len(echoedBytes))
}
}
// Verify the client can handle a keepalive packet from the server.
func TestClientHandlesKeepalives(t *testing.T) {
conn := dial(channelKeepaliveSender, t)
defer conn.Close()
session, err := conn.NewSession()
if err != nil {
t.Fatal(err)
}
defer session.Close()
if err := session.Shell(); err != nil {
t.Fatalf("Unable to execute command: %v", err)
}
err = session.Wait()
if err != nil {
t.Fatalf("expected nil but got: %v", err)
}
}
type exitStatusMsg struct {
Status uint32
}
type exitSignalMsg struct {
Signal string
CoreDumped bool
Errmsg string
Lang string
}
func handleTerminalRequests(in <-chan *Request) {
for req := range in {
ok := false
switch req.Type {
case "shell":
ok = true
if len(req.Payload) > 0 {
// We don't accept any commands, only the default shell.
ok = false
}
case "env":
ok = true
}
req.Reply(ok, nil)
}
}
func newServerShell(ch Channel, in <-chan *Request, prompt string) *terminal.Terminal {
term := terminal.NewTerminal(ch, prompt)
go handleTerminalRequests(in)
return term
}
func exitStatusZeroHandler(ch Channel, in <-chan *Request, t *testing.T) {
defer ch.Close()
// this string is returned to stdout
shell := newServerShell(ch, in, "> ")
readLine(shell, t)
sendStatus(0, ch, t)
}
func exitStatusNonZeroHandler(ch Channel, in <-chan *Request, t *testing.T) {
defer ch.Close()
shell := newServerShell(ch, in, "> ")
readLine(shell, t)
sendStatus(15, ch, t)
}
func exitSignalAndStatusHandler(ch Channel, in <-chan *Request, t *testing.T) {
defer ch.Close()
shell := newServerShell(ch, in, "> ")
readLine(shell, t)
sendStatus(15, ch, t)
sendSignal("TERM", ch, t)
}
func exitSignalHandler(ch Channel, in <-chan *Request, t *testing.T) {
defer ch.Close()
shell := newServerShell(ch, in, "> ")
readLine(shell, t)
sendSignal("TERM", ch, t)
}
func exitSignalUnknownHandler(ch Channel, in <-chan *Request, t *testing.T) {
defer ch.Close()
shell := newServerShell(ch, in, "> ")
readLine(shell, t)
sendSignal("SYS", ch, t)
}
func exitWithoutSignalOrStatus(ch Channel, in <-chan *Request, t *testing.T) {
defer ch.Close()
shell := newServerShell(ch, in, "> ")
readLine(shell, t)
}
func shellHandler(ch Channel, in <-chan *Request, t *testing.T) {
defer ch.Close()
// this string is returned to stdout
shell := newServerShell(ch, in, "golang")
readLine(shell, t)
sendStatus(0, ch, t)
}
// Ignores the command, writes fixed strings to stderr and stdout.
// Strings are "this-is-stdout." and "this-is-stderr.".
func fixedOutputHandler(ch Channel, in <-chan *Request, t *testing.T) {
defer ch.Close()
_, err := ch.Read(nil)
req, ok := <-in
if !ok {
t.Fatalf("error: expected channel request, got: %#v", err)
return
}
// ignore request, always send some text
req.Reply(true, nil)
_, err = io.WriteString(ch, "this-is-stdout.")
if err != nil {
t.Fatalf("error writing on server: %v", err)
}
_, err = io.WriteString(ch.Stderr(), "this-is-stderr.")
if err != nil {
t.Fatalf("error writing on server: %v", err)
}
sendStatus(0, ch, t)
}
func readLine(shell *terminal.Terminal, t *testing.T) {
if _, err := shell.ReadLine(); err != nil && err != io.EOF {
t.Errorf("unable to read line: %v", err)
}
}
func sendStatus(status uint32, ch Channel, t *testing.T) {
msg := exitStatusMsg{
Status: status,
}
if _, err := ch.SendRequest("exit-status", false, Marshal(&msg)); err != nil {
t.Errorf("unable to send status: %v", err)
}
}
func sendSignal(signal string, ch Channel, t *testing.T) {
sig := exitSignalMsg{
Signal: signal,
CoreDumped: false,
Errmsg: "Process terminated",
Lang: "en-GB-oed",
}
if _, err := ch.SendRequest("exit-signal", false, Marshal(&sig)); err != nil {
t.Errorf("unable to send signal: %v", err)
}
}
func discardHandler(ch Channel, t *testing.T) {
defer ch.Close()
io.Copy(ioutil.Discard, ch)
}
func echoHandler(ch Channel, in <-chan *Request, t *testing.T) {
defer ch.Close()
if n, err := copyNRandomly("echohandler", ch, ch, windowTestBytes); err != nil {
t.Errorf("short write, wrote %d, expected %d: %v ", n, windowTestBytes, err)
}
}
// copyNRandomly copies n bytes from src to dst. It uses a variable, and random,
// buffer size to exercise more code paths.
func copyNRandomly(title string, dst io.Writer, src io.Reader, n int) (int, error) {
var (
buf = make([]byte, 32*1024)
written int
remaining = n
)
for remaining > 0 {
l := rand.Intn(1 << 15)
if remaining < l {
l = remaining
}
nr, er := src.Read(buf[:l])
nw, ew := dst.Write(buf[:nr])
remaining -= nw
written += nw
if ew != nil {
return written, ew
}
if nr != nw {
return written, io.ErrShortWrite
}
if er != nil && er != io.EOF {
return written, er
}
}
return written, nil
}
func channelKeepaliveSender(ch Channel, in <-chan *Request, t *testing.T) {
defer ch.Close()
shell := newServerShell(ch, in, "> ")
readLine(shell, t)
if _, err := ch.SendRequest("keepalive@openssh.com", true, nil); err != nil {
t.Errorf("unable to send channel keepalive request: %v", err)
}
sendStatus(0, ch, t)
}
func TestClientWriteEOF(t *testing.T) {
conn := dial(simpleEchoHandler, t)
defer conn.Close()
session, err := conn.NewSession()
if err != nil {
t.Fatal(err)
}
defer session.Close()
stdin, err := session.StdinPipe()
if err != nil {
t.Fatalf("StdinPipe failed: %v", err)
}
stdout, err := session.StdoutPipe()
if err != nil {
t.Fatalf("StdoutPipe failed: %v", err)
}
data := []byte(`0000`)
_, err = stdin.Write(data)
if err != nil {
t.Fatalf("Write failed: %v", err)
}
stdin.Close()
res, err := ioutil.ReadAll(stdout)
if err != nil {
t.Fatalf("Read failed: %v", err)
}
if !bytes.Equal(data, res) {
t.Fatalf("Read differed from write, wrote: %v, read: %v", data, res)
}
}
func simpleEchoHandler(ch Channel, in <-chan *Request, t *testing.T) {
defer ch.Close()
data, err := ioutil.ReadAll(ch)
if err != nil {
t.Errorf("handler read error: %v", err)
}
_, err = ch.Write(data)
if err != nil {
t.Errorf("handler write error: %v", err)
}
}

View file

@ -0,0 +1,404 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ssh
import (
"errors"
"fmt"
"io"
"math/rand"
"net"
"strconv"
"strings"
"sync"
"time"
)
// Listen requests the remote peer open a listening socket on
// addr. Incoming connections will be available by calling Accept on
// the returned net.Listener. The listener must be serviced, or the
// SSH connection may hang.
func (c *Client) Listen(n, addr string) (net.Listener, error) {
laddr, err := net.ResolveTCPAddr(n, addr)
if err != nil {
return nil, err
}
return c.ListenTCP(laddr)
}
// Automatic port allocation is broken with OpenSSH before 6.0. See
// also https://bugzilla.mindrot.org/show_bug.cgi?id=2017. In
// particular, OpenSSH 5.9 sends a channelOpenMsg with port number 0,
// rather than the actual port number. This means you can never open
// two different listeners with auto allocated ports. We work around
// this by trying explicit ports until we succeed.
const openSSHPrefix = "OpenSSH_"
var portRandomizer = rand.New(rand.NewSource(time.Now().UnixNano()))
// isBrokenOpenSSHVersion returns true if the given version string
// specifies a version of OpenSSH that is known to have a bug in port
// forwarding.
func isBrokenOpenSSHVersion(versionStr string) bool {
i := strings.Index(versionStr, openSSHPrefix)
if i < 0 {
return false
}
i += len(openSSHPrefix)
j := i
for ; j < len(versionStr); j++ {
if versionStr[j] < '0' || versionStr[j] > '9' {
break
}
}
version, _ := strconv.Atoi(versionStr[i:j])
return version < 6
}
// autoPortListenWorkaround simulates automatic port allocation by
// trying random ports repeatedly.
func (c *Client) autoPortListenWorkaround(laddr *net.TCPAddr) (net.Listener, error) {
var sshListener net.Listener
var err error
const tries = 10
for i := 0; i < tries; i++ {
addr := *laddr
addr.Port = 1024 + portRandomizer.Intn(60000)
sshListener, err = c.ListenTCP(&addr)
if err == nil {
laddr.Port = addr.Port
return sshListener, err
}
}
return nil, fmt.Errorf("ssh: listen on random port failed after %d tries: %v", tries, err)
}
// RFC 4254 7.1
type channelForwardMsg struct {
addr string
rport uint32
}
// ListenTCP requests the remote peer open a listening socket
// on laddr. Incoming connections will be available by calling
// Accept on the returned net.Listener.
func (c *Client) ListenTCP(laddr *net.TCPAddr) (net.Listener, error) {
if laddr.Port == 0 && isBrokenOpenSSHVersion(string(c.ServerVersion())) {
return c.autoPortListenWorkaround(laddr)
}
m := channelForwardMsg{
laddr.IP.String(),
uint32(laddr.Port),
}
// send message
ok, resp, err := c.SendRequest("tcpip-forward", true, Marshal(&m))
if err != nil {
return nil, err
}
if !ok {
return nil, errors.New("ssh: tcpip-forward request denied by peer")
}
// If the original port was 0, then the remote side will
// supply a real port number in the response.
if laddr.Port == 0 {
var p struct {
Port uint32
}
if err := Unmarshal(resp, &p); err != nil {
return nil, err
}
laddr.Port = int(p.Port)
}
// Register this forward, using the port number we obtained.
ch := c.forwards.add(*laddr)
return &tcpListener{laddr, c, ch}, nil
}
// forwardList stores a mapping between remote
// forward requests and the tcpListeners.
type forwardList struct {
sync.Mutex
entries []forwardEntry
}
// forwardEntry represents an established mapping of a laddr on a
// remote ssh server to a channel connected to a tcpListener.
type forwardEntry struct {
laddr net.TCPAddr
c chan forward
}
// forward represents an incoming forwarded tcpip connection. The
// arguments to add/remove/lookup should be address as specified in
// the original forward-request.
type forward struct {
newCh NewChannel // the ssh client channel underlying this forward
raddr *net.TCPAddr // the raddr of the incoming connection
}
func (l *forwardList) add(addr net.TCPAddr) chan forward {
l.Lock()
defer l.Unlock()
f := forwardEntry{
addr,
make(chan forward, 1),
}
l.entries = append(l.entries, f)
return f.c
}
// See RFC 4254, section 7.2
type forwardedTCPPayload struct {
Addr string
Port uint32
OriginAddr string
OriginPort uint32
}
// parseTCPAddr parses the originating address from the remote into a *net.TCPAddr.
func parseTCPAddr(addr string, port uint32) (*net.TCPAddr, error) {
if port == 0 || port > 65535 {
return nil, fmt.Errorf("ssh: port number out of range: %d", port)
}
ip := net.ParseIP(string(addr))
if ip == nil {
return nil, fmt.Errorf("ssh: cannot parse IP address %q", addr)
}
return &net.TCPAddr{IP: ip, Port: int(port)}, nil
}
func (l *forwardList) handleChannels(in <-chan NewChannel) {
for ch := range in {
var payload forwardedTCPPayload
if err := Unmarshal(ch.ExtraData(), &payload); err != nil {
ch.Reject(ConnectionFailed, "could not parse forwarded-tcpip payload: "+err.Error())
continue
}
// RFC 4254 section 7.2 specifies that incoming
// addresses should list the address, in string
// format. It is implied that this should be an IP
// address, as it would be impossible to connect to it
// otherwise.
laddr, err := parseTCPAddr(payload.Addr, payload.Port)
if err != nil {
ch.Reject(ConnectionFailed, err.Error())
continue
}
raddr, err := parseTCPAddr(payload.OriginAddr, payload.OriginPort)
if err != nil {
ch.Reject(ConnectionFailed, err.Error())
continue
}
if ok := l.forward(*laddr, *raddr, ch); !ok {
// Section 7.2, implementations MUST reject spurious incoming
// connections.
ch.Reject(Prohibited, "no forward for address")
continue
}
}
}
// remove removes the forward entry, and the channel feeding its
// listener.
func (l *forwardList) remove(addr net.TCPAddr) {
l.Lock()
defer l.Unlock()
for i, f := range l.entries {
if addr.IP.Equal(f.laddr.IP) && addr.Port == f.laddr.Port {
l.entries = append(l.entries[:i], l.entries[i+1:]...)
close(f.c)
return
}
}
}
// closeAll closes and clears all forwards.
func (l *forwardList) closeAll() {
l.Lock()
defer l.Unlock()
for _, f := range l.entries {
close(f.c)
}
l.entries = nil
}
func (l *forwardList) forward(laddr, raddr net.TCPAddr, ch NewChannel) bool {
l.Lock()
defer l.Unlock()
for _, f := range l.entries {
if laddr.IP.Equal(f.laddr.IP) && laddr.Port == f.laddr.Port {
f.c <- forward{ch, &raddr}
return true
}
}
return false
}
type tcpListener struct {
laddr *net.TCPAddr
conn *Client
in <-chan forward
}
// Accept waits for and returns the next connection to the listener.
func (l *tcpListener) Accept() (net.Conn, error) {
s, ok := <-l.in
if !ok {
return nil, io.EOF
}
ch, incoming, err := s.newCh.Accept()
if err != nil {
return nil, err
}
go DiscardRequests(incoming)
return &tcpChanConn{
Channel: ch,
laddr: l.laddr,
raddr: s.raddr,
}, nil
}
// Close closes the listener.
func (l *tcpListener) Close() error {
m := channelForwardMsg{
l.laddr.IP.String(),
uint32(l.laddr.Port),
}
// this also closes the listener.
l.conn.forwards.remove(*l.laddr)
ok, _, err := l.conn.SendRequest("cancel-tcpip-forward", true, Marshal(&m))
if err == nil && !ok {
err = errors.New("ssh: cancel-tcpip-forward failed")
}
return err
}
// Addr returns the listener's network address.
func (l *tcpListener) Addr() net.Addr {
return l.laddr
}
// Dial initiates a connection to the addr from the remote host.
// The resulting connection has a zero LocalAddr() and RemoteAddr().
func (c *Client) Dial(n, addr string) (net.Conn, error) {
// Parse the address into host and numeric port.
host, portString, err := net.SplitHostPort(addr)
if err != nil {
return nil, err
}
port, err := strconv.ParseUint(portString, 10, 16)
if err != nil {
return nil, err
}
// Use a zero address for local and remote address.
zeroAddr := &net.TCPAddr{
IP: net.IPv4zero,
Port: 0,
}
ch, err := c.dial(net.IPv4zero.String(), 0, host, int(port))
if err != nil {
return nil, err
}
return &tcpChanConn{
Channel: ch,
laddr: zeroAddr,
raddr: zeroAddr,
}, nil
}
// DialTCP connects to the remote address raddr on the network net,
// which must be "tcp", "tcp4", or "tcp6". If laddr is not nil, it is used
// as the local address for the connection.
func (c *Client) DialTCP(n string, laddr, raddr *net.TCPAddr) (net.Conn, error) {
if laddr == nil {
laddr = &net.TCPAddr{
IP: net.IPv4zero,
Port: 0,
}
}
ch, err := c.dial(laddr.IP.String(), laddr.Port, raddr.IP.String(), raddr.Port)
if err != nil {
return nil, err
}
return &tcpChanConn{
Channel: ch,
laddr: laddr,
raddr: raddr,
}, nil
}
// RFC 4254 7.2
type channelOpenDirectMsg struct {
raddr string
rport uint32
laddr string
lport uint32
}
func (c *Client) dial(laddr string, lport int, raddr string, rport int) (Channel, error) {
msg := channelOpenDirectMsg{
raddr: raddr,
rport: uint32(rport),
laddr: laddr,
lport: uint32(lport),
}
ch, in, err := c.OpenChannel("direct-tcpip", Marshal(&msg))
go DiscardRequests(in)
return ch, err
}
type tcpChan struct {
Channel // the backing channel
}
// tcpChanConn fulfills the net.Conn interface without
// the tcpChan having to hold laddr or raddr directly.
type tcpChanConn struct {
Channel
laddr, raddr net.Addr
}
// LocalAddr returns the local network address.
func (t *tcpChanConn) LocalAddr() net.Addr {
return t.laddr
}
// RemoteAddr returns the remote network address.
func (t *tcpChanConn) RemoteAddr() net.Addr {
return t.raddr
}
// SetDeadline sets the read and write deadlines associated
// with the connection.
func (t *tcpChanConn) SetDeadline(deadline time.Time) error {
if err := t.SetReadDeadline(deadline); err != nil {
return err
}
return t.SetWriteDeadline(deadline)
}
// SetReadDeadline sets the read deadline.
// A zero value for t means Read will not time out.
// After the deadline, the error from Read will implement net.Error
// with Timeout() == true.
func (t *tcpChanConn) SetReadDeadline(deadline time.Time) error {
return errors.New("ssh: tcpChan: deadline not supported")
}
// SetWriteDeadline exists to satisfy the net.Conn interface
// but is not implemented by this type. It always returns an error.
func (t *tcpChanConn) SetWriteDeadline(deadline time.Time) error {
return errors.New("ssh: tcpChan: deadline not supported")
}

View file

@ -0,0 +1,20 @@
// Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ssh
import (
"testing"
)
func TestAutoPortListenBroken(t *testing.T) {
broken := "SSH-2.0-OpenSSH_5.9hh11"
works := "SSH-2.0-OpenSSH_6.1"
if !isBrokenOpenSSHVersion(broken) {
t.Errorf("version %q not marked as broken", broken)
}
if isBrokenOpenSSHVersion(works) {
t.Errorf("version %q marked as broken", works)
}
}

View file

@ -0,0 +1,888 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package terminal
import (
"bytes"
"io"
"sync"
"unicode/utf8"
)
// EscapeCodes contains escape sequences that can be written to the terminal in
// order to achieve different styles of text.
type EscapeCodes struct {
// Foreground colors
Black, Red, Green, Yellow, Blue, Magenta, Cyan, White []byte
// Reset all attributes
Reset []byte
}
var vt100EscapeCodes = EscapeCodes{
Black: []byte{keyEscape, '[', '3', '0', 'm'},
Red: []byte{keyEscape, '[', '3', '1', 'm'},
Green: []byte{keyEscape, '[', '3', '2', 'm'},
Yellow: []byte{keyEscape, '[', '3', '3', 'm'},
Blue: []byte{keyEscape, '[', '3', '4', 'm'},
Magenta: []byte{keyEscape, '[', '3', '5', 'm'},
Cyan: []byte{keyEscape, '[', '3', '6', 'm'},
White: []byte{keyEscape, '[', '3', '7', 'm'},
Reset: []byte{keyEscape, '[', '0', 'm'},
}
// Terminal contains the state for running a VT100 terminal that is capable of
// reading lines of input.
type Terminal struct {
// AutoCompleteCallback, if non-null, is called for each keypress with
// the full input line and the current position of the cursor (in
// bytes, as an index into |line|). If it returns ok=false, the key
// press is processed normally. Otherwise it returns a replacement line
// and the new cursor position.
AutoCompleteCallback func(line string, pos int, key rune) (newLine string, newPos int, ok bool)
// Escape contains a pointer to the escape codes for this terminal.
// It's always a valid pointer, although the escape codes themselves
// may be empty if the terminal doesn't support them.
Escape *EscapeCodes
// lock protects the terminal and the state in this object from
// concurrent processing of a key press and a Write() call.
lock sync.Mutex
c io.ReadWriter
prompt []rune
// line is the current line being entered.
line []rune
// pos is the logical position of the cursor in line
pos int
// echo is true if local echo is enabled
echo bool
// pasteActive is true iff there is a bracketed paste operation in
// progress.
pasteActive bool
// cursorX contains the current X value of the cursor where the left
// edge is 0. cursorY contains the row number where the first row of
// the current line is 0.
cursorX, cursorY int
// maxLine is the greatest value of cursorY so far.
maxLine int
termWidth, termHeight int
// outBuf contains the terminal data to be sent.
outBuf []byte
// remainder contains the remainder of any partial key sequences after
// a read. It aliases into inBuf.
remainder []byte
inBuf [256]byte
// history contains previously entered commands so that they can be
// accessed with the up and down keys.
history stRingBuffer
// historyIndex stores the currently accessed history entry, where zero
// means the immediately previous entry.
historyIndex int
// When navigating up and down the history it's possible to return to
// the incomplete, initial line. That value is stored in
// historyPending.
historyPending string
}
// NewTerminal runs a VT100 terminal on the given ReadWriter. If the ReadWriter is
// a local terminal, that terminal must first have been put into raw mode.
// prompt is a string that is written at the start of each input line (i.e.
// "> ").
func NewTerminal(c io.ReadWriter, prompt string) *Terminal {
return &Terminal{
Escape: &vt100EscapeCodes,
c: c,
prompt: []rune(prompt),
termWidth: 80,
termHeight: 24,
echo: true,
historyIndex: -1,
}
}
const (
keyCtrlD = 4
keyCtrlU = 21
keyEnter = '\r'
keyEscape = 27
keyBackspace = 127
keyUnknown = 0xd800 /* UTF-16 surrogate area */ + iota
keyUp
keyDown
keyLeft
keyRight
keyAltLeft
keyAltRight
keyHome
keyEnd
keyDeleteWord
keyDeleteLine
keyClearScreen
keyPasteStart
keyPasteEnd
)
var pasteStart = []byte{keyEscape, '[', '2', '0', '0', '~'}
var pasteEnd = []byte{keyEscape, '[', '2', '0', '1', '~'}
// bytesToKey tries to parse a key sequence from b. If successful, it returns
// the key and the remainder of the input. Otherwise it returns utf8.RuneError.
func bytesToKey(b []byte, pasteActive bool) (rune, []byte) {
if len(b) == 0 {
return utf8.RuneError, nil
}
if !pasteActive {
switch b[0] {
case 1: // ^A
return keyHome, b[1:]
case 5: // ^E
return keyEnd, b[1:]
case 8: // ^H
return keyBackspace, b[1:]
case 11: // ^K
return keyDeleteLine, b[1:]
case 12: // ^L
return keyClearScreen, b[1:]
case 23: // ^W
return keyDeleteWord, b[1:]
}
}
if b[0] != keyEscape {
if !utf8.FullRune(b) {
return utf8.RuneError, b
}
r, l := utf8.DecodeRune(b)
return r, b[l:]
}
if !pasteActive && len(b) >= 3 && b[0] == keyEscape && b[1] == '[' {
switch b[2] {
case 'A':
return keyUp, b[3:]
case 'B':
return keyDown, b[3:]
case 'C':
return keyRight, b[3:]
case 'D':
return keyLeft, b[3:]
case 'H':
return keyHome, b[3:]
case 'F':
return keyEnd, b[3:]
}
}
if !pasteActive && len(b) >= 6 && b[0] == keyEscape && b[1] == '[' && b[2] == '1' && b[3] == ';' && b[4] == '3' {
switch b[5] {
case 'C':
return keyAltRight, b[6:]
case 'D':
return keyAltLeft, b[6:]
}
}
if !pasteActive && len(b) >= 6 && bytes.Equal(b[:6], pasteStart) {
return keyPasteStart, b[6:]
}
if pasteActive && len(b) >= 6 && bytes.Equal(b[:6], pasteEnd) {
return keyPasteEnd, b[6:]
}
// If we get here then we have a key that we don't recognise, or a
// partial sequence. It's not clear how one should find the end of a
// sequence without knowing them all, but it seems that [a-zA-Z~] only
// appears at the end of a sequence.
for i, c := range b[0:] {
if c >= 'a' && c <= 'z' || c >= 'A' && c <= 'Z' || c == '~' {
return keyUnknown, b[i+1:]
}
}
return utf8.RuneError, b
}
// queue appends data to the end of t.outBuf
func (t *Terminal) queue(data []rune) {
t.outBuf = append(t.outBuf, []byte(string(data))...)
}
var eraseUnderCursor = []rune{' ', keyEscape, '[', 'D'}
var space = []rune{' '}
func isPrintable(key rune) bool {
isInSurrogateArea := key >= 0xd800 && key <= 0xdbff
return key >= 32 && !isInSurrogateArea
}
// moveCursorToPos appends data to t.outBuf which will move the cursor to the
// given, logical position in the text.
func (t *Terminal) moveCursorToPos(pos int) {
if !t.echo {
return
}
x := visualLength(t.prompt) + pos
y := x / t.termWidth
x = x % t.termWidth
up := 0
if y < t.cursorY {
up = t.cursorY - y
}
down := 0
if y > t.cursorY {
down = y - t.cursorY
}
left := 0
if x < t.cursorX {
left = t.cursorX - x
}
right := 0
if x > t.cursorX {
right = x - t.cursorX
}
t.cursorX = x
t.cursorY = y
t.move(up, down, left, right)
}
func (t *Terminal) move(up, down, left, right int) {
movement := make([]rune, 3*(up+down+left+right))
m := movement
for i := 0; i < up; i++ {
m[0] = keyEscape
m[1] = '['
m[2] = 'A'
m = m[3:]
}
for i := 0; i < down; i++ {
m[0] = keyEscape
m[1] = '['
m[2] = 'B'
m = m[3:]
}
for i := 0; i < left; i++ {
m[0] = keyEscape
m[1] = '['
m[2] = 'D'
m = m[3:]
}
for i := 0; i < right; i++ {
m[0] = keyEscape
m[1] = '['
m[2] = 'C'
m = m[3:]
}
t.queue(movement)
}
func (t *Terminal) clearLineToRight() {
op := []rune{keyEscape, '[', 'K'}
t.queue(op)
}
const maxLineLength = 4096
func (t *Terminal) setLine(newLine []rune, newPos int) {
if t.echo {
t.moveCursorToPos(0)
t.writeLine(newLine)
for i := len(newLine); i < len(t.line); i++ {
t.writeLine(space)
}
t.moveCursorToPos(newPos)
}
t.line = newLine
t.pos = newPos
}
func (t *Terminal) advanceCursor(places int) {
t.cursorX += places
t.cursorY += t.cursorX / t.termWidth
if t.cursorY > t.maxLine {
t.maxLine = t.cursorY
}
t.cursorX = t.cursorX % t.termWidth
if places > 0 && t.cursorX == 0 {
// Normally terminals will advance the current position
// when writing a character. But that doesn't happen
// for the last character in a line. However, when
// writing a character (except a new line) that causes
// a line wrap, the position will be advanced two
// places.
//
// So, if we are stopping at the end of a line, we
// need to write a newline so that our cursor can be
// advanced to the next line.
t.outBuf = append(t.outBuf, '\n')
}
}
func (t *Terminal) eraseNPreviousChars(n int) {
if n == 0 {
return
}
if t.pos < n {
n = t.pos
}
t.pos -= n
t.moveCursorToPos(t.pos)
copy(t.line[t.pos:], t.line[n+t.pos:])
t.line = t.line[:len(t.line)-n]
if t.echo {
t.writeLine(t.line[t.pos:])
for i := 0; i < n; i++ {
t.queue(space)
}
t.advanceCursor(n)
t.moveCursorToPos(t.pos)
}
}
// countToLeftWord returns then number of characters from the cursor to the
// start of the previous word.
func (t *Terminal) countToLeftWord() int {
if t.pos == 0 {
return 0
}
pos := t.pos - 1
for pos > 0 {
if t.line[pos] != ' ' {
break
}
pos--
}
for pos > 0 {
if t.line[pos] == ' ' {
pos++
break
}
pos--
}
return t.pos - pos
}
// countToRightWord returns then number of characters from the cursor to the
// start of the next word.
func (t *Terminal) countToRightWord() int {
pos := t.pos
for pos < len(t.line) {
if t.line[pos] == ' ' {
break
}
pos++
}
for pos < len(t.line) {
if t.line[pos] != ' ' {
break
}
pos++
}
return pos - t.pos
}
// visualLength returns the number of visible glyphs in s.
func visualLength(runes []rune) int {
inEscapeSeq := false
length := 0
for _, r := range runes {
switch {
case inEscapeSeq:
if (r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') {
inEscapeSeq = false
}
case r == '\x1b':
inEscapeSeq = true
default:
length++
}
}
return length
}
// handleKey processes the given key and, optionally, returns a line of text
// that the user has entered.
func (t *Terminal) handleKey(key rune) (line string, ok bool) {
if t.pasteActive && key != keyEnter {
t.addKeyToLine(key)
return
}
switch key {
case keyBackspace:
if t.pos == 0 {
return
}
t.eraseNPreviousChars(1)
case keyAltLeft:
// move left by a word.
t.pos -= t.countToLeftWord()
t.moveCursorToPos(t.pos)
case keyAltRight:
// move right by a word.
t.pos += t.countToRightWord()
t.moveCursorToPos(t.pos)
case keyLeft:
if t.pos == 0 {
return
}
t.pos--
t.moveCursorToPos(t.pos)
case keyRight:
if t.pos == len(t.line) {
return
}
t.pos++
t.moveCursorToPos(t.pos)
case keyHome:
if t.pos == 0 {
return
}
t.pos = 0
t.moveCursorToPos(t.pos)
case keyEnd:
if t.pos == len(t.line) {
return
}
t.pos = len(t.line)
t.moveCursorToPos(t.pos)
case keyUp:
entry, ok := t.history.NthPreviousEntry(t.historyIndex + 1)
if !ok {
return "", false
}
if t.historyIndex == -1 {
t.historyPending = string(t.line)
}
t.historyIndex++
runes := []rune(entry)
t.setLine(runes, len(runes))
case keyDown:
switch t.historyIndex {
case -1:
return
case 0:
runes := []rune(t.historyPending)
t.setLine(runes, len(runes))
t.historyIndex--
default:
entry, ok := t.history.NthPreviousEntry(t.historyIndex - 1)
if ok {
t.historyIndex--
runes := []rune(entry)
t.setLine(runes, len(runes))
}
}
case keyEnter:
t.moveCursorToPos(len(t.line))
t.queue([]rune("\r\n"))
line = string(t.line)
ok = true
t.line = t.line[:0]
t.pos = 0
t.cursorX = 0
t.cursorY = 0
t.maxLine = 0
case keyDeleteWord:
// Delete zero or more spaces and then one or more characters.
t.eraseNPreviousChars(t.countToLeftWord())
case keyDeleteLine:
// Delete everything from the current cursor position to the
// end of line.
for i := t.pos; i < len(t.line); i++ {
t.queue(space)
t.advanceCursor(1)
}
t.line = t.line[:t.pos]
t.moveCursorToPos(t.pos)
case keyCtrlD:
// Erase the character under the current position.
// The EOF case when the line is empty is handled in
// readLine().
if t.pos < len(t.line) {
t.pos++
t.eraseNPreviousChars(1)
}
case keyCtrlU:
t.eraseNPreviousChars(t.pos)
case keyClearScreen:
// Erases the screen and moves the cursor to the home position.
t.queue([]rune("\x1b[2J\x1b[H"))
t.queue(t.prompt)
t.cursorX, t.cursorY = 0, 0
t.advanceCursor(visualLength(t.prompt))
t.setLine(t.line, t.pos)
default:
if t.AutoCompleteCallback != nil {
prefix := string(t.line[:t.pos])
suffix := string(t.line[t.pos:])
t.lock.Unlock()
newLine, newPos, completeOk := t.AutoCompleteCallback(prefix+suffix, len(prefix), key)
t.lock.Lock()
if completeOk {
t.setLine([]rune(newLine), utf8.RuneCount([]byte(newLine)[:newPos]))
return
}
}
if !isPrintable(key) {
return
}
if len(t.line) == maxLineLength {
return
}
t.addKeyToLine(key)
}
return
}
// addKeyToLine inserts the given key at the current position in the current
// line.
func (t *Terminal) addKeyToLine(key rune) {
if len(t.line) == cap(t.line) {
newLine := make([]rune, len(t.line), 2*(1+len(t.line)))
copy(newLine, t.line)
t.line = newLine
}
t.line = t.line[:len(t.line)+1]
copy(t.line[t.pos+1:], t.line[t.pos:])
t.line[t.pos] = key
if t.echo {
t.writeLine(t.line[t.pos:])
}
t.pos++
t.moveCursorToPos(t.pos)
}
func (t *Terminal) writeLine(line []rune) {
for len(line) != 0 {
remainingOnLine := t.termWidth - t.cursorX
todo := len(line)
if todo > remainingOnLine {
todo = remainingOnLine
}
t.queue(line[:todo])
t.advanceCursor(visualLength(line[:todo]))
line = line[todo:]
}
}
func (t *Terminal) Write(buf []byte) (n int, err error) {
t.lock.Lock()
defer t.lock.Unlock()
if t.cursorX == 0 && t.cursorY == 0 {
// This is the easy case: there's nothing on the screen that we
// have to move out of the way.
return t.c.Write(buf)
}
// We have a prompt and possibly user input on the screen. We
// have to clear it first.
t.move(0 /* up */, 0 /* down */, t.cursorX /* left */, 0 /* right */)
t.cursorX = 0
t.clearLineToRight()
for t.cursorY > 0 {
t.move(1 /* up */, 0, 0, 0)
t.cursorY--
t.clearLineToRight()
}
if _, err = t.c.Write(t.outBuf); err != nil {
return
}
t.outBuf = t.outBuf[:0]
if n, err = t.c.Write(buf); err != nil {
return
}
t.writeLine(t.prompt)
if t.echo {
t.writeLine(t.line)
}
t.moveCursorToPos(t.pos)
if _, err = t.c.Write(t.outBuf); err != nil {
return
}
t.outBuf = t.outBuf[:0]
return
}
// ReadPassword temporarily changes the prompt and reads a password, without
// echo, from the terminal.
func (t *Terminal) ReadPassword(prompt string) (line string, err error) {
t.lock.Lock()
defer t.lock.Unlock()
oldPrompt := t.prompt
t.prompt = []rune(prompt)
t.echo = false
line, err = t.readLine()
t.prompt = oldPrompt
t.echo = true
return
}
// ReadLine returns a line of input from the terminal.
func (t *Terminal) ReadLine() (line string, err error) {
t.lock.Lock()
defer t.lock.Unlock()
return t.readLine()
}
func (t *Terminal) readLine() (line string, err error) {
// t.lock must be held at this point
if t.cursorX == 0 && t.cursorY == 0 {
t.writeLine(t.prompt)
t.c.Write(t.outBuf)
t.outBuf = t.outBuf[:0]
}
lineIsPasted := t.pasteActive
for {
rest := t.remainder
lineOk := false
for !lineOk {
var key rune
key, rest = bytesToKey(rest, t.pasteActive)
if key == utf8.RuneError {
break
}
if !t.pasteActive {
if key == keyCtrlD {
if len(t.line) == 0 {
return "", io.EOF
}
}
if key == keyPasteStart {
t.pasteActive = true
if len(t.line) == 0 {
lineIsPasted = true
}
continue
}
} else if key == keyPasteEnd {
t.pasteActive = false
continue
}
if !t.pasteActive {
lineIsPasted = false
}
line, lineOk = t.handleKey(key)
}
if len(rest) > 0 {
n := copy(t.inBuf[:], rest)
t.remainder = t.inBuf[:n]
} else {
t.remainder = nil
}
t.c.Write(t.outBuf)
t.outBuf = t.outBuf[:0]
if lineOk {
if t.echo {
t.historyIndex = -1
t.history.Add(line)
}
if lineIsPasted {
err = ErrPasteIndicator
}
return
}
// t.remainder is a slice at the beginning of t.inBuf
// containing a partial key sequence
readBuf := t.inBuf[len(t.remainder):]
var n int
t.lock.Unlock()
n, err = t.c.Read(readBuf)
t.lock.Lock()
if err != nil {
return
}
t.remainder = t.inBuf[:n+len(t.remainder)]
}
panic("unreachable") // for Go 1.0.
}
// SetPrompt sets the prompt to be used when reading subsequent lines.
func (t *Terminal) SetPrompt(prompt string) {
t.lock.Lock()
defer t.lock.Unlock()
t.prompt = []rune(prompt)
}
func (t *Terminal) clearAndRepaintLinePlusNPrevious(numPrevLines int) {
// Move cursor to column zero at the start of the line.
t.move(t.cursorY, 0, t.cursorX, 0)
t.cursorX, t.cursorY = 0, 0
t.clearLineToRight()
for t.cursorY < numPrevLines {
// Move down a line
t.move(0, 1, 0, 0)
t.cursorY++
t.clearLineToRight()
}
// Move back to beginning.
t.move(t.cursorY, 0, 0, 0)
t.cursorX, t.cursorY = 0, 0
t.queue(t.prompt)
t.advanceCursor(visualLength(t.prompt))
t.writeLine(t.line)
t.moveCursorToPos(t.pos)
}
func (t *Terminal) SetSize(width, height int) error {
t.lock.Lock()
defer t.lock.Unlock()
if width == 0 {
width = 1
}
oldWidth := t.termWidth
t.termWidth, t.termHeight = width, height
switch {
case width == oldWidth:
// If the width didn't change then nothing else needs to be
// done.
return nil
case width < oldWidth:
// Some terminals (e.g. xterm) will truncate lines that were
// too long when shinking. Others, (e.g. gnome-terminal) will
// attempt to wrap them. For the former, repainting t.maxLine
// works great, but that behaviour goes badly wrong in the case
// of the latter because they have doubled every full line.
// We assume that we are working on a terminal that wraps lines
// and adjust the cursor position based on every previous line
// wrapping and turning into two. This causes the prompt on
// xterms to move upwards, which isn't great, but it avoids a
// huge mess with gnome-terminal.
if t.cursorX >= t.termWidth {
t.cursorX = t.termWidth - 1
}
t.cursorY *= 2
t.clearAndRepaintLinePlusNPrevious(t.maxLine * 2)
case width > oldWidth:
// If the terminal expands then our position calculations will
// be wrong in the future because we think the cursor is
// |t.pos| chars into the string, but there will be a gap at
// the end of any wrapped line.
//
// But the position will actually be correct until we move, so
// we can move back to the beginning and repaint everything.
t.clearAndRepaintLinePlusNPrevious(t.maxLine)
}
_, err := t.c.Write(t.outBuf)
t.outBuf = t.outBuf[:0]
return err
}
type pasteIndicatorError struct{}
func (pasteIndicatorError) Error() string {
return "terminal: ErrPasteIndicator not correctly handled"
}
// ErrPasteIndicator may be returned from ReadLine as the error, in addition
// to valid line data. It indicates that bracketed paste mode is enabled and
// that the returned line consists only of pasted data. Programs may wish to
// interpret pasted data more literally than typed data.
var ErrPasteIndicator = pasteIndicatorError{}
// SetBracketedPasteMode requests that the terminal bracket paste operations
// with markers. Not all terminals support this but, if it is supported, then
// enabling this mode will stop any autocomplete callback from running due to
// pastes. Additionally, any lines that are completely pasted will be returned
// from ReadLine with the error set to ErrPasteIndicator.
func (t *Terminal) SetBracketedPasteMode(on bool) {
if on {
io.WriteString(t.c, "\x1b[?2004h")
} else {
io.WriteString(t.c, "\x1b[?2004l")
}
}
// stRingBuffer is a ring buffer of strings.
type stRingBuffer struct {
// entries contains max elements.
entries []string
max int
// head contains the index of the element most recently added to the ring.
head int
// size contains the number of elements in the ring.
size int
}
func (s *stRingBuffer) Add(a string) {
if s.entries == nil {
const defaultNumEntries = 100
s.entries = make([]string, defaultNumEntries)
s.max = defaultNumEntries
}
s.head = (s.head + 1) % s.max
s.entries[s.head] = a
if s.size < s.max {
s.size++
}
}
// NthPreviousEntry returns the value passed to the nth previous call to Add.
// If n is zero then the immediately prior value is returned, if one, then the
// next most recent, and so on. If such an element doesn't exist then ok is
// false.
func (s *stRingBuffer) NthPreviousEntry(n int) (value string, ok bool) {
if n >= s.size {
return "", false
}
index := s.head - n
if index < 0 {
index += s.max
}
return s.entries[index], true
}

View file

@ -0,0 +1,243 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package terminal
import (
"io"
"testing"
)
type MockTerminal struct {
toSend []byte
bytesPerRead int
received []byte
}
func (c *MockTerminal) Read(data []byte) (n int, err error) {
n = len(data)
if n == 0 {
return
}
if n > len(c.toSend) {
n = len(c.toSend)
}
if n == 0 {
return 0, io.EOF
}
if c.bytesPerRead > 0 && n > c.bytesPerRead {
n = c.bytesPerRead
}
copy(data, c.toSend[:n])
c.toSend = c.toSend[n:]
return
}
func (c *MockTerminal) Write(data []byte) (n int, err error) {
c.received = append(c.received, data...)
return len(data), nil
}
func TestClose(t *testing.T) {
c := &MockTerminal{}
ss := NewTerminal(c, "> ")
line, err := ss.ReadLine()
if line != "" {
t.Errorf("Expected empty line but got: %s", line)
}
if err != io.EOF {
t.Errorf("Error should have been EOF but got: %s", err)
}
}
var keyPressTests = []struct {
in string
line string
err error
throwAwayLines int
}{
{
err: io.EOF,
},
{
in: "\r",
line: "",
},
{
in: "foo\r",
line: "foo",
},
{
in: "a\x1b[Cb\r", // right
line: "ab",
},
{
in: "a\x1b[Db\r", // left
line: "ba",
},
{
in: "a\177b\r", // backspace
line: "b",
},
{
in: "\x1b[A\r", // up
},
{
in: "\x1b[B\r", // down
},
{
in: "line\x1b[A\x1b[B\r", // up then down
line: "line",
},
{
in: "line1\rline2\x1b[A\r", // recall previous line.
line: "line1",
throwAwayLines: 1,
},
{
// recall two previous lines and append.
in: "line1\rline2\rline3\x1b[A\x1b[Axxx\r",
line: "line1xxx",
throwAwayLines: 2,
},
{
// Ctrl-A to move to beginning of line followed by ^K to kill
// line.
in: "a b \001\013\r",
line: "",
},
{
// Ctrl-A to move to beginning of line, Ctrl-E to move to end,
// finally ^K to kill nothing.
in: "a b \001\005\013\r",
line: "a b ",
},
{
in: "\027\r",
line: "",
},
{
in: "a\027\r",
line: "",
},
{
in: "a \027\r",
line: "",
},
{
in: "a b\027\r",
line: "a ",
},
{
in: "a b \027\r",
line: "a ",
},
{
in: "one two thr\x1b[D\027\r",
line: "one two r",
},
{
in: "\013\r",
line: "",
},
{
in: "a\013\r",
line: "a",
},
{
in: "ab\x1b[D\013\r",
line: "a",
},
{
in: "Ξεσκεπάζω\r",
line: "Ξεσκεπάζω",
},
{
in: "£\r\x1b[A\177\r", // non-ASCII char, enter, up, backspace.
line: "",
throwAwayLines: 1,
},
{
in: "£\r££\x1b[A\x1b[B\177\r", // non-ASCII char, enter, 2x non-ASCII, up, down, backspace, enter.
line: "£",
throwAwayLines: 1,
},
{
// Ctrl-D at the end of the line should be ignored.
in: "a\004\r",
line: "a",
},
{
// a, b, left, Ctrl-D should erase the b.
in: "ab\x1b[D\004\r",
line: "a",
},
{
// a, b, c, d, left, left, ^U should erase to the beginning of
// the line.
in: "abcd\x1b[D\x1b[D\025\r",
line: "cd",
},
{
// Bracketed paste mode: control sequences should be returned
// verbatim in paste mode.
in: "abc\x1b[200~de\177f\x1b[201~\177\r",
line: "abcde\177",
},
{
// Enter in bracketed paste mode should still work.
in: "abc\x1b[200~d\refg\x1b[201~h\r",
line: "efgh",
throwAwayLines: 1,
},
{
// Lines consisting entirely of pasted data should be indicated as such.
in: "\x1b[200~a\r",
line: "a",
err: ErrPasteIndicator,
},
}
func TestKeyPresses(t *testing.T) {
for i, test := range keyPressTests {
for j := 1; j < len(test.in); j++ {
c := &MockTerminal{
toSend: []byte(test.in),
bytesPerRead: j,
}
ss := NewTerminal(c, "> ")
for k := 0; k < test.throwAwayLines; k++ {
_, err := ss.ReadLine()
if err != nil {
t.Errorf("Throwaway line %d from test %d resulted in error: %s", k, i, err)
}
}
line, err := ss.ReadLine()
if line != test.line {
t.Errorf("Line resulting from test %d (%d bytes per read) was '%s', expected '%s'", i, j, line, test.line)
break
}
if err != test.err {
t.Errorf("Error resulting from test %d (%d bytes per read) was '%v', expected '%v'", i, j, err, test.err)
break
}
}
}
}
func TestPasswordNotSaved(t *testing.T) {
c := &MockTerminal{
toSend: []byte("password\r\x1b[A\r"),
bytesPerRead: 1,
}
ss := NewTerminal(c, "> ")
pw, _ := ss.ReadPassword("> ")
if pw != "password" {
t.Fatalf("failed to read password, got %s", pw)
}
line, _ := ss.ReadLine()
if len(line) > 0 {
t.Fatalf("password was saved in history")
}
}

View file

@ -0,0 +1,128 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build darwin dragonfly freebsd linux,!appengine netbsd openbsd
// Package terminal provides support functions for dealing with terminals, as
// commonly found on UNIX systems.
//
// Putting a terminal into raw mode is the most common requirement:
//
// oldState, err := terminal.MakeRaw(0)
// if err != nil {
// panic(err)
// }
// defer terminal.Restore(0, oldState)
package terminal
import (
"io"
"syscall"
"unsafe"
)
// State contains the state of a terminal.
type State struct {
termios syscall.Termios
}
// IsTerminal returns true if the given file descriptor is a terminal.
func IsTerminal(fd int) bool {
var termios syscall.Termios
_, _, err := syscall.Syscall6(syscall.SYS_IOCTL, uintptr(fd), ioctlReadTermios, uintptr(unsafe.Pointer(&termios)), 0, 0, 0)
return err == 0
}
// MakeRaw put the terminal connected to the given file descriptor into raw
// mode and returns the previous state of the terminal so that it can be
// restored.
func MakeRaw(fd int) (*State, error) {
var oldState State
if _, _, err := syscall.Syscall6(syscall.SYS_IOCTL, uintptr(fd), ioctlReadTermios, uintptr(unsafe.Pointer(&oldState.termios)), 0, 0, 0); err != 0 {
return nil, err
}
newState := oldState.termios
newState.Iflag &^= syscall.ISTRIP | syscall.INLCR | syscall.ICRNL | syscall.IGNCR | syscall.IXON | syscall.IXOFF
newState.Lflag &^= syscall.ECHO | syscall.ICANON | syscall.ISIG
if _, _, err := syscall.Syscall6(syscall.SYS_IOCTL, uintptr(fd), ioctlWriteTermios, uintptr(unsafe.Pointer(&newState)), 0, 0, 0); err != 0 {
return nil, err
}
return &oldState, nil
}
// GetState returns the current state of a terminal which may be useful to
// restore the terminal after a signal.
func GetState(fd int) (*State, error) {
var oldState State
if _, _, err := syscall.Syscall6(syscall.SYS_IOCTL, uintptr(fd), ioctlReadTermios, uintptr(unsafe.Pointer(&oldState.termios)), 0, 0, 0); err != 0 {
return nil, err
}
return &oldState, nil
}
// Restore restores the terminal connected to the given file descriptor to a
// previous state.
func Restore(fd int, state *State) error {
_, _, err := syscall.Syscall6(syscall.SYS_IOCTL, uintptr(fd), ioctlWriteTermios, uintptr(unsafe.Pointer(&state.termios)), 0, 0, 0)
return err
}
// GetSize returns the dimensions of the given terminal.
func GetSize(fd int) (width, height int, err error) {
var dimensions [4]uint16
if _, _, err := syscall.Syscall6(syscall.SYS_IOCTL, uintptr(fd), uintptr(syscall.TIOCGWINSZ), uintptr(unsafe.Pointer(&dimensions)), 0, 0, 0); err != 0 {
return -1, -1, err
}
return int(dimensions[1]), int(dimensions[0]), nil
}
// ReadPassword reads a line of input from a terminal without local echo. This
// is commonly used for inputting passwords and other sensitive data. The slice
// returned does not include the \n.
func ReadPassword(fd int) ([]byte, error) {
var oldState syscall.Termios
if _, _, err := syscall.Syscall6(syscall.SYS_IOCTL, uintptr(fd), ioctlReadTermios, uintptr(unsafe.Pointer(&oldState)), 0, 0, 0); err != 0 {
return nil, err
}
newState := oldState
newState.Lflag &^= syscall.ECHO
newState.Lflag |= syscall.ICANON | syscall.ISIG
newState.Iflag |= syscall.ICRNL
if _, _, err := syscall.Syscall6(syscall.SYS_IOCTL, uintptr(fd), ioctlWriteTermios, uintptr(unsafe.Pointer(&newState)), 0, 0, 0); err != 0 {
return nil, err
}
defer func() {
syscall.Syscall6(syscall.SYS_IOCTL, uintptr(fd), ioctlWriteTermios, uintptr(unsafe.Pointer(&oldState)), 0, 0, 0)
}()
var buf [16]byte
var ret []byte
for {
n, err := syscall.Read(fd, buf[:])
if err != nil {
return nil, err
}
if n == 0 {
if len(ret) == 0 {
return nil, io.EOF
}
break
}
if buf[n-1] == '\n' {
n--
}
ret = append(ret, buf[:n]...)
if n < len(buf) {
break
}
}
return ret, nil
}

View file

@ -0,0 +1,12 @@
// Copyright 2013 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build darwin dragonfly freebsd netbsd openbsd
package terminal
import "syscall"
const ioctlReadTermios = syscall.TIOCGETA
const ioctlWriteTermios = syscall.TIOCSETA

View file

@ -0,0 +1,11 @@
// Copyright 2013 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package terminal
// These constants are declared here, rather than importing
// them from the syscall package as some syscall packages, even
// on linux, for example gccgo, do not declare them.
const ioctlReadTermios = 0x5401 // syscall.TCGETS
const ioctlWriteTermios = 0x5402 // syscall.TCSETS

View file

@ -0,0 +1,174 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build windows
// Package terminal provides support functions for dealing with terminals, as
// commonly found on UNIX systems.
//
// Putting a terminal into raw mode is the most common requirement:
//
// oldState, err := terminal.MakeRaw(0)
// if err != nil {
// panic(err)
// }
// defer terminal.Restore(0, oldState)
package terminal
import (
"io"
"syscall"
"unsafe"
)
const (
enableLineInput = 2
enableEchoInput = 4
enableProcessedInput = 1
enableWindowInput = 8
enableMouseInput = 16
enableInsertMode = 32
enableQuickEditMode = 64
enableExtendedFlags = 128
enableAutoPosition = 256
enableProcessedOutput = 1
enableWrapAtEolOutput = 2
)
var kernel32 = syscall.NewLazyDLL("kernel32.dll")
var (
procGetConsoleMode = kernel32.NewProc("GetConsoleMode")
procSetConsoleMode = kernel32.NewProc("SetConsoleMode")
procGetConsoleScreenBufferInfo = kernel32.NewProc("GetConsoleScreenBufferInfo")
)
type (
short int16
word uint16
coord struct {
x short
y short
}
smallRect struct {
left short
top short
right short
bottom short
}
consoleScreenBufferInfo struct {
size coord
cursorPosition coord
attributes word
window smallRect
maximumWindowSize coord
}
)
type State struct {
mode uint32
}
// IsTerminal returns true if the given file descriptor is a terminal.
func IsTerminal(fd int) bool {
var st uint32
r, _, e := syscall.Syscall(procGetConsoleMode.Addr(), 2, uintptr(fd), uintptr(unsafe.Pointer(&st)), 0)
return r != 0 && e == 0
}
// MakeRaw put the terminal connected to the given file descriptor into raw
// mode and returns the previous state of the terminal so that it can be
// restored.
func MakeRaw(fd int) (*State, error) {
var st uint32
_, _, e := syscall.Syscall(procGetConsoleMode.Addr(), 2, uintptr(fd), uintptr(unsafe.Pointer(&st)), 0)
if e != 0 {
return nil, error(e)
}
st &^= (enableEchoInput | enableProcessedInput | enableLineInput | enableProcessedOutput)
_, _, e = syscall.Syscall(procSetConsoleMode.Addr(), 2, uintptr(fd), uintptr(st), 0)
if e != 0 {
return nil, error(e)
}
return &State{st}, nil
}
// GetState returns the current state of a terminal which may be useful to
// restore the terminal after a signal.
func GetState(fd int) (*State, error) {
var st uint32
_, _, e := syscall.Syscall(procGetConsoleMode.Addr(), 2, uintptr(fd), uintptr(unsafe.Pointer(&st)), 0)
if e != 0 {
return nil, error(e)
}
return &State{st}, nil
}
// Restore restores the terminal connected to the given file descriptor to a
// previous state.
func Restore(fd int, state *State) error {
_, _, err := syscall.Syscall(procSetConsoleMode.Addr(), 2, uintptr(fd), uintptr(state.mode), 0)
return err
}
// GetSize returns the dimensions of the given terminal.
func GetSize(fd int) (width, height int, err error) {
var info consoleScreenBufferInfo
_, _, e := syscall.Syscall(procGetConsoleScreenBufferInfo.Addr(), 2, uintptr(fd), uintptr(unsafe.Pointer(&info)), 0)
if e != 0 {
return 0, 0, error(e)
}
return int(info.size.x), int(info.size.y), nil
}
// ReadPassword reads a line of input from a terminal without local echo. This
// is commonly used for inputting passwords and other sensitive data. The slice
// returned does not include the \n.
func ReadPassword(fd int) ([]byte, error) {
var st uint32
_, _, e := syscall.Syscall(procGetConsoleMode.Addr(), 2, uintptr(fd), uintptr(unsafe.Pointer(&st)), 0)
if e != 0 {
return nil, error(e)
}
old := st
st &^= (enableEchoInput)
st |= (enableProcessedInput | enableLineInput | enableProcessedOutput)
_, _, e = syscall.Syscall(procSetConsoleMode.Addr(), 2, uintptr(fd), uintptr(st), 0)
if e != 0 {
return nil, error(e)
}
defer func() {
syscall.Syscall(procSetConsoleMode.Addr(), 2, uintptr(fd), uintptr(old), 0)
}()
var buf [16]byte
var ret []byte
for {
n, err := syscall.Read(syscall.Handle(fd), buf[:])
if err != nil {
return nil, err
}
if n == 0 {
if len(ret) == 0 {
return nil, io.EOF
}
break
}
if buf[n-1] == '\n' {
n--
}
if n > 0 && buf[n-1] == '\r' {
n--
}
ret = append(ret, buf[:n]...)
if n < len(buf) {
break
}
}
return ret, nil
}

View file

@ -0,0 +1,50 @@
// Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build darwin dragonfly freebsd linux netbsd openbsd
package test
import (
"bytes"
"testing"
"golang.org/x/crypto/ssh"
"golang.org/x/crypto/ssh/agent"
)
func TestAgentForward(t *testing.T) {
server := newServer(t)
defer server.Shutdown()
conn := server.Dial(clientConfig())
defer conn.Close()
keyring := agent.NewKeyring()
keyring.Add(testPrivateKeys["dsa"], nil, "")
pub := testPublicKeys["dsa"]
sess, err := conn.NewSession()
if err != nil {
t.Fatalf("NewSession: %v", err)
}
if err := agent.RequestAgentForwarding(sess); err != nil {
t.Fatalf("RequestAgentForwarding: %v", err)
}
if err := agent.ForwardToAgent(conn, keyring); err != nil {
t.Fatalf("SetupForwardKeyring: %v", err)
}
out, err := sess.CombinedOutput("ssh-add -L")
if err != nil {
t.Fatalf("running ssh-add: %v, out %s", err, out)
}
key, _, _, _, err := ssh.ParseAuthorizedKey(out)
if err != nil {
t.Fatalf("ParseAuthorizedKey(%q): %v", out, err)
}
if !bytes.Equal(key.Marshal(), pub.Marshal()) {
t.Fatalf("got key %s, want %s", ssh.MarshalAuthorizedKey(key), ssh.MarshalAuthorizedKey(pub))
}
}

View file

@ -0,0 +1,47 @@
// Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build darwin dragonfly freebsd linux netbsd openbsd
package test
import (
"crypto/rand"
"testing"
"golang.org/x/crypto/ssh"
)
func TestCertLogin(t *testing.T) {
s := newServer(t)
defer s.Shutdown()
// Use a key different from the default.
clientKey := testSigners["dsa"]
caAuthKey := testSigners["ecdsa"]
cert := &ssh.Certificate{
Key: clientKey.PublicKey(),
ValidPrincipals: []string{username()},
CertType: ssh.UserCert,
ValidBefore: ssh.CertTimeInfinity,
}
if err := cert.SignCert(rand.Reader, caAuthKey); err != nil {
t.Fatalf("SetSignature: %v", err)
}
certSigner, err := ssh.NewCertSigner(cert, clientKey)
if err != nil {
t.Fatalf("NewCertSigner: %v", err)
}
conf := &ssh.ClientConfig{
User: username(),
}
conf.Auth = append(conf.Auth, ssh.PublicKeys(certSigner))
client, err := s.TryDial(conf)
if err != nil {
t.Fatalf("TryDial: %v", err)
}
client.Close()
}

View file

@ -0,0 +1,7 @@
// Copyright 2012 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// This package contains integration tests for the
// code.google.com/p/go.crypto/ssh package.
package test

View file

@ -0,0 +1,160 @@
// Copyright 2012 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build darwin dragonfly freebsd linux netbsd openbsd
package test
import (
"bytes"
"io"
"io/ioutil"
"math/rand"
"net"
"testing"
"time"
)
func TestPortForward(t *testing.T) {
server := newServer(t)
defer server.Shutdown()
conn := server.Dial(clientConfig())
defer conn.Close()
sshListener, err := conn.Listen("tcp", "localhost:0")
if err != nil {
t.Fatal(err)
}
go func() {
sshConn, err := sshListener.Accept()
if err != nil {
t.Fatalf("listen.Accept failed: %v", err)
}
_, err = io.Copy(sshConn, sshConn)
if err != nil && err != io.EOF {
t.Fatalf("ssh client copy: %v", err)
}
sshConn.Close()
}()
forwardedAddr := sshListener.Addr().String()
tcpConn, err := net.Dial("tcp", forwardedAddr)
if err != nil {
t.Fatalf("TCP dial failed: %v", err)
}
readChan := make(chan []byte)
go func() {
data, _ := ioutil.ReadAll(tcpConn)
readChan <- data
}()
// Invent some data.
data := make([]byte, 100*1000)
for i := range data {
data[i] = byte(i % 255)
}
var sent []byte
for len(sent) < 1000*1000 {
// Send random sized chunks
m := rand.Intn(len(data))
n, err := tcpConn.Write(data[:m])
if err != nil {
break
}
sent = append(sent, data[:n]...)
}
if err := tcpConn.(*net.TCPConn).CloseWrite(); err != nil {
t.Errorf("tcpConn.CloseWrite: %v", err)
}
read := <-readChan
if len(sent) != len(read) {
t.Fatalf("got %d bytes, want %d", len(read), len(sent))
}
if bytes.Compare(sent, read) != 0 {
t.Fatalf("read back data does not match")
}
if err := sshListener.Close(); err != nil {
t.Fatalf("sshListener.Close: %v", err)
}
// Check that the forward disappeared.
tcpConn, err = net.Dial("tcp", forwardedAddr)
if err == nil {
tcpConn.Close()
t.Errorf("still listening to %s after closing", forwardedAddr)
}
}
func TestAcceptClose(t *testing.T) {
server := newServer(t)
defer server.Shutdown()
conn := server.Dial(clientConfig())
sshListener, err := conn.Listen("tcp", "localhost:0")
if err != nil {
t.Fatal(err)
}
quit := make(chan error, 1)
go func() {
for {
c, err := sshListener.Accept()
if err != nil {
quit <- err
break
}
c.Close()
}
}()
sshListener.Close()
select {
case <-time.After(1 * time.Second):
t.Errorf("timeout: listener did not close.")
case err := <-quit:
t.Logf("quit as expected (error %v)", err)
}
}
// Check that listeners exit if the underlying client transport dies.
func TestPortForwardConnectionClose(t *testing.T) {
server := newServer(t)
defer server.Shutdown()
conn := server.Dial(clientConfig())
sshListener, err := conn.Listen("tcp", "localhost:0")
if err != nil {
t.Fatal(err)
}
quit := make(chan error, 1)
go func() {
for {
c, err := sshListener.Accept()
if err != nil {
quit <- err
break
}
c.Close()
}
}()
// It would be even nicer if we closed the server side, but it
// is more involved as the fd for that side is dup()ed.
server.clientConn.Close()
select {
case <-time.After(1 * time.Second):
t.Errorf("timeout: listener did not close.")
case err := <-quit:
t.Logf("quit as expected (error %v)", err)
}
}

View file

@ -0,0 +1,317 @@
// Copyright 2012 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build !windows
package test
// Session functional tests.
import (
"bytes"
"errors"
"golang.org/x/crypto/ssh"
"io"
"strings"
"testing"
)
func TestRunCommandSuccess(t *testing.T) {
server := newServer(t)
defer server.Shutdown()
conn := server.Dial(clientConfig())
defer conn.Close()
session, err := conn.NewSession()
if err != nil {
t.Fatalf("session failed: %v", err)
}
defer session.Close()
err = session.Run("true")
if err != nil {
t.Fatalf("session failed: %v", err)
}
}
func TestHostKeyCheck(t *testing.T) {
server := newServer(t)
defer server.Shutdown()
conf := clientConfig()
hostDB := hostKeyDB()
conf.HostKeyCallback = hostDB.Check
// change the keys.
hostDB.keys[ssh.KeyAlgoRSA][25]++
hostDB.keys[ssh.KeyAlgoDSA][25]++
hostDB.keys[ssh.KeyAlgoECDSA256][25]++
conn, err := server.TryDial(conf)
if err == nil {
conn.Close()
t.Fatalf("dial should have failed.")
} else if !strings.Contains(err.Error(), "host key mismatch") {
t.Fatalf("'host key mismatch' not found in %v", err)
}
}
func TestRunCommandStdin(t *testing.T) {
server := newServer(t)
defer server.Shutdown()
conn := server.Dial(clientConfig())
defer conn.Close()
session, err := conn.NewSession()
if err != nil {
t.Fatalf("session failed: %v", err)
}
defer session.Close()
r, w := io.Pipe()
defer r.Close()
defer w.Close()
session.Stdin = r
err = session.Run("true")
if err != nil {
t.Fatalf("session failed: %v", err)
}
}
func TestRunCommandStdinError(t *testing.T) {
server := newServer(t)
defer server.Shutdown()
conn := server.Dial(clientConfig())
defer conn.Close()
session, err := conn.NewSession()
if err != nil {
t.Fatalf("session failed: %v", err)
}
defer session.Close()
r, w := io.Pipe()
defer r.Close()
session.Stdin = r
pipeErr := errors.New("closing write end of pipe")
w.CloseWithError(pipeErr)
err = session.Run("true")
if err != pipeErr {
t.Fatalf("expected %v, found %v", pipeErr, err)
}
}
func TestRunCommandFailed(t *testing.T) {
server := newServer(t)
defer server.Shutdown()
conn := server.Dial(clientConfig())
defer conn.Close()
session, err := conn.NewSession()
if err != nil {
t.Fatalf("session failed: %v", err)
}
defer session.Close()
err = session.Run(`bash -c "kill -9 $$"`)
if err == nil {
t.Fatalf("session succeeded: %v", err)
}
}
func TestRunCommandWeClosed(t *testing.T) {
server := newServer(t)
defer server.Shutdown()
conn := server.Dial(clientConfig())
defer conn.Close()
session, err := conn.NewSession()
if err != nil {
t.Fatalf("session failed: %v", err)
}
err = session.Shell()
if err != nil {
t.Fatalf("shell failed: %v", err)
}
err = session.Close()
if err != nil {
t.Fatalf("shell failed: %v", err)
}
}
func TestFuncLargeRead(t *testing.T) {
server := newServer(t)
defer server.Shutdown()
conn := server.Dial(clientConfig())
defer conn.Close()
session, err := conn.NewSession()
if err != nil {
t.Fatalf("unable to create new session: %s", err)
}
stdout, err := session.StdoutPipe()
if err != nil {
t.Fatalf("unable to acquire stdout pipe: %s", err)
}
err = session.Start("dd if=/dev/urandom bs=2048 count=1024")
if err != nil {
t.Fatalf("unable to execute remote command: %s", err)
}
buf := new(bytes.Buffer)
n, err := io.Copy(buf, stdout)
if err != nil {
t.Fatalf("error reading from remote stdout: %s", err)
}
if n != 2048*1024 {
t.Fatalf("Expected %d bytes but read only %d from remote command", 2048, n)
}
}
func TestKeyChange(t *testing.T) {
server := newServer(t)
defer server.Shutdown()
conf := clientConfig()
hostDB := hostKeyDB()
conf.HostKeyCallback = hostDB.Check
conf.RekeyThreshold = 1024
conn := server.Dial(conf)
defer conn.Close()
for i := 0; i < 4; i++ {
session, err := conn.NewSession()
if err != nil {
t.Fatalf("unable to create new session: %s", err)
}
stdout, err := session.StdoutPipe()
if err != nil {
t.Fatalf("unable to acquire stdout pipe: %s", err)
}
err = session.Start("dd if=/dev/urandom bs=1024 count=1")
if err != nil {
t.Fatalf("unable to execute remote command: %s", err)
}
buf := new(bytes.Buffer)
n, err := io.Copy(buf, stdout)
if err != nil {
t.Fatalf("error reading from remote stdout: %s", err)
}
want := int64(1024)
if n != want {
t.Fatalf("Expected %d bytes but read only %d from remote command", want, n)
}
}
if changes := hostDB.checkCount; changes < 4 {
t.Errorf("got %d key changes, want 4", changes)
}
}
func TestInvalidTerminalMode(t *testing.T) {
server := newServer(t)
defer server.Shutdown()
conn := server.Dial(clientConfig())
defer conn.Close()
session, err := conn.NewSession()
if err != nil {
t.Fatalf("session failed: %v", err)
}
defer session.Close()
if err = session.RequestPty("vt100", 80, 40, ssh.TerminalModes{255: 1984}); err == nil {
t.Fatalf("req-pty failed: successful request with invalid mode")
}
}
func TestValidTerminalMode(t *testing.T) {
server := newServer(t)
defer server.Shutdown()
conn := server.Dial(clientConfig())
defer conn.Close()
session, err := conn.NewSession()
if err != nil {
t.Fatalf("session failed: %v", err)
}
defer session.Close()
stdout, err := session.StdoutPipe()
if err != nil {
t.Fatalf("unable to acquire stdout pipe: %s", err)
}
stdin, err := session.StdinPipe()
if err != nil {
t.Fatalf("unable to acquire stdin pipe: %s", err)
}
tm := ssh.TerminalModes{ssh.ECHO: 0}
if err = session.RequestPty("xterm", 80, 40, tm); err != nil {
t.Fatalf("req-pty failed: %s", err)
}
err = session.Shell()
if err != nil {
t.Fatalf("session failed: %s", err)
}
stdin.Write([]byte("stty -a && exit\n"))
var buf bytes.Buffer
if _, err := io.Copy(&buf, stdout); err != nil {
t.Fatalf("reading failed: %s", err)
}
if sttyOutput := buf.String(); !strings.Contains(sttyOutput, "-echo ") {
t.Fatalf("terminal mode failure: expected -echo in stty output, got %s", sttyOutput)
}
}
func TestCiphers(t *testing.T) {
var config ssh.Config
config.SetDefaults()
cipherOrder := config.Ciphers
for _, ciph := range cipherOrder {
server := newServer(t)
defer server.Shutdown()
conf := clientConfig()
conf.Ciphers = []string{ciph}
// Don't fail if sshd doesnt have the cipher.
conf.Ciphers = append(conf.Ciphers, cipherOrder...)
conn, err := server.TryDial(conf)
if err == nil {
conn.Close()
} else {
t.Fatalf("failed for cipher %q", ciph)
}
}
}
func TestMACs(t *testing.T) {
var config ssh.Config
config.SetDefaults()
macOrder := config.MACs
for _, mac := range macOrder {
server := newServer(t)
defer server.Shutdown()
conf := clientConfig()
conf.MACs = []string{mac}
// Don't fail if sshd doesnt have the MAC.
conf.MACs = append(conf.MACs, macOrder...)
if conn, err := server.TryDial(conf); err == nil {
conn.Close()
} else {
t.Fatalf("failed for MAC %q", mac)
}
}
}

View file

@ -0,0 +1,46 @@
// Copyright 2012 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build !windows
package test
// direct-tcpip functional tests
import (
"io"
"net"
"testing"
)
func TestDial(t *testing.T) {
server := newServer(t)
defer server.Shutdown()
sshConn := server.Dial(clientConfig())
defer sshConn.Close()
l, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("Listen: %v", err)
}
defer l.Close()
go func() {
for {
c, err := l.Accept()
if err != nil {
break
}
io.WriteString(c, c.RemoteAddr().String())
c.Close()
}
}()
conn, err := sshConn.Dial("tcp", l.Addr().String())
if err != nil {
t.Fatalf("Dial: %v", err)
}
defer conn.Close()
}

View file

@ -0,0 +1,261 @@
// Copyright 2012 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build darwin dragonfly freebsd linux netbsd openbsd plan9
package test
// functional test harness for unix.
import (
"bytes"
"fmt"
"io/ioutil"
"log"
"net"
"os"
"os/exec"
"os/user"
"path/filepath"
"testing"
"text/template"
"golang.org/x/crypto/ssh"
"golang.org/x/crypto/ssh/testdata"
)
const sshd_config = `
Protocol 2
HostKey {{.Dir}}/id_rsa
HostKey {{.Dir}}/id_dsa
HostKey {{.Dir}}/id_ecdsa
Pidfile {{.Dir}}/sshd.pid
#UsePrivilegeSeparation no
KeyRegenerationInterval 3600
ServerKeyBits 768
SyslogFacility AUTH
LogLevel DEBUG2
LoginGraceTime 120
PermitRootLogin no
StrictModes no
RSAAuthentication yes
PubkeyAuthentication yes
AuthorizedKeysFile {{.Dir}}/id_user.pub
TrustedUserCAKeys {{.Dir}}/id_ecdsa.pub
IgnoreRhosts yes
RhostsRSAAuthentication no
HostbasedAuthentication no
`
var configTmpl = template.Must(template.New("").Parse(sshd_config))
type server struct {
t *testing.T
cleanup func() // executed during Shutdown
configfile string
cmd *exec.Cmd
output bytes.Buffer // holds stderr from sshd process
// Client half of the network connection.
clientConn net.Conn
}
func username() string {
var username string
if user, err := user.Current(); err == nil {
username = user.Username
} else {
// user.Current() currently requires cgo. If an error is
// returned attempt to get the username from the environment.
log.Printf("user.Current: %v; falling back on $USER", err)
username = os.Getenv("USER")
}
if username == "" {
panic("Unable to get username")
}
return username
}
type storedHostKey struct {
// keys map from an algorithm string to binary key data.
keys map[string][]byte
// checkCount counts the Check calls. Used for testing
// rekeying.
checkCount int
}
func (k *storedHostKey) Add(key ssh.PublicKey) {
if k.keys == nil {
k.keys = map[string][]byte{}
}
k.keys[key.Type()] = key.Marshal()
}
func (k *storedHostKey) Check(addr string, remote net.Addr, key ssh.PublicKey) error {
k.checkCount++
algo := key.Type()
if k.keys == nil || bytes.Compare(key.Marshal(), k.keys[algo]) != 0 {
return fmt.Errorf("host key mismatch. Got %q, want %q", key, k.keys[algo])
}
return nil
}
func hostKeyDB() *storedHostKey {
keyChecker := &storedHostKey{}
keyChecker.Add(testPublicKeys["ecdsa"])
keyChecker.Add(testPublicKeys["rsa"])
keyChecker.Add(testPublicKeys["dsa"])
return keyChecker
}
func clientConfig() *ssh.ClientConfig {
config := &ssh.ClientConfig{
User: username(),
Auth: []ssh.AuthMethod{
ssh.PublicKeys(testSigners["user"]),
},
HostKeyCallback: hostKeyDB().Check,
}
return config
}
// unixConnection creates two halves of a connected net.UnixConn. It
// is used for connecting the Go SSH client with sshd without opening
// ports.
func unixConnection() (*net.UnixConn, *net.UnixConn, error) {
dir, err := ioutil.TempDir("", "unixConnection")
if err != nil {
return nil, nil, err
}
defer os.Remove(dir)
addr := filepath.Join(dir, "ssh")
listener, err := net.Listen("unix", addr)
if err != nil {
return nil, nil, err
}
defer listener.Close()
c1, err := net.Dial("unix", addr)
if err != nil {
return nil, nil, err
}
c2, err := listener.Accept()
if err != nil {
c1.Close()
return nil, nil, err
}
return c1.(*net.UnixConn), c2.(*net.UnixConn), nil
}
func (s *server) TryDial(config *ssh.ClientConfig) (*ssh.Client, error) {
sshd, err := exec.LookPath("sshd")
if err != nil {
s.t.Skipf("skipping test: %v", err)
}
c1, c2, err := unixConnection()
if err != nil {
s.t.Fatalf("unixConnection: %v", err)
}
s.cmd = exec.Command(sshd, "-f", s.configfile, "-i", "-e")
f, err := c2.File()
if err != nil {
s.t.Fatalf("UnixConn.File: %v", err)
}
defer f.Close()
s.cmd.Stdin = f
s.cmd.Stdout = f
s.cmd.Stderr = &s.output
if err := s.cmd.Start(); err != nil {
s.t.Fail()
s.Shutdown()
s.t.Fatalf("s.cmd.Start: %v", err)
}
s.clientConn = c1
conn, chans, reqs, err := ssh.NewClientConn(c1, "", config)
if err != nil {
return nil, err
}
return ssh.NewClient(conn, chans, reqs), nil
}
func (s *server) Dial(config *ssh.ClientConfig) *ssh.Client {
conn, err := s.TryDial(config)
if err != nil {
s.t.Fail()
s.Shutdown()
s.t.Fatalf("ssh.Client: %v", err)
}
return conn
}
func (s *server) Shutdown() {
if s.cmd != nil && s.cmd.Process != nil {
// Don't check for errors; if it fails it's most
// likely "os: process already finished", and we don't
// care about that. Use os.Interrupt, so child
// processes are killed too.
s.cmd.Process.Signal(os.Interrupt)
s.cmd.Wait()
}
if s.t.Failed() {
// log any output from sshd process
s.t.Logf("sshd: %s", s.output.String())
}
s.cleanup()
}
func writeFile(path string, contents []byte) {
f, err := os.OpenFile(path, os.O_WRONLY|os.O_TRUNC|os.O_CREATE, 0600)
if err != nil {
panic(err)
}
defer f.Close()
if _, err := f.Write(contents); err != nil {
panic(err)
}
}
// newServer returns a new mock ssh server.
func newServer(t *testing.T) *server {
if testing.Short() {
t.Skip("skipping test due to -short")
}
dir, err := ioutil.TempDir("", "sshtest")
if err != nil {
t.Fatal(err)
}
f, err := os.Create(filepath.Join(dir, "sshd_config"))
if err != nil {
t.Fatal(err)
}
err = configTmpl.Execute(f, map[string]string{
"Dir": dir,
})
if err != nil {
t.Fatal(err)
}
f.Close()
for k, v := range testdata.PEMBytes {
filename := "id_" + k
writeFile(filepath.Join(dir, filename), v)
writeFile(filepath.Join(dir, filename+".pub"), ssh.MarshalAuthorizedKey(testPublicKeys[k]))
}
return &server{
t: t,
configfile: f.Name(),
cleanup: func() {
if err := os.RemoveAll(dir); err != nil {
t.Error(err)
}
},
}
}

View file

@ -0,0 +1,64 @@
// Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// IMPLEMENTOR NOTE: To avoid a package loop, this file is in three places:
// ssh/, ssh/agent, and ssh/test/. It should be kept in sync across all three
// instances.
package test
import (
"crypto/rand"
"fmt"
"golang.org/x/crypto/ssh"
"golang.org/x/crypto/ssh/testdata"
)
var (
testPrivateKeys map[string]interface{}
testSigners map[string]ssh.Signer
testPublicKeys map[string]ssh.PublicKey
)
func init() {
var err error
n := len(testdata.PEMBytes)
testPrivateKeys = make(map[string]interface{}, n)
testSigners = make(map[string]ssh.Signer, n)
testPublicKeys = make(map[string]ssh.PublicKey, n)
for t, k := range testdata.PEMBytes {
testPrivateKeys[t], err = ssh.ParseRawPrivateKey(k)
if err != nil {
panic(fmt.Sprintf("Unable to parse test key %s: %v", t, err))
}
testSigners[t], err = ssh.NewSignerFromKey(testPrivateKeys[t])
if err != nil {
panic(fmt.Sprintf("Unable to create signer for test key %s: %v", t, err))
}
testPublicKeys[t] = testSigners[t].PublicKey()
}
// Create a cert and sign it for use in tests.
testCert := &ssh.Certificate{
Nonce: []byte{}, // To pass reflect.DeepEqual after marshal & parse, this must be non-nil
ValidPrincipals: []string{"gopher1", "gopher2"}, // increases test coverage
ValidAfter: 0, // unix epoch
ValidBefore: ssh.CertTimeInfinity, // The end of currently representable time.
Reserved: []byte{}, // To pass reflect.DeepEqual after marshal & parse, this must be non-nil
Key: testPublicKeys["ecdsa"],
SignatureKey: testPublicKeys["rsa"],
Permissions: ssh.Permissions{
CriticalOptions: map[string]string{},
Extensions: map[string]string{},
},
}
testCert.SignCert(rand.Reader, testSigners["rsa"])
testPrivateKeys["cert"] = testPrivateKeys["ecdsa"]
testSigners["cert"], err = ssh.NewCertSigner(testCert, testSigners["ecdsa"])
if err != nil {
panic(fmt.Sprintf("Unable to create certificate signer: %v", err))
}
}

View file

@ -0,0 +1,8 @@
// Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// This package contains test data shared between the various subpackages of
// the code.google.com/p/go.crypto/ssh package. Under no circumstance should
// this data be used for production code.
package testdata

View file

@ -0,0 +1,43 @@
// Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package testdata
var PEMBytes = map[string][]byte{
"dsa": []byte(`-----BEGIN DSA PRIVATE KEY-----
MIIBuwIBAAKBgQD6PDSEyXiI9jfNs97WuM46MSDCYlOqWw80ajN16AohtBncs1YB
lHk//dQOvCYOsYaE+gNix2jtoRjwXhDsc25/IqQbU1ahb7mB8/rsaILRGIbA5WH3
EgFtJmXFovDz3if6F6TzvhFpHgJRmLYVR8cqsezL3hEZOvvs2iH7MorkxwIVAJHD
nD82+lxh2fb4PMsIiaXudAsBAoGAQRf7Q/iaPRn43ZquUhd6WwvirqUj+tkIu6eV
2nZWYmXLlqFQKEy4Tejl7Wkyzr2OSYvbXLzo7TNxLKoWor6ips0phYPPMyXld14r
juhT24CrhOzuLMhDduMDi032wDIZG4Y+K7ElU8Oufn8Sj5Wge8r6ANmmVgmFfynr
FhdYCngCgYEA3ucGJ93/Mx4q4eKRDxcWD3QzWyqpbRVRRV1Vmih9Ha/qC994nJFz
DQIdjxDIT2Rk2AGzMqFEB68Zc3O+Wcsmz5eWWzEwFxaTwOGWTyDqsDRLm3fD+QYj
nOwuxb0Kce+gWI8voWcqC9cyRm09jGzu2Ab3Bhtpg8JJ8L7gS3MRZK4CFEx4UAfY
Fmsr0W6fHB9nhS4/UXM8
-----END DSA PRIVATE KEY-----
`),
"ecdsa": []byte(`-----BEGIN EC PRIVATE KEY-----
MHcCAQEEINGWx0zo6fhJ/0EAfrPzVFyFC9s18lBt3cRoEDhS3ARooAoGCCqGSM49
AwEHoUQDQgAEi9Hdw6KvZcWxfg2IDhA7UkpDtzzt6ZqJXSsFdLd+Kx4S3Sx4cVO+
6/ZOXRnPmNAlLUqjShUsUBBngG0u2fqEqA==
-----END EC PRIVATE KEY-----
`),
"rsa": []byte(`-----BEGIN RSA PRIVATE KEY-----
MIIBOwIBAAJBALdGZxkXDAjsYk10ihwU6Id2KeILz1TAJuoq4tOgDWxEEGeTrcld
r/ZwVaFzjWzxaf6zQIJbfaSEAhqD5yo72+sCAwEAAQJBAK8PEVU23Wj8mV0QjwcJ
tZ4GcTUYQL7cF4+ezTCE9a1NrGnCP2RuQkHEKxuTVrxXt+6OF15/1/fuXnxKjmJC
nxkCIQDaXvPPBi0c7vAxGwNY9726x01/dNbHCE0CBtcotobxpwIhANbbQbh3JHVW
2haQh4fAG5mhesZKAGcxTyv4mQ7uMSQdAiAj+4dzMpJWdSzQ+qGHlHMIBvVHLkqB
y2VdEyF7DPCZewIhAI7GOI/6LDIFOvtPo6Bj2nNmyQ1HU6k/LRtNIXi4c9NJAiAr
rrxx26itVhJmcvoUhOjwuzSlP2bE5VHAvkGB352YBg==
-----END RSA PRIVATE KEY-----
`),
"user": []byte(`-----BEGIN EC PRIVATE KEY-----
MHcCAQEEILYCAeq8f7V4vSSypRw7pxy8yz3V5W4qg8kSC3zJhqpQoAoGCCqGSM49
AwEHoUQDQgAEYcO2xNKiRUYOLEHM7VYAp57HNyKbOdYtHD83Z4hzNPVC4tM5mdGD
PLL8IEwvYu2wq+lpXfGQnNMbzYf9gspG0w==
-----END EC PRIVATE KEY-----
`),
}

View file

@ -0,0 +1,63 @@
// Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// IMPLEMENTOR NOTE: To avoid a package loop, this file is in three places:
// ssh/, ssh/agent, and ssh/test/. It should be kept in sync across all three
// instances.
package ssh
import (
"crypto/rand"
"fmt"
"golang.org/x/crypto/ssh/testdata"
)
var (
testPrivateKeys map[string]interface{}
testSigners map[string]Signer
testPublicKeys map[string]PublicKey
)
func init() {
var err error
n := len(testdata.PEMBytes)
testPrivateKeys = make(map[string]interface{}, n)
testSigners = make(map[string]Signer, n)
testPublicKeys = make(map[string]PublicKey, n)
for t, k := range testdata.PEMBytes {
testPrivateKeys[t], err = ParseRawPrivateKey(k)
if err != nil {
panic(fmt.Sprintf("Unable to parse test key %s: %v", t, err))
}
testSigners[t], err = NewSignerFromKey(testPrivateKeys[t])
if err != nil {
panic(fmt.Sprintf("Unable to create signer for test key %s: %v", t, err))
}
testPublicKeys[t] = testSigners[t].PublicKey()
}
// Create a cert and sign it for use in tests.
testCert := &Certificate{
Nonce: []byte{}, // To pass reflect.DeepEqual after marshal & parse, this must be non-nil
ValidPrincipals: []string{"gopher1", "gopher2"}, // increases test coverage
ValidAfter: 0, // unix epoch
ValidBefore: CertTimeInfinity, // The end of currently representable time.
Reserved: []byte{}, // To pass reflect.DeepEqual after marshal & parse, this must be non-nil
Key: testPublicKeys["ecdsa"],
SignatureKey: testPublicKeys["rsa"],
Permissions: Permissions{
CriticalOptions: map[string]string{},
Extensions: map[string]string{},
},
}
testCert.SignCert(rand.Reader, testSigners["rsa"])
testPrivateKeys["cert"] = testPrivateKeys["ecdsa"]
testSigners["cert"], err = NewCertSigner(testCert, testSigners["ecdsa"])
if err != nil {
panic(fmt.Sprintf("Unable to create certificate signer: %v", err))
}
}

View file

@ -0,0 +1,327 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ssh
import (
"bufio"
"errors"
"io"
)
const (
gcmCipherID = "aes128-gcm@openssh.com"
)
// packetConn represents a transport that implements packet based
// operations.
type packetConn interface {
// Encrypt and send a packet of data to the remote peer.
writePacket(packet []byte) error
// Read a packet from the connection
readPacket() ([]byte, error)
// Close closes the write-side of the connection.
Close() error
}
// transport is the keyingTransport that implements the SSH packet
// protocol.
type transport struct {
reader connectionState
writer connectionState
bufReader *bufio.Reader
bufWriter *bufio.Writer
rand io.Reader
io.Closer
// Initial H used for the session ID. Once assigned this does
// not change, even during subsequent key exchanges.
sessionID []byte
}
func (t *transport) getSessionID() []byte {
if t.sessionID == nil {
panic("session ID not set yet")
}
s := make([]byte, len(t.sessionID))
copy(s, t.sessionID)
return s
}
// packetCipher represents a combination of SSH encryption/MAC
// protocol. A single instance should be used for one direction only.
type packetCipher interface {
// writePacket encrypts the packet and writes it to w. The
// contents of the packet are generally scrambled.
writePacket(seqnum uint32, w io.Writer, rand io.Reader, packet []byte) error
// readPacket reads and decrypts a packet of data. The
// returned packet may be overwritten by future calls of
// readPacket.
readPacket(seqnum uint32, r io.Reader) ([]byte, error)
}
// connectionState represents one side (read or write) of the
// connection. This is necessary because each direction has its own
// keys, and can even have its own algorithms
type connectionState struct {
packetCipher
seqNum uint32
dir direction
pendingKeyChange chan packetCipher
}
// prepareKeyChange sets up key material for a keychange. The key changes in
// both directions are triggered by reading and writing a msgNewKey packet
// respectively.
func (t *transport) prepareKeyChange(algs *algorithms, kexResult *kexResult) error {
if t.sessionID == nil {
t.sessionID = kexResult.H
}
kexResult.SessionID = t.sessionID
if ciph, err := newPacketCipher(t.reader.dir, algs.r, kexResult); err != nil {
return err
} else {
t.reader.pendingKeyChange <- ciph
}
if ciph, err := newPacketCipher(t.writer.dir, algs.w, kexResult); err != nil {
return err
} else {
t.writer.pendingKeyChange <- ciph
}
return nil
}
// Read and decrypt next packet.
func (t *transport) readPacket() ([]byte, error) {
return t.reader.readPacket(t.bufReader)
}
func (s *connectionState) readPacket(r *bufio.Reader) ([]byte, error) {
packet, err := s.packetCipher.readPacket(s.seqNum, r)
s.seqNum++
if err == nil && len(packet) == 0 {
err = errors.New("ssh: zero length packet")
}
if len(packet) > 0 && packet[0] == msgNewKeys {
select {
case cipher := <-s.pendingKeyChange:
s.packetCipher = cipher
default:
return nil, errors.New("ssh: got bogus newkeys message.")
}
}
// The packet may point to an internal buffer, so copy the
// packet out here.
fresh := make([]byte, len(packet))
copy(fresh, packet)
return fresh, err
}
func (t *transport) writePacket(packet []byte) error {
return t.writer.writePacket(t.bufWriter, t.rand, packet)
}
func (s *connectionState) writePacket(w *bufio.Writer, rand io.Reader, packet []byte) error {
changeKeys := len(packet) > 0 && packet[0] == msgNewKeys
err := s.packetCipher.writePacket(s.seqNum, w, rand, packet)
if err != nil {
return err
}
if err = w.Flush(); err != nil {
return err
}
s.seqNum++
if changeKeys {
select {
case cipher := <-s.pendingKeyChange:
s.packetCipher = cipher
default:
panic("ssh: no key material for msgNewKeys")
}
}
return err
}
func newTransport(rwc io.ReadWriteCloser, rand io.Reader, isClient bool) *transport {
t := &transport{
bufReader: bufio.NewReader(rwc),
bufWriter: bufio.NewWriter(rwc),
rand: rand,
reader: connectionState{
packetCipher: &streamPacketCipher{cipher: noneCipher{}},
pendingKeyChange: make(chan packetCipher, 1),
},
writer: connectionState{
packetCipher: &streamPacketCipher{cipher: noneCipher{}},
pendingKeyChange: make(chan packetCipher, 1),
},
Closer: rwc,
}
if isClient {
t.reader.dir = serverKeys
t.writer.dir = clientKeys
} else {
t.reader.dir = clientKeys
t.writer.dir = serverKeys
}
return t
}
type direction struct {
ivTag []byte
keyTag []byte
macKeyTag []byte
}
var (
serverKeys = direction{[]byte{'B'}, []byte{'D'}, []byte{'F'}}
clientKeys = direction{[]byte{'A'}, []byte{'C'}, []byte{'E'}}
)
// generateKeys generates key material for IV, MAC and encryption.
func generateKeys(d direction, algs directionAlgorithms, kex *kexResult) (iv, key, macKey []byte) {
cipherMode := cipherModes[algs.Cipher]
macMode := macModes[algs.MAC]
iv = make([]byte, cipherMode.ivSize)
key = make([]byte, cipherMode.keySize)
macKey = make([]byte, macMode.keySize)
generateKeyMaterial(iv, d.ivTag, kex)
generateKeyMaterial(key, d.keyTag, kex)
generateKeyMaterial(macKey, d.macKeyTag, kex)
return
}
// setupKeys sets the cipher and MAC keys from kex.K, kex.H and sessionId, as
// described in RFC 4253, section 6.4. direction should either be serverKeys
// (to setup server->client keys) or clientKeys (for client->server keys).
func newPacketCipher(d direction, algs directionAlgorithms, kex *kexResult) (packetCipher, error) {
iv, key, macKey := generateKeys(d, algs, kex)
if algs.Cipher == gcmCipherID {
return newGCMCipher(iv, key, macKey)
}
c := &streamPacketCipher{
mac: macModes[algs.MAC].new(macKey),
}
c.macResult = make([]byte, c.mac.Size())
var err error
c.cipher, err = cipherModes[algs.Cipher].createStream(key, iv)
if err != nil {
return nil, err
}
return c, nil
}
// generateKeyMaterial fills out with key material generated from tag, K, H
// and sessionId, as specified in RFC 4253, section 7.2.
func generateKeyMaterial(out, tag []byte, r *kexResult) {
var digestsSoFar []byte
h := r.Hash.New()
for len(out) > 0 {
h.Reset()
h.Write(r.K)
h.Write(r.H)
if len(digestsSoFar) == 0 {
h.Write(tag)
h.Write(r.SessionID)
} else {
h.Write(digestsSoFar)
}
digest := h.Sum(nil)
n := copy(out, digest)
out = out[n:]
if len(out) > 0 {
digestsSoFar = append(digestsSoFar, digest...)
}
}
}
const packageVersion = "SSH-2.0-Go"
// Sends and receives a version line. The versionLine string should
// be US ASCII, start with "SSH-2.0-", and should not include a
// newline. exchangeVersions returns the other side's version line.
func exchangeVersions(rw io.ReadWriter, versionLine []byte) (them []byte, err error) {
// Contrary to the RFC, we do not ignore lines that don't
// start with "SSH-2.0-" to make the library usable with
// nonconforming servers.
for _, c := range versionLine {
// The spec disallows non US-ASCII chars, and
// specifically forbids null chars.
if c < 32 {
return nil, errors.New("ssh: junk character in version line")
}
}
if _, err = rw.Write(append(versionLine, '\r', '\n')); err != nil {
return
}
them, err = readVersion(rw)
return them, err
}
// maxVersionStringBytes is the maximum number of bytes that we'll
// accept as a version string. RFC 4253 section 4.2 limits this at 255
// chars
const maxVersionStringBytes = 255
// Read version string as specified by RFC 4253, section 4.2.
func readVersion(r io.Reader) ([]byte, error) {
versionString := make([]byte, 0, 64)
var ok bool
var buf [1]byte
for len(versionString) < maxVersionStringBytes {
_, err := io.ReadFull(r, buf[:])
if err != nil {
return nil, err
}
// The RFC says that the version should be terminated with \r\n
// but several SSH servers actually only send a \n.
if buf[0] == '\n' {
ok = true
break
}
// non ASCII chars are disallowed, but we are lenient,
// since Go doesn't use null-terminated strings.
// The RFC allows a comment after a space, however,
// all of it (version and comments) goes into the
// session hash.
versionString = append(versionString, buf[0])
}
if !ok {
return nil, errors.New("ssh: overflow reading version string")
}
// There might be a '\r' on the end which we should remove.
if len(versionString) > 0 && versionString[len(versionString)-1] == '\r' {
versionString = versionString[:len(versionString)-1]
}
return versionString, nil
}

View file

@ -0,0 +1,109 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ssh
import (
"bytes"
"crypto/rand"
"encoding/binary"
"strings"
"testing"
)
func TestReadVersion(t *testing.T) {
longversion := strings.Repeat("SSH-2.0-bla", 50)[:253]
cases := map[string]string{
"SSH-2.0-bla\r\n": "SSH-2.0-bla",
"SSH-2.0-bla\n": "SSH-2.0-bla",
longversion + "\r\n": longversion,
}
for in, want := range cases {
result, err := readVersion(bytes.NewBufferString(in))
if err != nil {
t.Errorf("readVersion(%q): %s", in, err)
}
got := string(result)
if got != want {
t.Errorf("got %q, want %q", got, want)
}
}
}
func TestReadVersionError(t *testing.T) {
longversion := strings.Repeat("SSH-2.0-bla", 50)[:253]
cases := []string{
longversion + "too-long\r\n",
}
for _, in := range cases {
if _, err := readVersion(bytes.NewBufferString(in)); err == nil {
t.Errorf("readVersion(%q) should have failed", in)
}
}
}
func TestExchangeVersionsBasic(t *testing.T) {
v := "SSH-2.0-bla"
buf := bytes.NewBufferString(v + "\r\n")
them, err := exchangeVersions(buf, []byte("xyz"))
if err != nil {
t.Errorf("exchangeVersions: %v", err)
}
if want := "SSH-2.0-bla"; string(them) != want {
t.Errorf("got %q want %q for our version", them, want)
}
}
func TestExchangeVersions(t *testing.T) {
cases := []string{
"not\x000allowed",
"not allowed\n",
}
for _, c := range cases {
buf := bytes.NewBufferString("SSH-2.0-bla\r\n")
if _, err := exchangeVersions(buf, []byte(c)); err == nil {
t.Errorf("exchangeVersions(%q): should have failed", c)
}
}
}
type closerBuffer struct {
bytes.Buffer
}
func (b *closerBuffer) Close() error {
return nil
}
func TestTransportMaxPacketWrite(t *testing.T) {
buf := &closerBuffer{}
tr := newTransport(buf, rand.Reader, true)
huge := make([]byte, maxPacket+1)
err := tr.writePacket(huge)
if err == nil {
t.Errorf("transport accepted write for a huge packet.")
}
}
func TestTransportMaxPacketReader(t *testing.T) {
var header [5]byte
huge := make([]byte, maxPacket+128)
binary.BigEndian.PutUint32(header[0:], uint32(len(huge)))
// padding.
header[4] = 0
buf := &closerBuffer{}
buf.Write(header[:])
buf.Write(huge)
tr := newTransport(buf, rand.Reader, true)
_, err := tr.readPacket()
if err == nil {
t.Errorf("transport succeeded reading huge packet.")
} else if !strings.Contains(err.Error(), "large") {
t.Errorf("got %q, should mention %q", err.Error(), "large")
}
}

View file

@ -0,0 +1,7 @@
build:
go build
push:
git push origin master
git push github master

View file

@ -0,0 +1,40 @@
Package migration for Golang automatically handles versioning of a database
schema by applying a series of migrations supplied by the client. It uses
features only from the database/sql package, so it tries to be driver
independent. However, to track the version of the database, it is necessary to
execute some SQL. I've made an effort to keep those queries simple, but if they
don't work with your database, you may override them.
This package works by applying a series of migrations to a database. Once a
migration is created, it should never be changed. Every time a database is
opened with this package, all necessary migrations are executed in a single
transaction. If any part of the process fails, an error is returned and the
transaction is rolled back so that the database is left untouched. (Note that
for this to be useful, you'll need to use a database that supports rolling back
changes to your schema. Notably, MySQL does not support this, although SQLite
and PostgreSQL do.)
The version of a database is defined as the number of migrations applied to it.
### Installation
If you have Go installed and
[your GOPATH is setup](http://golang.org/doc/code.html#GOPATH), then
`migration` can be installed with `go get`:
go get github.com/BurntSushi/migration
### Documentation
Documentation is available at
[godoc.org/github.com/BurntSushi/migration](http://godoc.org/github.com/BurntSushi/migration).
### Unstable
At the moment, I'm still experimenting with the public API, so I may still
introduce breaking changes. In general though, I am happy with the overall
architecture.

View file

@ -0,0 +1,24 @@
This is free and unencumbered software released into the public domain.
Anyone is free to copy, modify, publish, use, compile, sell, or
distribute this software, either in source code form or as a compiled
binary, for any purpose, commercial or non-commercial, and by any
means.
In jurisdictions that recognize copyright laws, the author or authors
of this software dedicate any and all copyright interest in the
software to the public domain. We make this dedication for the benefit
of the public at large and to the detriment of our heirs and
successors. We intend this dedication to be an overt act of
relinquishment in perpetuity of all present and future rights to this
software under copyright law.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
IN NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY CLAIM, DAMAGES OR
OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE,
ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
OTHER DEALINGS IN THE SOFTWARE.
For more information, please refer to <http://unlicense.org/>

View file

@ -0,0 +1,20 @@
/*
Package migration automatically handles versioning of a database
schema by applying a series of migrations supplied by the client. It uses
features only from the database/sql package, so it tries to be driver
independent. However, to track the version of the database, it is necessary to
execute some SQL. I've made an effort to keep those queries simple, but if they
don't work with your database, you may override them.
This package works by applying a series of migrations to a database. Once a
migration is created, it should never be changed. Every time a database is
opened with this package, all necessary migrations are executed in a single
transaction. If any part of the process fails, an error is returned and the
transaction is rolled back so that the database is left untouched. (Note that
for this to be useful, you'll need to use a database that supports rolling back
changes to your schema. Notably, MySQL does not support this, although SQLite
and PostgreSQL do.)
The version of a database is defined as the number of migrations applied to it.
*/
package migration

View file

@ -0,0 +1,237 @@
package migration
import (
"database/sql"
"fmt"
)
var ef = fmt.Errorf
// LimitedTx specifies the behavior of a transaction *without* commit and
// rollback functions. Values with this type are given to client functions.
// In particular, the migration routines in this package
// handle transaction commits and rollbacks. Therefore the functions provided
// by the client should not use them.
type LimitedTx interface {
Exec(query string, args ...interface{}) (sql.Result, error)
Prepare(query string) (*sql.Stmt, error)
Query(query string, args ...interface{}) (*sql.Rows, error)
QueryRow(query string, args ...interface{}) *sql.Row
Stmt(stmt *sql.Stmt) *sql.Stmt
}
// GetVersion is any function that can retrieve the migration version of a
// particular database. It is exposed in case a client wants to override the
// default behavior of this package. (For example, by using the `user_version`
// PRAGMA in SQLite.)
//
// The DefaultGetVersion function provided with this package creates its own
// table with a single column and a single row.
//
// The version returned should be equivalent to the number of migrations
// applied to this database. It should be 0 if no migrations have been applied
// yet.
//
// If an error is returned, the migration automatically fails.
//
// Note that a LimitedTx is used to emphasize that functions with this type
// MUST NOT call Commit or Rollback. The migration routine in this pacakge will
// do it for you.
type GetVersion func(LimitedTx) (int, error)
// The default way to get the version from a database. If the database has
// had no migrations performed, then it creates a table with a single row and
// a single column storing the version as 0. It then returns 0.
//
// If the table exists, then the version stored in the table is returned.
var DefaultGetVersion GetVersion = defaultGetVersion
// SetVersion is the dual of GetVersion. It allows the client to define a
// different mechanism for setting the database version than the one used by
// DefaultSetVersion in this package.
//
// If an error is returned, the migration that tried to set the version
// automatically fails.
//
// Note that a LimitedTx is used to emphasize that functions with this type
// MUST NOT call Commit or Rollback. The migration routine in this pacakge will
// do it for you.
type SetVersion func(LimitedTx, int) error
// The default way to set the version of the database. If the database has had
// no migrations performed, then it creates a table with a single row and a
// single column and storing the version given there.
//
// If the table exists, then the existing version is overwritten.
var DefaultSetVersion SetVersion = defaultSetVersion
// Migrator corresponds to a function that updates the database by one version.
// Note that a migration should NOT call Rollback or Commit. Instead, this
// package will call Rollback for you if your migration returns an error. If
// no error is returned, then the next migration is applied. When all
// migrations have been applied, the version is updated and the changes are
// committed to the database.
type Migrator func(LimitedTx) error
// Open wraps the Open function from the database/sql package, but performs
// a series of migrations on a database if they haven't been performed already.
//
// Migrations are tracked by a simple versioning scheme. The version of the
// database is the number of migrations that have been performed on it.
// Similarly, the version of your library is the number of migrations that are
// given to this function.
//
// If Open returns successfully, then the database and your library will have
// the same versions. If there was a problem migrating---or if the database
// version is greater than your library version---then an error is returned.
// Since all migrations are performed in a single transaction, if an error
// occurs, no changes are made to the database. (Assuming you're using a
// relational database that allows modifications to a schema to be rolled back.)
//
// Note that this versioning scheme includes no semantic analysis. It is up to
// client to ensure that once a migration is defined, it never changes.
//
// The details of how the version is stored are opaque to the client, but in
// general, it will add a table to your database called "migration_version"
// with a single column containing a single row.
func Open(driver, dsn string, migrations []Migrator) (*sql.DB, error) {
return OpenWith(driver, dsn, migrations, nil, nil)
}
// OpenWith is exactly like Open, except it allows the client to specify their
// own versioning scheme. Note that vget and vset must BOTH be
// nil or BOTH be non-nil. Otherwise, this function panics. This is because the
// implementation of one generally relies on the implementation of the other.
//
// If vget and vset are both set to nil, then the behavior of this
// function is identical to the behavior of Open.
func OpenWith(
driver, dsn string,
migrations []Migrator,
vget GetVersion, vset SetVersion,
) (*sql.DB, error) {
if (vget == nil && vset != nil) || (vget != nil && vset == nil) {
panic("vget/vset must both be nil or both be non-nil")
}
if vget == nil {
vget = DefaultGetVersion
}
if vset == nil {
vset = DefaultSetVersion
}
db, err := sql.Open(driver, dsn)
if err != nil {
return nil, err
}
if err := (migration{db, migrations, vget, vset}).migrate(); err != nil {
return nil, err
}
return db, nil
}
type migration struct {
*sql.DB
migrations []Migrator
getVersion GetVersion
setVersion SetVersion
}
// Stmt satisfies the LimitedTx interface.
func (m migration) Stmt(stmt *sql.Stmt) *sql.Stmt {
return stmt
}
func (m migration) migrate() error {
libVersion := len(m.migrations)
dbVersion, err := m.getVersion(m)
if err != nil {
return ef("Could not get DB version: %s", err)
}
if dbVersion > libVersion {
return ef("Database version (%d) is greater than library version (%d).",
dbVersion, libVersion)
}
if dbVersion == libVersion {
return nil
}
tx, err := m.Begin()
if err != nil {
return ef("Could not start transaction: %s", err)
}
for i := dbVersion; i < libVersion; i++ {
if err := m.migrations[i](tx); err != nil {
if err2 := tx.Rollback(); err2 != nil {
return ef(
"When migrating from %d to %d, got error '%s' and "+
"got error '%s' after trying to rollback.",
i, i+1, err, err2)
}
return ef(
"When migrating from %d to %d, got error '%s' and "+
"successfully rolled back.", i, i+1, err)
}
}
if err := m.setVersion(tx, libVersion); err != nil {
if err2 := tx.Rollback(); err2 != nil {
return ef(
"When trying to set version to %d (from %d), got error '%s' "+
"and got error '%s' after trying to rollback.",
libVersion, dbVersion, err, err2)
}
return ef(
"When trying to set version to %d (from %d), got error '%s' "+
"and successfully rolled back.",
libVersion, dbVersion, err)
}
if err := tx.Commit(); err != nil {
return ef("Error committing migration from %d to %d: %s",
dbVersion, libVersion, err)
}
return nil
}
func defaultGetVersion(tx LimitedTx) (int, error) {
v, err := getVersion(tx)
if err != nil {
if err := createVersionTable(tx); err != nil {
return 0, err
}
return getVersion(tx)
}
return v, nil
}
func defaultSetVersion(tx LimitedTx, version int) error {
if err := setVersion(tx, version); err != nil {
if err := createVersionTable(tx); err != nil {
return err
}
return setVersion(tx, version)
}
return nil
}
func getVersion(tx LimitedTx) (int, error) {
var version int
r := tx.QueryRow("SELECT version FROM migration_version")
if err := r.Scan(&version); err != nil {
return 0, err
}
return version, nil
}
func setVersion(tx LimitedTx, version int) error {
_, err := tx.Exec("UPDATE migration_version SET version = $1", version)
return err
}
func createVersionTable(tx LimitedTx) error {
_, err := tx.Exec(`
CREATE TABLE migration_version (
version INTEGER
);
INSERT INTO migration_version (version) VALUES (0)`)
return err
}

View file

@ -0,0 +1 @@
au BufWritePost *.go silent!make tags > /dev/null 2>&1

View file

@ -0,0 +1,5 @@
TAGS
tags
.*.swp
tomlcheck/tomlcheck
toml.test

View file

@ -0,0 +1,12 @@
language: go
go:
- 1.1
- 1.2
- tip
install:
- go install ./...
- go get github.com/BurntSushi/toml-test
script:
- export PATH="$PATH:$HOME/gopath/bin"
- make test

View file

@ -0,0 +1,3 @@
Compatible with TOML version
[v0.2.0](https://github.com/mojombo/toml/blob/master/versions/toml-v0.2.0.md)

View file

@ -0,0 +1,14 @@
DO WHAT THE FUCK YOU WANT TO PUBLIC LICENSE
Version 2, December 2004
Copyright (C) 2004 Sam Hocevar <sam@hocevar.net>
Everyone is permitted to copy and distribute verbatim or modified
copies of this license document, and changing it is allowed as long
as the name is changed.
DO WHAT THE FUCK YOU WANT TO PUBLIC LICENSE
TERMS AND CONDITIONS FOR COPYING, DISTRIBUTION AND MODIFICATION
0. You just DO WHAT THE FUCK YOU WANT TO.

View file

@ -0,0 +1,19 @@
install:
go install ./...
test: install
go test -v
toml-test toml-test-decoder
toml-test -encoder toml-test-encoder
fmt:
gofmt -w *.go */*.go
colcheck *.go */*.go
tags:
find ./ -name '*.go' -print0 | xargs -0 gotags > TAGS
push:
git push origin master
git push github master

View file

@ -0,0 +1,220 @@
## TOML parser and encoder for Go with reflection
TOML stands for Tom's Obvious, Minimal Language. This Go package provides a
reflection interface similar to Go's standard library `json` and `xml`
packages. This package also supports the `encoding.TextUnmarshaler` and
`encoding.TextMarshaler` interfaces so that you can define custom data
representations. (There is an example of this below.)
Spec: https://github.com/mojombo/toml
Compatible with TOML version
[v0.2.0](https://github.com/toml-lang/toml/blob/master/versions/en/toml-v0.2.0.md)
Documentation: http://godoc.org/github.com/BurntSushi/toml
Installation:
```bash
go get github.com/BurntSushi/toml
```
Try the toml validator:
```bash
go get github.com/BurntSushi/toml/cmd/tomlv
tomlv some-toml-file.toml
```
[![Build status](https://api.travis-ci.org/BurntSushi/toml.png)](https://travis-ci.org/BurntSushi/toml)
### Testing
This package passes all tests in
[toml-test](https://github.com/BurntSushi/toml-test) for both the decoder
and the encoder.
### Examples
This package works similarly to how the Go standard library handles `XML`
and `JSON`. Namely, data is loaded into Go values via reflection.
For the simplest example, consider some TOML file as just a list of keys
and values:
```toml
Age = 25
Cats = [ "Cauchy", "Plato" ]
Pi = 3.14
Perfection = [ 6, 28, 496, 8128 ]
DOB = 1987-07-05T05:45:00Z
```
Which could be defined in Go as:
```go
type Config struct {
Age int
Cats []string
Pi float64
Perfection []int
DOB time.Time // requires `import time`
}
```
And then decoded with:
```go
var conf Config
if _, err := toml.Decode(tomlData, &conf); err != nil {
// handle error
}
```
You can also use struct tags if your struct field name doesn't map to a TOML
key value directly:
```toml
some_key_NAME = "wat"
```
```go
type TOML struct {
ObscureKey string `toml:"some_key_NAME"`
}
```
### Using the `encoding.TextUnmarshaler` interface
Here's an example that automatically parses duration strings into
`time.Duration` values:
```toml
[[song]]
name = "Thunder Road"
duration = "4m49s"
[[song]]
name = "Stairway to Heaven"
duration = "8m03s"
```
Which can be decoded with:
```go
type song struct {
Name string
Duration duration
}
type songs struct {
Song []song
}
var favorites songs
if _, err := toml.Decode(blob, &favorites); err != nil {
log.Fatal(err)
}
for _, s := range favorites.Song {
fmt.Printf("%s (%s)\n", s.Name, s.Duration)
}
```
And you'll also need a `duration` type that satisfies the
`encoding.TextUnmarshaler` interface:
```go
type duration struct {
time.Duration
}
func (d *duration) UnmarshalText(text []byte) error {
var err error
d.Duration, err = time.ParseDuration(string(text))
return err
}
```
### More complex usage
Here's an example of how to load the example from the official spec page:
```toml
# This is a TOML document. Boom.
title = "TOML Example"
[owner]
name = "Tom Preston-Werner"
organization = "GitHub"
bio = "GitHub Cofounder & CEO\nLikes tater tots and beer."
dob = 1979-05-27T07:32:00Z # First class dates? Why not?
[database]
server = "192.168.1.1"
ports = [ 8001, 8001, 8002 ]
connection_max = 5000
enabled = true
[servers]
# You can indent as you please. Tabs or spaces. TOML don't care.
[servers.alpha]
ip = "10.0.0.1"
dc = "eqdc10"
[servers.beta]
ip = "10.0.0.2"
dc = "eqdc10"
[clients]
data = [ ["gamma", "delta"], [1, 2] ] # just an update to make sure parsers support it
# Line breaks are OK when inside arrays
hosts = [
"alpha",
"omega"
]
```
And the corresponding Go types are:
```go
type tomlConfig struct {
Title string
Owner ownerInfo
DB database `toml:"database"`
Servers map[string]server
Clients clients
}
type ownerInfo struct {
Name string
Org string `toml:"organization"`
Bio string
DOB time.Time
}
type database struct {
Server string
Ports []int
ConnMax int `toml:"connection_max"`
Enabled bool
}
type server struct {
IP string
DC string
}
type clients struct {
Data [][]interface{}
Hosts []string
}
```
Note that a case insensitive match will be tried if an exact match can't be
found.
A working example of the above can be found in `_examples/example.{go,toml}`.

View file

@ -0,0 +1,14 @@
DO WHAT THE FUCK YOU WANT TO PUBLIC LICENSE
Version 2, December 2004
Copyright (C) 2004 Sam Hocevar <sam@hocevar.net>
Everyone is permitted to copy and distribute verbatim or modified
copies of this license document, and changing it is allowed as long
as the name is changed.
DO WHAT THE FUCK YOU WANT TO PUBLIC LICENSE
TERMS AND CONDITIONS FOR COPYING, DISTRIBUTION AND MODIFICATION
0. You just DO WHAT THE FUCK YOU WANT TO.

View file

@ -0,0 +1,14 @@
# Implements the TOML test suite interface
This is an implementation of the interface expected by
[toml-test](https://github.com/BurntSushi/toml-test) for my
[toml parser written in Go](https://github.com/BurntSushi/toml).
In particular, it maps TOML data on `stdin` to a JSON format on `stdout`.
Compatible with TOML version
[v0.2.0](https://github.com/mojombo/toml/blob/master/versions/toml-v0.2.0.md)
Compatible with `toml-test` version
[v0.2.0](https://github.com/BurntSushi/toml-test/tree/v0.2.0)

View file

@ -0,0 +1,90 @@
// Command toml-test-decoder satisfies the toml-test interface for testing
// TOML decoders. Namely, it accepts TOML on stdin and outputs JSON on stdout.
package main
import (
"encoding/json"
"flag"
"fmt"
"log"
"os"
"path"
"time"
"github.com/drone/drone/Godeps/_workspace/src/github.com/BurntSushi/toml"
)
func init() {
log.SetFlags(0)
flag.Usage = usage
flag.Parse()
}
func usage() {
log.Printf("Usage: %s < toml-file\n", path.Base(os.Args[0]))
flag.PrintDefaults()
os.Exit(1)
}
func main() {
if flag.NArg() != 0 {
flag.Usage()
}
var tmp interface{}
if _, err := toml.DecodeReader(os.Stdin, &tmp); err != nil {
log.Fatalf("Error decoding TOML: %s", err)
}
typedTmp := translate(tmp)
if err := json.NewEncoder(os.Stdout).Encode(typedTmp); err != nil {
log.Fatalf("Error encoding JSON: %s", err)
}
}
func translate(tomlData interface{}) interface{} {
switch orig := tomlData.(type) {
case map[string]interface{}:
typed := make(map[string]interface{}, len(orig))
for k, v := range orig {
typed[k] = translate(v)
}
return typed
case []map[string]interface{}:
typed := make([]map[string]interface{}, len(orig))
for i, v := range orig {
typed[i] = translate(v).(map[string]interface{})
}
return typed
case []interface{}:
typed := make([]interface{}, len(orig))
for i, v := range orig {
typed[i] = translate(v)
}
// We don't really need to tag arrays, but let's be future proof.
// (If TOML ever supports tuples, we'll need this.)
return tag("array", typed)
case time.Time:
return tag("datetime", orig.Format("2006-01-02T15:04:05Z"))
case bool:
return tag("bool", fmt.Sprintf("%v", orig))
case int64:
return tag("integer", fmt.Sprintf("%d", orig))
case float64:
return tag("float", fmt.Sprintf("%v", orig))
case string:
return tag("string", orig)
}
panic(fmt.Sprintf("Unknown type: %T", tomlData))
}
func tag(typeName string, data interface{}) map[string]interface{} {
return map[string]interface{}{
"type": typeName,
"value": data,
}
}

View file

@ -0,0 +1,14 @@
DO WHAT THE FUCK YOU WANT TO PUBLIC LICENSE
Version 2, December 2004
Copyright (C) 2004 Sam Hocevar <sam@hocevar.net>
Everyone is permitted to copy and distribute verbatim or modified
copies of this license document, and changing it is allowed as long
as the name is changed.
DO WHAT THE FUCK YOU WANT TO PUBLIC LICENSE
TERMS AND CONDITIONS FOR COPYING, DISTRIBUTION AND MODIFICATION
0. You just DO WHAT THE FUCK YOU WANT TO.

View file

@ -0,0 +1,14 @@
# Implements the TOML test suite interface for TOML encoders
This is an implementation of the interface expected by
[toml-test](https://github.com/BurntSushi/toml-test) for the
[TOML encoder](https://github.com/BurntSushi/toml).
In particular, it maps JSON data on `stdin` to a TOML format on `stdout`.
Compatible with TOML version
[v0.2.0](https://github.com/mojombo/toml/blob/master/versions/toml-v0.2.0.md)
Compatible with `toml-test` version
[v0.2.0](https://github.com/BurntSushi/toml-test/tree/v0.2.0)

View file

@ -0,0 +1,131 @@
// Command toml-test-encoder satisfies the toml-test interface for testing
// TOML encoders. Namely, it accepts JSON on stdin and outputs TOML on stdout.
package main
import (
"encoding/json"
"flag"
"log"
"os"
"path"
"strconv"
"time"
"github.com/drone/drone/Godeps/_workspace/src/github.com/BurntSushi/toml"
)
func init() {
log.SetFlags(0)
flag.Usage = usage
flag.Parse()
}
func usage() {
log.Printf("Usage: %s < json-file\n", path.Base(os.Args[0]))
flag.PrintDefaults()
os.Exit(1)
}
func main() {
if flag.NArg() != 0 {
flag.Usage()
}
var tmp interface{}
if err := json.NewDecoder(os.Stdin).Decode(&tmp); err != nil {
log.Fatalf("Error decoding JSON: %s", err)
}
tomlData := translate(tmp)
if err := toml.NewEncoder(os.Stdout).Encode(tomlData); err != nil {
log.Fatalf("Error encoding TOML: %s", err)
}
}
func translate(typedJson interface{}) interface{} {
switch v := typedJson.(type) {
case map[string]interface{}:
if len(v) == 2 && in("type", v) && in("value", v) {
return untag(v)
}
m := make(map[string]interface{}, len(v))
for k, v2 := range v {
m[k] = translate(v2)
}
return m
case []interface{}:
tabArray := make([]map[string]interface{}, len(v))
for i := range v {
if m, ok := translate(v[i]).(map[string]interface{}); ok {
tabArray[i] = m
} else {
log.Fatalf("JSON arrays may only contain objects. This " +
"corresponds to only tables being allowed in " +
"TOML table arrays.")
}
}
return tabArray
}
log.Fatalf("Unrecognized JSON format '%T'.", typedJson)
panic("unreachable")
}
func untag(typed map[string]interface{}) interface{} {
t := typed["type"].(string)
v := typed["value"]
switch t {
case "string":
return v.(string)
case "integer":
v := v.(string)
n, err := strconv.Atoi(v)
if err != nil {
log.Fatalf("Could not parse '%s' as integer: %s", v, err)
}
return n
case "float":
v := v.(string)
f, err := strconv.ParseFloat(v, 64)
if err != nil {
log.Fatalf("Could not parse '%s' as float64: %s", v, err)
}
return f
case "datetime":
v := v.(string)
t, err := time.Parse("2006-01-02T15:04:05Z", v)
if err != nil {
log.Fatalf("Could not parse '%s' as a datetime: %s", v, err)
}
return t
case "bool":
v := v.(string)
switch v {
case "true":
return true
case "false":
return false
}
log.Fatalf("Could not parse '%s' as a boolean.", v)
case "array":
v := v.([]interface{})
array := make([]interface{}, len(v))
for i := range v {
if m, ok := v[i].(map[string]interface{}); ok {
array[i] = untag(m)
} else {
log.Fatalf("Arrays may only contain other arrays or "+
"primitive values, but found a '%T'.", m)
}
}
return array
}
log.Fatalf("Unrecognized tag type '%s'.", t)
panic("unreachable")
}
func in(key string, m map[string]interface{}) bool {
_, ok := m[key]
return ok
}

View file

@ -0,0 +1,14 @@
DO WHAT THE FUCK YOU WANT TO PUBLIC LICENSE
Version 2, December 2004
Copyright (C) 2004 Sam Hocevar <sam@hocevar.net>
Everyone is permitted to copy and distribute verbatim or modified
copies of this license document, and changing it is allowed as long
as the name is changed.
DO WHAT THE FUCK YOU WANT TO PUBLIC LICENSE
TERMS AND CONDITIONS FOR COPYING, DISTRIBUTION AND MODIFICATION
0. You just DO WHAT THE FUCK YOU WANT TO.

View file

@ -0,0 +1,22 @@
# TOML Validator
If Go is installed, it's simple to try it out:
```bash
go get github.com/BurntSushi/toml/cmd/tomlv
tomlv some-toml-file.toml
```
You can see the types of every key in a TOML file with:
```bash
tomlv -types some-toml-file.toml
```
At the moment, only one error message is reported at a time. Error messages
include line numbers. No output means that the files given are valid TOML, or
there is a bug in `tomlv`.
Compatible with TOML version
[v0.1.0](https://github.com/mojombo/toml/blob/master/versions/toml-v0.1.0.md)

View file

@ -0,0 +1,61 @@
// Command tomlv validates TOML documents and prints each key's type.
package main
import (
"flag"
"fmt"
"log"
"os"
"path"
"strings"
"text/tabwriter"
"github.com/drone/drone/Godeps/_workspace/src/github.com/BurntSushi/toml"
)
var (
flagTypes = false
)
func init() {
log.SetFlags(0)
flag.BoolVar(&flagTypes, "types", flagTypes,
"When set, the types of every defined key will be shown.")
flag.Usage = usage
flag.Parse()
}
func usage() {
log.Printf("Usage: %s toml-file [ toml-file ... ]\n",
path.Base(os.Args[0]))
flag.PrintDefaults()
os.Exit(1)
}
func main() {
if flag.NArg() < 1 {
flag.Usage()
}
for _, f := range flag.Args() {
var tmp interface{}
md, err := toml.DecodeFile(f, &tmp)
if err != nil {
log.Fatalf("Error in '%s': %s", f, err)
}
if flagTypes {
printTypes(md)
}
}
}
func printTypes(md toml.MetaData) {
tabw := tabwriter.NewWriter(os.Stdout, 0, 0, 2, ' ', 0)
for _, key := range md.Keys() {
fmt.Fprintf(tabw, "%s%s\t%s\n",
strings.Repeat(" ", len(key)-1), key, md.Type(key...))
}
tabw.Flush()
}

View file

@ -0,0 +1,492 @@
package toml
import (
"fmt"
"io"
"io/ioutil"
"math"
"reflect"
"strings"
"time"
)
var e = fmt.Errorf
// Unmarshaler is the interface implemented by objects that can unmarshal a
// TOML description of themselves.
type Unmarshaler interface {
UnmarshalTOML(interface{}) error
}
// Unmarshal decodes the contents of `p` in TOML format into a pointer `v`.
func Unmarshal(p []byte, v interface{}) error {
_, err := Decode(string(p), v)
return err
}
// Primitive is a TOML value that hasn't been decoded into a Go value.
// When using the various `Decode*` functions, the type `Primitive` may
// be given to any value, and its decoding will be delayed.
//
// A `Primitive` value can be decoded using the `PrimitiveDecode` function.
//
// The underlying representation of a `Primitive` value is subject to change.
// Do not rely on it.
//
// N.B. Primitive values are still parsed, so using them will only avoid
// the overhead of reflection. They can be useful when you don't know the
// exact type of TOML data until run time.
type Primitive struct {
undecoded interface{}
context Key
}
// DEPRECATED!
//
// Use MetaData.PrimitiveDecode instead.
func PrimitiveDecode(primValue Primitive, v interface{}) error {
md := MetaData{decoded: make(map[string]bool)}
return md.unify(primValue.undecoded, rvalue(v))
}
// PrimitiveDecode is just like the other `Decode*` functions, except it
// decodes a TOML value that has already been parsed. Valid primitive values
// can *only* be obtained from values filled by the decoder functions,
// including this method. (i.e., `v` may contain more `Primitive`
// values.)
//
// Meta data for primitive values is included in the meta data returned by
// the `Decode*` functions with one exception: keys returned by the Undecoded
// method will only reflect keys that were decoded. Namely, any keys hidden
// behind a Primitive will be considered undecoded. Executing this method will
// update the undecoded keys in the meta data. (See the example.)
func (md *MetaData) PrimitiveDecode(primValue Primitive, v interface{}) error {
md.context = primValue.context
defer func() { md.context = nil }()
return md.unify(primValue.undecoded, rvalue(v))
}
// Decode will decode the contents of `data` in TOML format into a pointer
// `v`.
//
// TOML hashes correspond to Go structs or maps. (Dealer's choice. They can be
// used interchangeably.)
//
// TOML arrays of tables correspond to either a slice of structs or a slice
// of maps.
//
// TOML datetimes correspond to Go `time.Time` values.
//
// All other TOML types (float, string, int, bool and array) correspond
// to the obvious Go types.
//
// An exception to the above rules is if a type implements the
// encoding.TextUnmarshaler interface. In this case, any primitive TOML value
// (floats, strings, integers, booleans and datetimes) will be converted to
// a byte string and given to the value's UnmarshalText method. See the
// Unmarshaler example for a demonstration with time duration strings.
//
// Key mapping
//
// TOML keys can map to either keys in a Go map or field names in a Go
// struct. The special `toml` struct tag may be used to map TOML keys to
// struct fields that don't match the key name exactly. (See the example.)
// A case insensitive match to struct names will be tried if an exact match
// can't be found.
//
// The mapping between TOML values and Go values is loose. That is, there
// may exist TOML values that cannot be placed into your representation, and
// there may be parts of your representation that do not correspond to
// TOML values. This loose mapping can be made stricter by using the IsDefined
// and/or Undecoded methods on the MetaData returned.
//
// This decoder will not handle cyclic types. If a cyclic type is passed,
// `Decode` will not terminate.
func Decode(data string, v interface{}) (MetaData, error) {
p, err := parse(data)
if err != nil {
return MetaData{}, err
}
md := MetaData{
p.mapping, p.types, p.ordered,
make(map[string]bool, len(p.ordered)), nil,
}
return md, md.unify(p.mapping, rvalue(v))
}
// DecodeFile is just like Decode, except it will automatically read the
// contents of the file at `fpath` and decode it for you.
func DecodeFile(fpath string, v interface{}) (MetaData, error) {
bs, err := ioutil.ReadFile(fpath)
if err != nil {
return MetaData{}, err
}
return Decode(string(bs), v)
}
// DecodeReader is just like Decode, except it will consume all bytes
// from the reader and decode it for you.
func DecodeReader(r io.Reader, v interface{}) (MetaData, error) {
bs, err := ioutil.ReadAll(r)
if err != nil {
return MetaData{}, err
}
return Decode(string(bs), v)
}
// unify performs a sort of type unification based on the structure of `rv`,
// which is the client representation.
//
// Any type mismatch produces an error. Finding a type that we don't know
// how to handle produces an unsupported type error.
func (md *MetaData) unify(data interface{}, rv reflect.Value) error {
// Special case. Look for a `Primitive` value.
if rv.Type() == reflect.TypeOf((*Primitive)(nil)).Elem() {
// Save the undecoded data and the key context into the primitive
// value.
context := make(Key, len(md.context))
copy(context, md.context)
rv.Set(reflect.ValueOf(Primitive{
undecoded: data,
context: context,
}))
return nil
}
// Special case. Unmarshaler Interface support.
if rv.CanAddr() {
if v, ok := rv.Addr().Interface().(Unmarshaler); ok {
return v.UnmarshalTOML(data)
}
}
// Special case. Handle time.Time values specifically.
// TODO: Remove this code when we decide to drop support for Go 1.1.
// This isn't necessary in Go 1.2 because time.Time satisfies the encoding
// interfaces.
if rv.Type().AssignableTo(rvalue(time.Time{}).Type()) {
return md.unifyDatetime(data, rv)
}
// Special case. Look for a value satisfying the TextUnmarshaler interface.
if v, ok := rv.Interface().(TextUnmarshaler); ok {
return md.unifyText(data, v)
}
// BUG(burntsushi)
// The behavior here is incorrect whenever a Go type satisfies the
// encoding.TextUnmarshaler interface but also corresponds to a TOML
// hash or array. In particular, the unmarshaler should only be applied
// to primitive TOML values. But at this point, it will be applied to
// all kinds of values and produce an incorrect error whenever those values
// are hashes or arrays (including arrays of tables).
k := rv.Kind()
// laziness
if k >= reflect.Int && k <= reflect.Uint64 {
return md.unifyInt(data, rv)
}
switch k {
case reflect.Ptr:
elem := reflect.New(rv.Type().Elem())
err := md.unify(data, reflect.Indirect(elem))
if err != nil {
return err
}
rv.Set(elem)
return nil
case reflect.Struct:
return md.unifyStruct(data, rv)
case reflect.Map:
return md.unifyMap(data, rv)
case reflect.Array:
return md.unifyArray(data, rv)
case reflect.Slice:
return md.unifySlice(data, rv)
case reflect.String:
return md.unifyString(data, rv)
case reflect.Bool:
return md.unifyBool(data, rv)
case reflect.Interface:
// we only support empty interfaces.
if rv.NumMethod() > 0 {
return e("Unsupported type '%s'.", rv.Kind())
}
return md.unifyAnything(data, rv)
case reflect.Float32:
fallthrough
case reflect.Float64:
return md.unifyFloat64(data, rv)
}
return e("Unsupported type '%s'.", rv.Kind())
}
func (md *MetaData) unifyStruct(mapping interface{}, rv reflect.Value) error {
tmap, ok := mapping.(map[string]interface{})
if !ok {
return mismatch(rv, "map", mapping)
}
for key, datum := range tmap {
var f *field
fields := cachedTypeFields(rv.Type())
for i := range fields {
ff := &fields[i]
if ff.name == key {
f = ff
break
}
if f == nil && strings.EqualFold(ff.name, key) {
f = ff
}
}
if f != nil {
subv := rv
for _, i := range f.index {
subv = indirect(subv.Field(i))
}
if isUnifiable(subv) {
md.decoded[md.context.add(key).String()] = true
md.context = append(md.context, key)
if err := md.unify(datum, subv); err != nil {
return e("Type mismatch for '%s.%s': %s",
rv.Type().String(), f.name, err)
}
md.context = md.context[0 : len(md.context)-1]
} else if f.name != "" {
// Bad user! No soup for you!
return e("Field '%s.%s' is unexported, and therefore cannot "+
"be loaded with reflection.", rv.Type().String(), f.name)
}
}
}
return nil
}
func (md *MetaData) unifyMap(mapping interface{}, rv reflect.Value) error {
tmap, ok := mapping.(map[string]interface{})
if !ok {
return badtype("map", mapping)
}
if rv.IsNil() {
rv.Set(reflect.MakeMap(rv.Type()))
}
for k, v := range tmap {
md.decoded[md.context.add(k).String()] = true
md.context = append(md.context, k)
rvkey := indirect(reflect.New(rv.Type().Key()))
rvval := reflect.Indirect(reflect.New(rv.Type().Elem()))
if err := md.unify(v, rvval); err != nil {
return err
}
md.context = md.context[0 : len(md.context)-1]
rvkey.SetString(k)
rv.SetMapIndex(rvkey, rvval)
}
return nil
}
func (md *MetaData) unifyArray(data interface{}, rv reflect.Value) error {
datav := reflect.ValueOf(data)
if datav.Kind() != reflect.Slice {
return badtype("slice", data)
}
sliceLen := datav.Len()
if sliceLen != rv.Len() {
return e("expected array length %d; got TOML array of length %d",
rv.Len(), sliceLen)
}
return md.unifySliceArray(datav, rv)
}
func (md *MetaData) unifySlice(data interface{}, rv reflect.Value) error {
datav := reflect.ValueOf(data)
if datav.Kind() != reflect.Slice {
return badtype("slice", data)
}
sliceLen := datav.Len()
if rv.IsNil() {
rv.Set(reflect.MakeSlice(rv.Type(), sliceLen, sliceLen))
}
return md.unifySliceArray(datav, rv)
}
func (md *MetaData) unifySliceArray(data, rv reflect.Value) error {
sliceLen := data.Len()
for i := 0; i < sliceLen; i++ {
v := data.Index(i).Interface()
sliceval := indirect(rv.Index(i))
if err := md.unify(v, sliceval); err != nil {
return err
}
}
return nil
}
func (md *MetaData) unifyDatetime(data interface{}, rv reflect.Value) error {
if _, ok := data.(time.Time); ok {
rv.Set(reflect.ValueOf(data))
return nil
}
return badtype("time.Time", data)
}
func (md *MetaData) unifyString(data interface{}, rv reflect.Value) error {
if s, ok := data.(string); ok {
rv.SetString(s)
return nil
}
return badtype("string", data)
}
func (md *MetaData) unifyFloat64(data interface{}, rv reflect.Value) error {
if num, ok := data.(float64); ok {
switch rv.Kind() {
case reflect.Float32:
fallthrough
case reflect.Float64:
rv.SetFloat(num)
default:
panic("bug")
}
return nil
}
return badtype("float", data)
}
func (md *MetaData) unifyInt(data interface{}, rv reflect.Value) error {
if num, ok := data.(int64); ok {
if rv.Kind() >= reflect.Int && rv.Kind() <= reflect.Int64 {
switch rv.Kind() {
case reflect.Int, reflect.Int64:
// No bounds checking necessary.
case reflect.Int8:
if num < math.MinInt8 || num > math.MaxInt8 {
return e("Value '%d' is out of range for int8.", num)
}
case reflect.Int16:
if num < math.MinInt16 || num > math.MaxInt16 {
return e("Value '%d' is out of range for int16.", num)
}
case reflect.Int32:
if num < math.MinInt32 || num > math.MaxInt32 {
return e("Value '%d' is out of range for int32.", num)
}
}
rv.SetInt(num)
} else if rv.Kind() >= reflect.Uint && rv.Kind() <= reflect.Uint64 {
unum := uint64(num)
switch rv.Kind() {
case reflect.Uint, reflect.Uint64:
// No bounds checking necessary.
case reflect.Uint8:
if num < 0 || unum > math.MaxUint8 {
return e("Value '%d' is out of range for uint8.", num)
}
case reflect.Uint16:
if num < 0 || unum > math.MaxUint16 {
return e("Value '%d' is out of range for uint16.", num)
}
case reflect.Uint32:
if num < 0 || unum > math.MaxUint32 {
return e("Value '%d' is out of range for uint32.", num)
}
}
rv.SetUint(unum)
} else {
panic("unreachable")
}
return nil
}
return badtype("integer", data)
}
func (md *MetaData) unifyBool(data interface{}, rv reflect.Value) error {
if b, ok := data.(bool); ok {
rv.SetBool(b)
return nil
}
return badtype("boolean", data)
}
func (md *MetaData) unifyAnything(data interface{}, rv reflect.Value) error {
rv.Set(reflect.ValueOf(data))
return nil
}
func (md *MetaData) unifyText(data interface{}, v TextUnmarshaler) error {
var s string
switch sdata := data.(type) {
case TextMarshaler:
text, err := sdata.MarshalText()
if err != nil {
return err
}
s = string(text)
case fmt.Stringer:
s = sdata.String()
case string:
s = sdata
case bool:
s = fmt.Sprintf("%v", sdata)
case int64:
s = fmt.Sprintf("%d", sdata)
case float64:
s = fmt.Sprintf("%f", sdata)
default:
return badtype("primitive (string-like)", data)
}
if err := v.UnmarshalText([]byte(s)); err != nil {
return err
}
return nil
}
// rvalue returns a reflect.Value of `v`. All pointers are resolved.
func rvalue(v interface{}) reflect.Value {
return indirect(reflect.ValueOf(v))
}
// indirect returns the value pointed to by a pointer.
// Pointers are followed until the value is not a pointer.
// New values are allocated for each nil pointer.
//
// An exception to this rule is if the value satisfies an interface of
// interest to us (like encoding.TextUnmarshaler).
func indirect(v reflect.Value) reflect.Value {
if v.Kind() != reflect.Ptr {
if v.CanAddr() {
pv := v.Addr()
if _, ok := pv.Interface().(TextUnmarshaler); ok {
return pv
}
}
return v
}
if v.IsNil() {
v.Set(reflect.New(v.Type().Elem()))
}
return indirect(reflect.Indirect(v))
}
func isUnifiable(rv reflect.Value) bool {
if rv.CanSet() {
return true
}
if _, ok := rv.Interface().(TextUnmarshaler); ok {
return true
}
return false
}
func badtype(expected string, data interface{}) error {
return e("Expected %s but found '%T'.", expected, data)
}
func mismatch(user reflect.Value, expected string, data interface{}) error {
return e("Type mismatch for %s. Expected %s but found '%T'.",
user.Type().String(), expected, data)
}

View file

@ -0,0 +1,122 @@
package toml
import "strings"
// MetaData allows access to meta information about TOML data that may not
// be inferrable via reflection. In particular, whether a key has been defined
// and the TOML type of a key.
type MetaData struct {
mapping map[string]interface{}
types map[string]tomlType
keys []Key
decoded map[string]bool
context Key // Used only during decoding.
}
// IsDefined returns true if the key given exists in the TOML data. The key
// should be specified hierarchially. e.g.,
//
// // access the TOML key 'a.b.c'
// IsDefined("a", "b", "c")
//
// IsDefined will return false if an empty key given. Keys are case sensitive.
func (md *MetaData) IsDefined(key ...string) bool {
if len(key) == 0 {
return false
}
var hash map[string]interface{}
var ok bool
var hashOrVal interface{} = md.mapping
for _, k := range key {
if hash, ok = hashOrVal.(map[string]interface{}); !ok {
return false
}
if hashOrVal, ok = hash[k]; !ok {
return false
}
}
return true
}
// Type returns a string representation of the type of the key specified.
//
// Type will return the empty string if given an empty key or a key that
// does not exist. Keys are case sensitive.
func (md *MetaData) Type(key ...string) string {
fullkey := strings.Join(key, ".")
if typ, ok := md.types[fullkey]; ok {
return typ.typeString()
}
return ""
}
// Key is the type of any TOML key, including key groups. Use (MetaData).Keys
// to get values of this type.
type Key []string
func (k Key) String() string {
return strings.Join(k, ".")
}
func (k Key) maybeQuotedAll() string {
var ss []string
for i := range k {
ss = append(ss, k.maybeQuoted(i))
}
return strings.Join(ss, ".")
}
func (k Key) maybeQuoted(i int) string {
quote := false
for _, c := range k[i] {
if !isBareKeyChar(c) {
quote = true
break
}
}
if quote {
return "\"" + strings.Replace(k[i], "\"", "\\\"", -1) + "\""
} else {
return k[i]
}
}
func (k Key) add(piece string) Key {
newKey := make(Key, len(k)+1)
copy(newKey, k)
newKey[len(k)] = piece
return newKey
}
// Keys returns a slice of every key in the TOML data, including key groups.
// Each key is itself a slice, where the first element is the top of the
// hierarchy and the last is the most specific.
//
// The list will have the same order as the keys appeared in the TOML data.
//
// All keys returned are non-empty.
func (md *MetaData) Keys() []Key {
return md.keys
}
// Undecoded returns all keys that have not been decoded in the order in which
// they appear in the original TOML document.
//
// This includes keys that haven't been decoded because of a Primitive value.
// Once the Primitive value is decoded, the keys will be considered decoded.
//
// Also note that decoding into an empty interface will result in no decoding,
// and so no keys will be considered decoded.
//
// In this sense, the Undecoded keys correspond to keys in the TOML document
// that do not have a concrete type in your representation.
func (md *MetaData) Undecoded() []Key {
undecoded := make([]Key, 0, len(md.keys))
for _, key := range md.keys {
if !md.decoded[key.String()] {
undecoded = append(undecoded, key)
}
}
return undecoded
}

View file

@ -0,0 +1,950 @@
package toml
import (
"fmt"
"log"
"reflect"
"testing"
"time"
)
func init() {
log.SetFlags(0)
}
func TestDecodeSimple(t *testing.T) {
var testSimple = `
age = 250
andrew = "gallant"
kait = "brady"
now = 1987-07-05T05:45:00Z
yesOrNo = true
pi = 3.14
colors = [
["red", "green", "blue"],
["cyan", "magenta", "yellow", "black"],
]
[My.Cats]
plato = "cat 1"
cauchy = "cat 2"
`
type cats struct {
Plato string
Cauchy string
}
type simple struct {
Age int
Colors [][]string
Pi float64
YesOrNo bool
Now time.Time
Andrew string
Kait string
My map[string]cats
}
var val simple
_, err := Decode(testSimple, &val)
if err != nil {
t.Fatal(err)
}
now, err := time.Parse("2006-01-02T15:04:05", "1987-07-05T05:45:00")
if err != nil {
panic(err)
}
var answer = simple{
Age: 250,
Andrew: "gallant",
Kait: "brady",
Now: now,
YesOrNo: true,
Pi: 3.14,
Colors: [][]string{
{"red", "green", "blue"},
{"cyan", "magenta", "yellow", "black"},
},
My: map[string]cats{
"Cats": cats{Plato: "cat 1", Cauchy: "cat 2"},
},
}
if !reflect.DeepEqual(val, answer) {
t.Fatalf("Expected\n-----\n%#v\n-----\nbut got\n-----\n%#v\n",
answer, val)
}
}
func TestDecodeEmbedded(t *testing.T) {
type Dog struct{ Name string }
type Age int
tests := map[string]struct {
input string
decodeInto interface{}
wantDecoded interface{}
}{
"embedded struct": {
input: `Name = "milton"`,
decodeInto: &struct{ Dog }{},
wantDecoded: &struct{ Dog }{Dog{"milton"}},
},
"embedded non-nil pointer to struct": {
input: `Name = "milton"`,
decodeInto: &struct{ *Dog }{},
wantDecoded: &struct{ *Dog }{&Dog{"milton"}},
},
"embedded nil pointer to struct": {
input: ``,
decodeInto: &struct{ *Dog }{},
wantDecoded: &struct{ *Dog }{nil},
},
"embedded int": {
input: `Age = -5`,
decodeInto: &struct{ Age }{},
wantDecoded: &struct{ Age }{-5},
},
}
for label, test := range tests {
_, err := Decode(test.input, test.decodeInto)
if err != nil {
t.Fatal(err)
}
if !reflect.DeepEqual(test.wantDecoded, test.decodeInto) {
t.Errorf("%s: want decoded == %+v, got %+v",
label, test.wantDecoded, test.decodeInto)
}
}
}
func TestTableArrays(t *testing.T) {
var tomlTableArrays = `
[[albums]]
name = "Born to Run"
[[albums.songs]]
name = "Jungleland"
[[albums.songs]]
name = "Meeting Across the River"
[[albums]]
name = "Born in the USA"
[[albums.songs]]
name = "Glory Days"
[[albums.songs]]
name = "Dancing in the Dark"
`
type Song struct {
Name string
}
type Album struct {
Name string
Songs []Song
}
type Music struct {
Albums []Album
}
expected := Music{[]Album{
{"Born to Run", []Song{{"Jungleland"}, {"Meeting Across the River"}}},
{"Born in the USA", []Song{{"Glory Days"}, {"Dancing in the Dark"}}},
}}
var got Music
if _, err := Decode(tomlTableArrays, &got); err != nil {
t.Fatal(err)
}
if !reflect.DeepEqual(expected, got) {
t.Fatalf("\n%#v\n!=\n%#v\n", expected, got)
}
}
// Case insensitive matching tests.
// A bit more comprehensive than needed given the current implementation,
// but implementations change.
// Probably still missing demonstrations of some ugly corner cases regarding
// case insensitive matching and multiple fields.
func TestCase(t *testing.T) {
var caseToml = `
tOpString = "string"
tOpInt = 1
tOpFloat = 1.1
tOpBool = true
tOpdate = 2006-01-02T15:04:05Z
tOparray = [ "array" ]
Match = "i should be in Match only"
MatcH = "i should be in MatcH only"
once = "just once"
[nEst.eD]
nEstedString = "another string"
`
type InsensitiveEd struct {
NestedString string
}
type InsensitiveNest struct {
Ed InsensitiveEd
}
type Insensitive struct {
TopString string
TopInt int
TopFloat float64
TopBool bool
TopDate time.Time
TopArray []string
Match string
MatcH string
Once string
OncE string
Nest InsensitiveNest
}
tme, err := time.Parse(time.RFC3339, time.RFC3339[:len(time.RFC3339)-5])
if err != nil {
panic(err)
}
expected := Insensitive{
TopString: "string",
TopInt: 1,
TopFloat: 1.1,
TopBool: true,
TopDate: tme,
TopArray: []string{"array"},
MatcH: "i should be in MatcH only",
Match: "i should be in Match only",
Once: "just once",
OncE: "",
Nest: InsensitiveNest{
Ed: InsensitiveEd{NestedString: "another string"},
},
}
var got Insensitive
if _, err := Decode(caseToml, &got); err != nil {
t.Fatal(err)
}
if !reflect.DeepEqual(expected, got) {
t.Fatalf("\n%#v\n!=\n%#v\n", expected, got)
}
}
func TestPointers(t *testing.T) {
type Object struct {
Type string
Description string
}
type Dict struct {
NamedObject map[string]*Object
BaseObject *Object
Strptr *string
Strptrs []*string
}
s1, s2, s3 := "blah", "abc", "def"
expected := &Dict{
Strptr: &s1,
Strptrs: []*string{&s2, &s3},
NamedObject: map[string]*Object{
"foo": {"FOO", "fooooo!!!"},
"bar": {"BAR", "ba-ba-ba-ba-barrrr!!!"},
},
BaseObject: &Object{"BASE", "da base"},
}
ex1 := `
Strptr = "blah"
Strptrs = ["abc", "def"]
[NamedObject.foo]
Type = "FOO"
Description = "fooooo!!!"
[NamedObject.bar]
Type = "BAR"
Description = "ba-ba-ba-ba-barrrr!!!"
[BaseObject]
Type = "BASE"
Description = "da base"
`
dict := new(Dict)
_, err := Decode(ex1, dict)
if err != nil {
t.Errorf("Decode error: %v", err)
}
if !reflect.DeepEqual(expected, dict) {
t.Fatalf("\n%#v\n!=\n%#v\n", expected, dict)
}
}
type sphere struct {
Center [3]float64
Radius float64
}
func TestDecodeSimpleArray(t *testing.T) {
var s1 sphere
if _, err := Decode(`center = [0.0, 1.5, 0.0]`, &s1); err != nil {
t.Fatal(err)
}
}
func TestDecodeArrayWrongSize(t *testing.T) {
var s1 sphere
if _, err := Decode(`center = [0.1, 2.3]`, &s1); err == nil {
t.Fatal("Expected array type mismatch error")
}
}
func TestDecodeLargeIntoSmallInt(t *testing.T) {
type table struct {
Value int8
}
var tab table
if _, err := Decode(`value = 500`, &tab); err == nil {
t.Fatal("Expected integer out-of-bounds error.")
}
}
func TestDecodeSizedInts(t *testing.T) {
type table struct {
U8 uint8
U16 uint16
U32 uint32
U64 uint64
U uint
I8 int8
I16 int16
I32 int32
I64 int64
I int
}
answer := table{1, 1, 1, 1, 1, -1, -1, -1, -1, -1}
toml := `
u8 = 1
u16 = 1
u32 = 1
u64 = 1
u = 1
i8 = -1
i16 = -1
i32 = -1
i64 = -1
i = -1
`
var tab table
if _, err := Decode(toml, &tab); err != nil {
t.Fatal(err.Error())
}
if answer != tab {
t.Fatalf("Expected %#v but got %#v", answer, tab)
}
}
func TestUnmarshaler(t *testing.T) {
var tomlBlob = `
[dishes.hamboogie]
name = "Hamboogie with fries"
price = 10.99
[[dishes.hamboogie.ingredients]]
name = "Bread Bun"
[[dishes.hamboogie.ingredients]]
name = "Lettuce"
[[dishes.hamboogie.ingredients]]
name = "Real Beef Patty"
[[dishes.hamboogie.ingredients]]
name = "Tomato"
[dishes.eggsalad]
name = "Egg Salad with rice"
price = 3.99
[[dishes.eggsalad.ingredients]]
name = "Egg"
[[dishes.eggsalad.ingredients]]
name = "Mayo"
[[dishes.eggsalad.ingredients]]
name = "Rice"
`
m := &menu{}
if _, err := Decode(tomlBlob, m); err != nil {
log.Fatal(err)
}
if len(m.Dishes) != 2 {
t.Log("two dishes should be loaded with UnmarshalTOML()")
t.Errorf("expected %d but got %d", 2, len(m.Dishes))
}
eggSalad := m.Dishes["eggsalad"]
if _, ok := interface{}(eggSalad).(dish); !ok {
t.Errorf("expected a dish")
}
if eggSalad.Name != "Egg Salad with rice" {
t.Errorf("expected the dish to be named 'Egg Salad with rice'")
}
if len(eggSalad.Ingredients) != 3 {
t.Log("dish should be loaded with UnmarshalTOML()")
t.Errorf("expected %d but got %d", 3, len(eggSalad.Ingredients))
}
found := false
for _, i := range eggSalad.Ingredients {
if i.Name == "Rice" {
found = true
break
}
}
if !found {
t.Error("Rice was not loaded in UnmarshalTOML()")
}
// test on a value - must be passed as *
o := menu{}
if _, err := Decode(tomlBlob, &o); err != nil {
log.Fatal(err)
}
}
type menu struct {
Dishes map[string]dish
}
func (m *menu) UnmarshalTOML(p interface{}) error {
m.Dishes = make(map[string]dish)
data, _ := p.(map[string]interface{})
dishes := data["dishes"].(map[string]interface{})
for n, v := range dishes {
if d, ok := v.(map[string]interface{}); ok {
nd := dish{}
nd.UnmarshalTOML(d)
m.Dishes[n] = nd
} else {
return fmt.Errorf("not a dish")
}
}
return nil
}
type dish struct {
Name string
Price float32
Ingredients []ingredient
}
func (d *dish) UnmarshalTOML(p interface{}) error {
data, _ := p.(map[string]interface{})
d.Name, _ = data["name"].(string)
d.Price, _ = data["price"].(float32)
ingredients, _ := data["ingredients"].([]map[string]interface{})
for _, e := range ingredients {
n, _ := interface{}(e).(map[string]interface{})
name, _ := n["name"].(string)
i := ingredient{name}
d.Ingredients = append(d.Ingredients, i)
}
return nil
}
type ingredient struct {
Name string
}
func ExampleMetaData_PrimitiveDecode() {
var md MetaData
var err error
var tomlBlob = `
ranking = ["Springsteen", "J Geils"]
[bands.Springsteen]
started = 1973
albums = ["Greetings", "WIESS", "Born to Run", "Darkness"]
[bands."J Geils"]
started = 1970
albums = ["The J. Geils Band", "Full House", "Blow Your Face Out"]
`
type band struct {
Started int
Albums []string
}
type classics struct {
Ranking []string
Bands map[string]Primitive
}
// Do the initial decode. Reflection is delayed on Primitive values.
var music classics
if md, err = Decode(tomlBlob, &music); err != nil {
log.Fatal(err)
}
// MetaData still includes information on Primitive values.
fmt.Printf("Is `bands.Springsteen` defined? %v\n",
md.IsDefined("bands", "Springsteen"))
// Decode primitive data into Go values.
for _, artist := range music.Ranking {
// A band is a primitive value, so we need to decode it to get a
// real `band` value.
primValue := music.Bands[artist]
var aBand band
if err = md.PrimitiveDecode(primValue, &aBand); err != nil {
log.Fatal(err)
}
fmt.Printf("%s started in %d.\n", artist, aBand.Started)
}
// Check to see if there were any fields left undecoded.
// Note that this won't be empty before decoding the Primitive value!
fmt.Printf("Undecoded: %q\n", md.Undecoded())
// Output:
// Is `bands.Springsteen` defined? true
// Springsteen started in 1973.
// J Geils started in 1970.
// Undecoded: []
}
func ExampleDecode() {
var tomlBlob = `
# Some comments.
[alpha]
ip = "10.0.0.1"
[alpha.config]
Ports = [ 8001, 8002 ]
Location = "Toronto"
Created = 1987-07-05T05:45:00Z
[beta]
ip = "10.0.0.2"
[beta.config]
Ports = [ 9001, 9002 ]
Location = "New Jersey"
Created = 1887-01-05T05:55:00Z
`
type serverConfig struct {
Ports []int
Location string
Created time.Time
}
type server struct {
IP string `toml:"ip"`
Config serverConfig `toml:"config"`
}
type servers map[string]server
var config servers
if _, err := Decode(tomlBlob, &config); err != nil {
log.Fatal(err)
}
for _, name := range []string{"alpha", "beta"} {
s := config[name]
fmt.Printf("Server: %s (ip: %s) in %s created on %s\n",
name, s.IP, s.Config.Location,
s.Config.Created.Format("2006-01-02"))
fmt.Printf("Ports: %v\n", s.Config.Ports)
}
// Output:
// Server: alpha (ip: 10.0.0.1) in Toronto created on 1987-07-05
// Ports: [8001 8002]
// Server: beta (ip: 10.0.0.2) in New Jersey created on 1887-01-05
// Ports: [9001 9002]
}
type duration struct {
time.Duration
}
func (d *duration) UnmarshalText(text []byte) error {
var err error
d.Duration, err = time.ParseDuration(string(text))
return err
}
// Example Unmarshaler shows how to decode TOML strings into your own
// custom data type.
func Example_unmarshaler() {
blob := `
[[song]]
name = "Thunder Road"
duration = "4m49s"
[[song]]
name = "Stairway to Heaven"
duration = "8m03s"
`
type song struct {
Name string
Duration duration
}
type songs struct {
Song []song
}
var favorites songs
if _, err := Decode(blob, &favorites); err != nil {
log.Fatal(err)
}
// Code to implement the TextUnmarshaler interface for `duration`:
//
// type duration struct {
// time.Duration
// }
//
// func (d *duration) UnmarshalText(text []byte) error {
// var err error
// d.Duration, err = time.ParseDuration(string(text))
// return err
// }
for _, s := range favorites.Song {
fmt.Printf("%s (%s)\n", s.Name, s.Duration)
}
// Output:
// Thunder Road (4m49s)
// Stairway to Heaven (8m3s)
}
// Example StrictDecoding shows how to detect whether there are keys in the
// TOML document that weren't decoded into the value given. This is useful
// for returning an error to the user if they've included extraneous fields
// in their configuration.
func Example_strictDecoding() {
var blob = `
key1 = "value1"
key2 = "value2"
key3 = "value3"
`
type config struct {
Key1 string
Key3 string
}
var conf config
md, err := Decode(blob, &conf)
if err != nil {
log.Fatal(err)
}
fmt.Printf("Undecoded keys: %q\n", md.Undecoded())
// Output:
// Undecoded keys: ["key2"]
}
// Example UnmarshalTOML shows how to implement a struct type that knows how to
// unmarshal itself. The struct must take full responsibility for mapping the
// values passed into the struct. The method may be used with interfaces in a
// struct in cases where the actual type is not known until the data is
// examined.
func Example_unmarshalTOML() {
var blob = `
[[parts]]
type = "valve"
id = "valve-1"
size = 1.2
rating = 4
[[parts]]
type = "valve"
id = "valve-2"
size = 2.1
rating = 5
[[parts]]
type = "pipe"
id = "pipe-1"
length = 2.1
diameter = 12
[[parts]]
type = "cable"
id = "cable-1"
length = 12
rating = 3.1
`
o := &order{}
err := Unmarshal([]byte(blob), o)
if err != nil {
log.Fatal(err)
}
fmt.Println(len(o.parts))
for _, part := range o.parts {
fmt.Println(part.Name())
}
// Code to implement UmarshalJSON.
// type order struct {
// // NOTE `order.parts` is a private slice of type `part` which is an
// // interface and may only be loaded from toml using the
// // UnmarshalTOML() method of the Umarshaler interface.
// parts parts
// }
// func (o *order) UnmarshalTOML(data interface{}) error {
// // NOTE the example below contains detailed type casting to show how
// // the 'data' is retrieved. In operational use, a type cast wrapper
// // may be prefered e.g.
// //
// // func AsMap(v interface{}) (map[string]interface{}, error) {
// // return v.(map[string]interface{})
// // }
// //
// // resulting in:
// // d, _ := AsMap(data)
// //
// d, _ := data.(map[string]interface{})
// parts, _ := d["parts"].([]map[string]interface{})
// for _, p := range parts {
// typ, _ := p["type"].(string)
// id, _ := p["id"].(string)
// // detect the type of part and handle each case
// switch p["type"] {
// case "valve":
// size := float32(p["size"].(float64))
// rating := int(p["rating"].(int64))
// valve := &valve{
// Type: typ,
// ID: id,
// Size: size,
// Rating: rating,
// }
// o.parts = append(o.parts, valve)
// case "pipe":
// length := float32(p["length"].(float64))
// diameter := int(p["diameter"].(int64))
// pipe := &pipe{
// Type: typ,
// ID: id,
// Length: length,
// Diameter: diameter,
// }
// o.parts = append(o.parts, pipe)
// case "cable":
// length := int(p["length"].(int64))
// rating := float32(p["rating"].(float64))
// cable := &cable{
// Type: typ,
// ID: id,
// Length: length,
// Rating: rating,
// }
// o.parts = append(o.parts, cable)
// }
// }
// return nil
// }
// type parts []part
// type part interface {
// Name() string
// }
// type valve struct {
// Type string
// ID string
// Size float32
// Rating int
// }
// func (v *valve) Name() string {
// return fmt.Sprintf("VALVE: %s", v.ID)
// }
// type pipe struct {
// Type string
// ID string
// Length float32
// Diameter int
// }
// func (p *pipe) Name() string {
// return fmt.Sprintf("PIPE: %s", p.ID)
// }
// type cable struct {
// Type string
// ID string
// Length int
// Rating float32
// }
// func (c *cable) Name() string {
// return fmt.Sprintf("CABLE: %s", c.ID)
// }
// Output:
// 4
// VALVE: valve-1
// VALVE: valve-2
// PIPE: pipe-1
// CABLE: cable-1
}
type order struct {
// NOTE `order.parts` is a private slice of type `part` which is an
// interface and may only be loaded from toml using the UnmarshalTOML()
// method of the Umarshaler interface.
parts parts
}
func (o *order) UnmarshalTOML(data interface{}) error {
// NOTE the example below contains detailed type casting to show how
// the 'data' is retrieved. In operational use, a type cast wrapper
// may be prefered e.g.
//
// func AsMap(v interface{}) (map[string]interface{}, error) {
// return v.(map[string]interface{})
// }
//
// resulting in:
// d, _ := AsMap(data)
//
d, _ := data.(map[string]interface{})
parts, _ := d["parts"].([]map[string]interface{})
for _, p := range parts {
typ, _ := p["type"].(string)
id, _ := p["id"].(string)
// detect the type of part and handle each case
switch p["type"] {
case "valve":
size := float32(p["size"].(float64))
rating := int(p["rating"].(int64))
valve := &valve{
Type: typ,
ID: id,
Size: size,
Rating: rating,
}
o.parts = append(o.parts, valve)
case "pipe":
length := float32(p["length"].(float64))
diameter := int(p["diameter"].(int64))
pipe := &pipe{
Type: typ,
ID: id,
Length: length,
Diameter: diameter,
}
o.parts = append(o.parts, pipe)
case "cable":
length := int(p["length"].(int64))
rating := float32(p["rating"].(float64))
cable := &cable{
Type: typ,
ID: id,
Length: length,
Rating: rating,
}
o.parts = append(o.parts, cable)
}
}
return nil
}
type parts []part
type part interface {
Name() string
}
type valve struct {
Type string
ID string
Size float32
Rating int
}
func (v *valve) Name() string {
return fmt.Sprintf("VALVE: %s", v.ID)
}
type pipe struct {
Type string
ID string
Length float32
Diameter int
}
func (p *pipe) Name() string {
return fmt.Sprintf("PIPE: %s", p.ID)
}
type cable struct {
Type string
ID string
Length int
Rating float32
}
func (c *cable) Name() string {
return fmt.Sprintf("CABLE: %s", c.ID)
}

View file

@ -0,0 +1,27 @@
/*
Package toml provides facilities for decoding and encoding TOML configuration
files via reflection. There is also support for delaying decoding with
the Primitive type, and querying the set of keys in a TOML document with the
MetaData type.
The specification implemented: https://github.com/mojombo/toml
The sub-command github.com/BurntSushi/toml/cmd/tomlv can be used to verify
whether a file is a valid TOML document. It can also be used to print the
type of each key in a TOML document.
Testing
There are two important types of tests used for this package. The first is
contained inside '*_test.go' files and uses the standard Go unit testing
framework. These tests are primarily devoted to holistically testing the
decoder and encoder.
The second type of testing is used to verify the implementation's adherence
to the TOML specification. These tests have been factored into their own
project: https://github.com/BurntSushi/toml-test
The reason the tests are in a separate project is so that they can be used by
any implementation of TOML. Namely, it is language agnostic.
*/
package toml

View file

@ -0,0 +1,496 @@
package toml
import (
"bufio"
"errors"
"fmt"
"io"
"reflect"
"sort"
"strconv"
"strings"
"time"
)
type tomlEncodeError struct{ error }
var (
errArrayMixedElementTypes = errors.New(
"can't encode array with mixed element types")
errArrayNilElement = errors.New(
"can't encode array with nil element")
errNonString = errors.New(
"can't encode a map with non-string key type")
errAnonNonStruct = errors.New(
"can't encode an anonymous field that is not a struct")
errArrayNoTable = errors.New(
"TOML array element can't contain a table")
errNoKey = errors.New(
"top-level values must be a Go map or struct")
errAnything = errors.New("") // used in testing
)
var quotedReplacer = strings.NewReplacer(
"\t", "\\t",
"\n", "\\n",
"\r", "\\r",
"\"", "\\\"",
"\\", "\\\\",
)
// Encoder controls the encoding of Go values to a TOML document to some
// io.Writer.
//
// The indentation level can be controlled with the Indent field.
type Encoder struct {
// A single indentation level. By default it is two spaces.
Indent string
// hasWritten is whether we have written any output to w yet.
hasWritten bool
w *bufio.Writer
}
// NewEncoder returns a TOML encoder that encodes Go values to the io.Writer
// given. By default, a single indentation level is 2 spaces.
func NewEncoder(w io.Writer) *Encoder {
return &Encoder{
w: bufio.NewWriter(w),
Indent: " ",
}
}
// Encode writes a TOML representation of the Go value to the underlying
// io.Writer. If the value given cannot be encoded to a valid TOML document,
// then an error is returned.
//
// The mapping between Go values and TOML values should be precisely the same
// as for the Decode* functions. Similarly, the TextMarshaler interface is
// supported by encoding the resulting bytes as strings. (If you want to write
// arbitrary binary data then you will need to use something like base64 since
// TOML does not have any binary types.)
//
// When encoding TOML hashes (i.e., Go maps or structs), keys without any
// sub-hashes are encoded first.
//
// If a Go map is encoded, then its keys are sorted alphabetically for
// deterministic output. More control over this behavior may be provided if
// there is demand for it.
//
// Encoding Go values without a corresponding TOML representation---like map
// types with non-string keys---will cause an error to be returned. Similarly
// for mixed arrays/slices, arrays/slices with nil elements, embedded
// non-struct types and nested slices containing maps or structs.
// (e.g., [][]map[string]string is not allowed but []map[string]string is OK
// and so is []map[string][]string.)
func (enc *Encoder) Encode(v interface{}) error {
rv := eindirect(reflect.ValueOf(v))
if err := enc.safeEncode(Key([]string{}), rv); err != nil {
return err
}
return enc.w.Flush()
}
func (enc *Encoder) safeEncode(key Key, rv reflect.Value) (err error) {
defer func() {
if r := recover(); r != nil {
if terr, ok := r.(tomlEncodeError); ok {
err = terr.error
return
}
panic(r)
}
}()
enc.encode(key, rv)
return nil
}
func (enc *Encoder) encode(key Key, rv reflect.Value) {
// Special case. Time needs to be in ISO8601 format.
// Special case. If we can marshal the type to text, then we used that.
// Basically, this prevents the encoder for handling these types as
// generic structs (or whatever the underlying type of a TextMarshaler is).
switch rv.Interface().(type) {
case time.Time, TextMarshaler:
enc.keyEqElement(key, rv)
return
}
k := rv.Kind()
switch k {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32,
reflect.Int64,
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32,
reflect.Uint64,
reflect.Float32, reflect.Float64, reflect.String, reflect.Bool:
enc.keyEqElement(key, rv)
case reflect.Array, reflect.Slice:
if typeEqual(tomlArrayHash, tomlTypeOfGo(rv)) {
enc.eArrayOfTables(key, rv)
} else {
enc.keyEqElement(key, rv)
}
case reflect.Interface:
if rv.IsNil() {
return
}
enc.encode(key, rv.Elem())
case reflect.Map:
if rv.IsNil() {
return
}
enc.eTable(key, rv)
case reflect.Ptr:
if rv.IsNil() {
return
}
enc.encode(key, rv.Elem())
case reflect.Struct:
enc.eTable(key, rv)
default:
panic(e("Unsupported type for key '%s': %s", key, k))
}
}
// eElement encodes any value that can be an array element (primitives and
// arrays).
func (enc *Encoder) eElement(rv reflect.Value) {
switch v := rv.Interface().(type) {
case time.Time:
// Special case time.Time as a primitive. Has to come before
// TextMarshaler below because time.Time implements
// encoding.TextMarshaler, but we need to always use UTC.
enc.wf(v.In(time.FixedZone("UTC", 0)).Format("2006-01-02T15:04:05Z"))
return
case TextMarshaler:
// Special case. Use text marshaler if it's available for this value.
if s, err := v.MarshalText(); err != nil {
encPanic(err)
} else {
enc.writeQuoted(string(s))
}
return
}
switch rv.Kind() {
case reflect.Bool:
enc.wf(strconv.FormatBool(rv.Bool()))
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32,
reflect.Int64:
enc.wf(strconv.FormatInt(rv.Int(), 10))
case reflect.Uint, reflect.Uint8, reflect.Uint16,
reflect.Uint32, reflect.Uint64:
enc.wf(strconv.FormatUint(rv.Uint(), 10))
case reflect.Float32:
enc.wf(floatAddDecimal(strconv.FormatFloat(rv.Float(), 'f', -1, 32)))
case reflect.Float64:
enc.wf(floatAddDecimal(strconv.FormatFloat(rv.Float(), 'f', -1, 64)))
case reflect.Array, reflect.Slice:
enc.eArrayOrSliceElement(rv)
case reflect.Interface:
enc.eElement(rv.Elem())
case reflect.String:
enc.writeQuoted(rv.String())
default:
panic(e("Unexpected primitive type: %s", rv.Kind()))
}
}
// By the TOML spec, all floats must have a decimal with at least one
// number on either side.
func floatAddDecimal(fstr string) string {
if !strings.Contains(fstr, ".") {
return fstr + ".0"
}
return fstr
}
func (enc *Encoder) writeQuoted(s string) {
enc.wf("\"%s\"", quotedReplacer.Replace(s))
}
func (enc *Encoder) eArrayOrSliceElement(rv reflect.Value) {
length := rv.Len()
enc.wf("[")
for i := 0; i < length; i++ {
elem := rv.Index(i)
enc.eElement(elem)
if i != length-1 {
enc.wf(", ")
}
}
enc.wf("]")
}
func (enc *Encoder) eArrayOfTables(key Key, rv reflect.Value) {
if len(key) == 0 {
encPanic(errNoKey)
}
for i := 0; i < rv.Len(); i++ {
trv := rv.Index(i)
if isNil(trv) {
continue
}
panicIfInvalidKey(key)
enc.newline()
enc.wf("%s[[%s]]", enc.indentStr(key), key.maybeQuotedAll())
enc.newline()
enc.eMapOrStruct(key, trv)
}
}
func (enc *Encoder) eTable(key Key, rv reflect.Value) {
panicIfInvalidKey(key)
if len(key) == 1 {
// Output an extra new line between top-level tables.
// (The newline isn't written if nothing else has been written though.)
enc.newline()
}
if len(key) > 0 {
enc.wf("%s[%s]", enc.indentStr(key), key.maybeQuotedAll())
enc.newline()
}
enc.eMapOrStruct(key, rv)
}
func (enc *Encoder) eMapOrStruct(key Key, rv reflect.Value) {
switch rv := eindirect(rv); rv.Kind() {
case reflect.Map:
enc.eMap(key, rv)
case reflect.Struct:
enc.eStruct(key, rv)
default:
panic("eTable: unhandled reflect.Value Kind: " + rv.Kind().String())
}
}
func (enc *Encoder) eMap(key Key, rv reflect.Value) {
rt := rv.Type()
if rt.Key().Kind() != reflect.String {
encPanic(errNonString)
}
// Sort keys so that we have deterministic output. And write keys directly
// underneath this key first, before writing sub-structs or sub-maps.
var mapKeysDirect, mapKeysSub []string
for _, mapKey := range rv.MapKeys() {
k := mapKey.String()
if typeIsHash(tomlTypeOfGo(rv.MapIndex(mapKey))) {
mapKeysSub = append(mapKeysSub, k)
} else {
mapKeysDirect = append(mapKeysDirect, k)
}
}
var writeMapKeys = func(mapKeys []string) {
sort.Strings(mapKeys)
for _, mapKey := range mapKeys {
mrv := rv.MapIndex(reflect.ValueOf(mapKey))
if isNil(mrv) {
// Don't write anything for nil fields.
continue
}
enc.encode(key.add(mapKey), mrv)
}
}
writeMapKeys(mapKeysDirect)
writeMapKeys(mapKeysSub)
}
func (enc *Encoder) eStruct(key Key, rv reflect.Value) {
// Write keys for fields directly under this key first, because if we write
// a field that creates a new table, then all keys under it will be in that
// table (not the one we're writing here).
rt := rv.Type()
var fieldsDirect, fieldsSub [][]int
var addFields func(rt reflect.Type, rv reflect.Value, start []int)
addFields = func(rt reflect.Type, rv reflect.Value, start []int) {
for i := 0; i < rt.NumField(); i++ {
f := rt.Field(i)
// skip unexporded fields
if f.PkgPath != "" {
continue
}
frv := rv.Field(i)
if f.Anonymous {
frv := eindirect(frv)
t := frv.Type()
if t.Kind() != reflect.Struct {
encPanic(errAnonNonStruct)
}
addFields(t, frv, f.Index)
} else if typeIsHash(tomlTypeOfGo(frv)) {
fieldsSub = append(fieldsSub, append(start, f.Index...))
} else {
fieldsDirect = append(fieldsDirect, append(start, f.Index...))
}
}
}
addFields(rt, rv, nil)
var writeFields = func(fields [][]int) {
for _, fieldIndex := range fields {
sft := rt.FieldByIndex(fieldIndex)
sf := rv.FieldByIndex(fieldIndex)
if isNil(sf) {
// Don't write anything for nil fields.
continue
}
keyName := sft.Tag.Get("toml")
if keyName == "-" {
continue
}
if keyName == "" {
keyName = sft.Name
}
enc.encode(key.add(keyName), sf)
}
}
writeFields(fieldsDirect)
writeFields(fieldsSub)
}
// tomlTypeName returns the TOML type name of the Go value's type. It is
// used to determine whether the types of array elements are mixed (which is
// forbidden). If the Go value is nil, then it is illegal for it to be an array
// element, and valueIsNil is returned as true.
// Returns the TOML type of a Go value. The type may be `nil`, which means
// no concrete TOML type could be found.
func tomlTypeOfGo(rv reflect.Value) tomlType {
if isNil(rv) || !rv.IsValid() {
return nil
}
switch rv.Kind() {
case reflect.Bool:
return tomlBool
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32,
reflect.Int64,
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32,
reflect.Uint64:
return tomlInteger
case reflect.Float32, reflect.Float64:
return tomlFloat
case reflect.Array, reflect.Slice:
if typeEqual(tomlHash, tomlArrayType(rv)) {
return tomlArrayHash
} else {
return tomlArray
}
case reflect.Ptr, reflect.Interface:
return tomlTypeOfGo(rv.Elem())
case reflect.String:
return tomlString
case reflect.Map:
return tomlHash
case reflect.Struct:
switch rv.Interface().(type) {
case time.Time:
return tomlDatetime
case TextMarshaler:
return tomlString
default:
return tomlHash
}
default:
panic("unexpected reflect.Kind: " + rv.Kind().String())
}
}
// tomlArrayType returns the element type of a TOML array. The type returned
// may be nil if it cannot be determined (e.g., a nil slice or a zero length
// slize). This function may also panic if it finds a type that cannot be
// expressed in TOML (such as nil elements, heterogeneous arrays or directly
// nested arrays of tables).
func tomlArrayType(rv reflect.Value) tomlType {
if isNil(rv) || !rv.IsValid() || rv.Len() == 0 {
return nil
}
firstType := tomlTypeOfGo(rv.Index(0))
if firstType == nil {
encPanic(errArrayNilElement)
}
rvlen := rv.Len()
for i := 1; i < rvlen; i++ {
elem := rv.Index(i)
switch elemType := tomlTypeOfGo(elem); {
case elemType == nil:
encPanic(errArrayNilElement)
case !typeEqual(firstType, elemType):
encPanic(errArrayMixedElementTypes)
}
}
// If we have a nested array, then we must make sure that the nested
// array contains ONLY primitives.
// This checks arbitrarily nested arrays.
if typeEqual(firstType, tomlArray) || typeEqual(firstType, tomlArrayHash) {
nest := tomlArrayType(eindirect(rv.Index(0)))
if typeEqual(nest, tomlHash) || typeEqual(nest, tomlArrayHash) {
encPanic(errArrayNoTable)
}
}
return firstType
}
func (enc *Encoder) newline() {
if enc.hasWritten {
enc.wf("\n")
}
}
func (enc *Encoder) keyEqElement(key Key, val reflect.Value) {
if len(key) == 0 {
encPanic(errNoKey)
}
panicIfInvalidKey(key)
enc.wf("%s%s = ", enc.indentStr(key), key.maybeQuoted(len(key)-1))
enc.eElement(val)
enc.newline()
}
func (enc *Encoder) wf(format string, v ...interface{}) {
if _, err := fmt.Fprintf(enc.w, format, v...); err != nil {
encPanic(err)
}
enc.hasWritten = true
}
func (enc *Encoder) indentStr(key Key) string {
return strings.Repeat(enc.Indent, len(key)-1)
}
func encPanic(err error) {
panic(tomlEncodeError{err})
}
func eindirect(v reflect.Value) reflect.Value {
switch v.Kind() {
case reflect.Ptr, reflect.Interface:
return eindirect(v.Elem())
default:
return v
}
}
func isNil(rv reflect.Value) bool {
switch rv.Kind() {
case reflect.Interface, reflect.Map, reflect.Ptr, reflect.Slice:
return rv.IsNil()
default:
return false
}
}
func panicIfInvalidKey(key Key) {
for _, k := range key {
if len(k) == 0 {
encPanic(e("Key '%s' is not a valid table name. Key names "+
"cannot be empty.", key.maybeQuotedAll()))
}
}
}
func isValidKeyName(s string) bool {
return len(s) != 0
}

View file

@ -0,0 +1,506 @@
package toml
import (
"bytes"
"fmt"
"log"
"net"
"testing"
"time"
)
func TestEncodeRoundTrip(t *testing.T) {
type Config struct {
Age int
Cats []string
Pi float64
Perfection []int
DOB time.Time
Ipaddress net.IP
}
var inputs = Config{
13,
[]string{"one", "two", "three"},
3.145,
[]int{11, 2, 3, 4},
time.Now(),
net.ParseIP("192.168.59.254"),
}
var firstBuffer bytes.Buffer
e := NewEncoder(&firstBuffer)
err := e.Encode(inputs)
if err != nil {
t.Fatal(err)
}
var outputs Config
if _, err := Decode(firstBuffer.String(), &outputs); err != nil {
log.Printf("Could not decode:\n-----\n%s\n-----\n",
firstBuffer.String())
t.Fatal(err)
}
// could test each value individually, but I'm lazy
var secondBuffer bytes.Buffer
e2 := NewEncoder(&secondBuffer)
err = e2.Encode(outputs)
if err != nil {
t.Fatal(err)
}
if firstBuffer.String() != secondBuffer.String() {
t.Error(
firstBuffer.String(),
"\n\n is not identical to\n\n",
secondBuffer.String())
}
}
// XXX(burntsushi)
// I think these tests probably should be removed. They are good, but they
// ought to be obsolete by toml-test.
func TestEncode(t *testing.T) {
type Embedded struct {
Int int `toml:"_int"`
}
type NonStruct int
date := time.Date(2014, 5, 11, 20, 30, 40, 0, time.FixedZone("IST", 3600))
dateStr := "2014-05-11T19:30:40Z"
tests := map[string]struct {
input interface{}
wantOutput string
wantError error
}{
"bool field": {
input: struct {
BoolTrue bool
BoolFalse bool
}{true, false},
wantOutput: "BoolTrue = true\nBoolFalse = false\n",
},
"int fields": {
input: struct {
Int int
Int8 int8
Int16 int16
Int32 int32
Int64 int64
}{1, 2, 3, 4, 5},
wantOutput: "Int = 1\nInt8 = 2\nInt16 = 3\nInt32 = 4\nInt64 = 5\n",
},
"uint fields": {
input: struct {
Uint uint
Uint8 uint8
Uint16 uint16
Uint32 uint32
Uint64 uint64
}{1, 2, 3, 4, 5},
wantOutput: "Uint = 1\nUint8 = 2\nUint16 = 3\nUint32 = 4" +
"\nUint64 = 5\n",
},
"float fields": {
input: struct {
Float32 float32
Float64 float64
}{1.5, 2.5},
wantOutput: "Float32 = 1.5\nFloat64 = 2.5\n",
},
"string field": {
input: struct{ String string }{"foo"},
wantOutput: "String = \"foo\"\n",
},
"string field and unexported field": {
input: struct {
String string
unexported int
}{"foo", 0},
wantOutput: "String = \"foo\"\n",
},
"datetime field in UTC": {
input: struct{ Date time.Time }{date},
wantOutput: fmt.Sprintf("Date = %s\n", dateStr),
},
"datetime field as primitive": {
// Using a map here to fail if isStructOrMap() returns true for
// time.Time.
input: map[string]interface{}{
"Date": date,
"Int": 1,
},
wantOutput: fmt.Sprintf("Date = %s\nInt = 1\n", dateStr),
},
"array fields": {
input: struct {
IntArray0 [0]int
IntArray3 [3]int
}{[0]int{}, [3]int{1, 2, 3}},
wantOutput: "IntArray0 = []\nIntArray3 = [1, 2, 3]\n",
},
"slice fields": {
input: struct{ IntSliceNil, IntSlice0, IntSlice3 []int }{
nil, []int{}, []int{1, 2, 3},
},
wantOutput: "IntSlice0 = []\nIntSlice3 = [1, 2, 3]\n",
},
"datetime slices": {
input: struct{ DatetimeSlice []time.Time }{
[]time.Time{date, date},
},
wantOutput: fmt.Sprintf("DatetimeSlice = [%s, %s]\n",
dateStr, dateStr),
},
"nested arrays and slices": {
input: struct {
SliceOfArrays [][2]int
ArrayOfSlices [2][]int
SliceOfArraysOfSlices [][2][]int
ArrayOfSlicesOfArrays [2][][2]int
SliceOfMixedArrays [][2]interface{}
ArrayOfMixedSlices [2][]interface{}
}{
[][2]int{{1, 2}, {3, 4}},
[2][]int{{1, 2}, {3, 4}},
[][2][]int{
{
{1, 2}, {3, 4},
},
{
{5, 6}, {7, 8},
},
},
[2][][2]int{
{
{1, 2}, {3, 4},
},
{
{5, 6}, {7, 8},
},
},
[][2]interface{}{
{1, 2}, {"a", "b"},
},
[2][]interface{}{
{1, 2}, {"a", "b"},
},
},
wantOutput: `SliceOfArrays = [[1, 2], [3, 4]]
ArrayOfSlices = [[1, 2], [3, 4]]
SliceOfArraysOfSlices = [[[1, 2], [3, 4]], [[5, 6], [7, 8]]]
ArrayOfSlicesOfArrays = [[[1, 2], [3, 4]], [[5, 6], [7, 8]]]
SliceOfMixedArrays = [[1, 2], ["a", "b"]]
ArrayOfMixedSlices = [[1, 2], ["a", "b"]]
`,
},
"empty slice": {
input: struct{ Empty []interface{} }{[]interface{}{}},
wantOutput: "Empty = []\n",
},
"(error) slice with element type mismatch (string and integer)": {
input: struct{ Mixed []interface{} }{[]interface{}{1, "a"}},
wantError: errArrayMixedElementTypes,
},
"(error) slice with element type mismatch (integer and float)": {
input: struct{ Mixed []interface{} }{[]interface{}{1, 2.5}},
wantError: errArrayMixedElementTypes,
},
"slice with elems of differing Go types, same TOML types": {
input: struct {
MixedInts []interface{}
MixedFloats []interface{}
}{
[]interface{}{
int(1), int8(2), int16(3), int32(4), int64(5),
uint(1), uint8(2), uint16(3), uint32(4), uint64(5),
},
[]interface{}{float32(1.5), float64(2.5)},
},
wantOutput: "MixedInts = [1, 2, 3, 4, 5, 1, 2, 3, 4, 5]\n" +
"MixedFloats = [1.5, 2.5]\n",
},
"(error) slice w/ element type mismatch (one is nested array)": {
input: struct{ Mixed []interface{} }{
[]interface{}{1, []interface{}{2}},
},
wantError: errArrayMixedElementTypes,
},
"(error) slice with 1 nil element": {
input: struct{ NilElement1 []interface{} }{[]interface{}{nil}},
wantError: errArrayNilElement,
},
"(error) slice with 1 nil element (and other non-nil elements)": {
input: struct{ NilElement []interface{} }{
[]interface{}{1, nil},
},
wantError: errArrayNilElement,
},
"simple map": {
input: map[string]int{"a": 1, "b": 2},
wantOutput: "a = 1\nb = 2\n",
},
"map with interface{} value type": {
input: map[string]interface{}{"a": 1, "b": "c"},
wantOutput: "a = 1\nb = \"c\"\n",
},
"map with interface{} value type, some of which are structs": {
input: map[string]interface{}{
"a": struct{ Int int }{2},
"b": 1,
},
wantOutput: "b = 1\n\n[a]\n Int = 2\n",
},
"nested map": {
input: map[string]map[string]int{
"a": {"b": 1},
"c": {"d": 2},
},
wantOutput: "[a]\n b = 1\n\n[c]\n d = 2\n",
},
"nested struct": {
input: struct{ Struct struct{ Int int } }{
struct{ Int int }{1},
},
wantOutput: "[Struct]\n Int = 1\n",
},
"nested struct and non-struct field": {
input: struct {
Struct struct{ Int int }
Bool bool
}{struct{ Int int }{1}, true},
wantOutput: "Bool = true\n\n[Struct]\n Int = 1\n",
},
"2 nested structs": {
input: struct{ Struct1, Struct2 struct{ Int int } }{
struct{ Int int }{1}, struct{ Int int }{2},
},
wantOutput: "[Struct1]\n Int = 1\n\n[Struct2]\n Int = 2\n",
},
"deeply nested structs": {
input: struct {
Struct1, Struct2 struct{ Struct3 *struct{ Int int } }
}{
struct{ Struct3 *struct{ Int int } }{&struct{ Int int }{1}},
struct{ Struct3 *struct{ Int int } }{nil},
},
wantOutput: "[Struct1]\n [Struct1.Struct3]\n Int = 1" +
"\n\n[Struct2]\n",
},
"nested struct with nil struct elem": {
input: struct {
Struct struct{ Inner *struct{ Int int } }
}{
struct{ Inner *struct{ Int int } }{nil},
},
wantOutput: "[Struct]\n",
},
"nested struct with no fields": {
input: struct {
Struct struct{ Inner struct{} }
}{
struct{ Inner struct{} }{struct{}{}},
},
wantOutput: "[Struct]\n [Struct.Inner]\n",
},
"struct with tags": {
input: struct {
Struct struct {
Int int `toml:"_int"`
} `toml:"_struct"`
Bool bool `toml:"_bool"`
}{
struct {
Int int `toml:"_int"`
}{1}, true,
},
wantOutput: "_bool = true\n\n[_struct]\n _int = 1\n",
},
"embedded struct": {
input: struct{ Embedded }{Embedded{1}},
wantOutput: "_int = 1\n",
},
"embedded *struct": {
input: struct{ *Embedded }{&Embedded{1}},
wantOutput: "_int = 1\n",
},
"nested embedded struct": {
input: struct {
Struct struct{ Embedded } `toml:"_struct"`
}{struct{ Embedded }{Embedded{1}}},
wantOutput: "[_struct]\n _int = 1\n",
},
"nested embedded *struct": {
input: struct {
Struct struct{ *Embedded } `toml:"_struct"`
}{struct{ *Embedded }{&Embedded{1}}},
wantOutput: "[_struct]\n _int = 1\n",
},
"array of tables": {
input: struct {
Structs []*struct{ Int int } `toml:"struct"`
}{
[]*struct{ Int int }{{1}, {3}},
},
wantOutput: "[[struct]]\n Int = 1\n\n[[struct]]\n Int = 3\n",
},
"array of tables order": {
input: map[string]interface{}{
"map": map[string]interface{}{
"zero": 5,
"arr": []map[string]int{
map[string]int{
"friend": 5,
},
},
},
},
wantOutput: "[map]\n zero = 5\n\n [[map.arr]]\n friend = 5\n",
},
"(error) top-level slice": {
input: []struct{ Int int }{{1}, {2}, {3}},
wantError: errNoKey,
},
"(error) slice of slice": {
input: struct {
Slices [][]struct{ Int int }
}{
[][]struct{ Int int }{{{1}}, {{2}}, {{3}}},
},
wantError: errArrayNoTable,
},
"(error) map no string key": {
input: map[int]string{1: ""},
wantError: errNonString,
},
"(error) anonymous non-struct": {
input: struct{ NonStruct }{5},
wantError: errAnonNonStruct,
},
"(error) empty key name": {
input: map[string]int{"": 1},
wantError: errAnything,
},
"(error) empty map name": {
input: map[string]interface{}{
"": map[string]int{"v": 1},
},
wantError: errAnything,
},
}
for label, test := range tests {
encodeExpected(t, label, test.input, test.wantOutput, test.wantError)
}
}
func TestEncodeNestedTableArrays(t *testing.T) {
type song struct {
Name string `toml:"name"`
}
type album struct {
Name string `toml:"name"`
Songs []song `toml:"songs"`
}
type springsteen struct {
Albums []album `toml:"albums"`
}
value := springsteen{
[]album{
{"Born to Run",
[]song{{"Jungleland"}, {"Meeting Across the River"}}},
{"Born in the USA",
[]song{{"Glory Days"}, {"Dancing in the Dark"}}},
},
}
expected := `[[albums]]
name = "Born to Run"
[[albums.songs]]
name = "Jungleland"
[[albums.songs]]
name = "Meeting Across the River"
[[albums]]
name = "Born in the USA"
[[albums.songs]]
name = "Glory Days"
[[albums.songs]]
name = "Dancing in the Dark"
`
encodeExpected(t, "nested table arrays", value, expected, nil)
}
func TestEncodeArrayHashWithNormalHashOrder(t *testing.T) {
type Alpha struct {
V int
}
type Beta struct {
V int
}
type Conf struct {
V int
A Alpha
B []Beta
}
val := Conf{
V: 1,
A: Alpha{2},
B: []Beta{{3}},
}
expected := "V = 1\n\n[A]\n V = 2\n\n[[B]]\n V = 3\n"
encodeExpected(t, "array hash with normal hash order", val, expected, nil)
}
func encodeExpected(
t *testing.T, label string, val interface{}, wantStr string, wantErr error,
) {
var buf bytes.Buffer
enc := NewEncoder(&buf)
err := enc.Encode(val)
if err != wantErr {
if wantErr != nil {
if wantErr == errAnything && err != nil {
return
}
t.Errorf("%s: want Encode error %v, got %v", label, wantErr, err)
} else {
t.Errorf("%s: Encode failed: %s", label, err)
}
}
if err != nil {
return
}
if got := buf.String(); wantStr != got {
t.Errorf("%s: want\n-----\n%q\n-----\nbut got\n-----\n%q\n-----\n",
label, wantStr, got)
}
}
func ExampleEncoder_Encode() {
date, _ := time.Parse(time.RFC822, "14 Mar 10 18:00 UTC")
var config = map[string]interface{}{
"date": date,
"counts": []int{1, 1, 2, 3, 5, 8},
"hash": map[string]string{
"key1": "val1",
"key2": "val2",
},
}
buf := new(bytes.Buffer)
if err := NewEncoder(buf).Encode(config); err != nil {
log.Fatal(err)
}
fmt.Println(buf.String())
// Output:
// counts = [1, 1, 2, 3, 5, 8]
// date = 2010-03-14T18:00:00Z
//
// [hash]
// key1 = "val1"
// key2 = "val2"
}

View file

@ -0,0 +1,19 @@
// +build go1.2
package toml
// In order to support Go 1.1, we define our own TextMarshaler and
// TextUnmarshaler types. For Go 1.2+, we just alias them with the
// standard library interfaces.
import (
"encoding"
)
// TextMarshaler is a synonym for encoding.TextMarshaler. It is defined here
// so that Go 1.1 can be supported.
type TextMarshaler encoding.TextMarshaler
// TextUnmarshaler is a synonym for encoding.TextUnmarshaler. It is defined
// here so that Go 1.1 can be supported.
type TextUnmarshaler encoding.TextUnmarshaler

View file

@ -0,0 +1,18 @@
// +build !go1.2
package toml
// These interfaces were introduced in Go 1.2, so we add them manually when
// compiling for Go 1.1.
// TextMarshaler is a synonym for encoding.TextMarshaler. It is defined here
// so that Go 1.1 can be supported.
type TextMarshaler interface {
MarshalText() (text []byte, err error)
}
// TextUnmarshaler is a synonym for encoding.TextUnmarshaler. It is defined
// here so that Go 1.1 can be supported.
type TextUnmarshaler interface {
UnmarshalText(text []byte) error
}

874
Godeps/_workspace/src/github.com/BurntSushi/toml/lex.go generated vendored Normal file
View file

@ -0,0 +1,874 @@
package toml
import (
"fmt"
"strings"
"unicode/utf8"
)
type itemType int
const (
itemError itemType = iota
itemNIL // used in the parser to indicate no type
itemEOF
itemText
itemString
itemRawString
itemMultilineString
itemRawMultilineString
itemBool
itemInteger
itemFloat
itemDatetime
itemArray // the start of an array
itemArrayEnd
itemTableStart
itemTableEnd
itemArrayTableStart
itemArrayTableEnd
itemKeyStart
itemCommentStart
)
const (
eof = 0
tableStart = '['
tableEnd = ']'
arrayTableStart = '['
arrayTableEnd = ']'
tableSep = '.'
keySep = '='
arrayStart = '['
arrayEnd = ']'
arrayValTerm = ','
commentStart = '#'
stringStart = '"'
stringEnd = '"'
rawStringStart = '\''
rawStringEnd = '\''
)
type stateFn func(lx *lexer) stateFn
type lexer struct {
input string
start int
pos int
width int
line int
state stateFn
items chan item
// A stack of state functions used to maintain context.
// The idea is to reuse parts of the state machine in various places.
// For example, values can appear at the top level or within arbitrarily
// nested arrays. The last state on the stack is used after a value has
// been lexed. Similarly for comments.
stack []stateFn
}
type item struct {
typ itemType
val string
line int
}
func (lx *lexer) nextItem() item {
for {
select {
case item := <-lx.items:
return item
default:
lx.state = lx.state(lx)
}
}
}
func lex(input string) *lexer {
lx := &lexer{
input: input + "\n",
state: lexTop,
line: 1,
items: make(chan item, 10),
stack: make([]stateFn, 0, 10),
}
return lx
}
func (lx *lexer) push(state stateFn) {
lx.stack = append(lx.stack, state)
}
func (lx *lexer) pop() stateFn {
if len(lx.stack) == 0 {
return lx.errorf("BUG in lexer: no states to pop.")
}
last := lx.stack[len(lx.stack)-1]
lx.stack = lx.stack[0 : len(lx.stack)-1]
return last
}
func (lx *lexer) current() string {
return lx.input[lx.start:lx.pos]
}
func (lx *lexer) emit(typ itemType) {
lx.items <- item{typ, lx.current(), lx.line}
lx.start = lx.pos
}
func (lx *lexer) emitTrim(typ itemType) {
lx.items <- item{typ, strings.TrimSpace(lx.current()), lx.line}
lx.start = lx.pos
}
func (lx *lexer) next() (r rune) {
if lx.pos >= len(lx.input) {
lx.width = 0
return eof
}
if lx.input[lx.pos] == '\n' {
lx.line++
}
r, lx.width = utf8.DecodeRuneInString(lx.input[lx.pos:])
lx.pos += lx.width
return r
}
// ignore skips over the pending input before this point.
func (lx *lexer) ignore() {
lx.start = lx.pos
}
// backup steps back one rune. Can be called only once per call of next.
func (lx *lexer) backup() {
lx.pos -= lx.width
if lx.pos < len(lx.input) && lx.input[lx.pos] == '\n' {
lx.line--
}
}
// accept consumes the next rune if it's equal to `valid`.
func (lx *lexer) accept(valid rune) bool {
if lx.next() == valid {
return true
}
lx.backup()
return false
}
// peek returns but does not consume the next rune in the input.
func (lx *lexer) peek() rune {
r := lx.next()
lx.backup()
return r
}
// errorf stops all lexing by emitting an error and returning `nil`.
// Note that any value that is a character is escaped if it's a special
// character (new lines, tabs, etc.).
func (lx *lexer) errorf(format string, values ...interface{}) stateFn {
lx.items <- item{
itemError,
fmt.Sprintf(format, values...),
lx.line,
}
return nil
}
// lexTop consumes elements at the top level of TOML data.
func lexTop(lx *lexer) stateFn {
r := lx.next()
if isWhitespace(r) || isNL(r) {
return lexSkip(lx, lexTop)
}
switch r {
case commentStart:
lx.push(lexTop)
return lexCommentStart
case tableStart:
return lexTableStart
case eof:
if lx.pos > lx.start {
return lx.errorf("Unexpected EOF.")
}
lx.emit(itemEOF)
return nil
}
// At this point, the only valid item can be a key, so we back up
// and let the key lexer do the rest.
lx.backup()
lx.push(lexTopEnd)
return lexKeyStart
}
// lexTopEnd is entered whenever a top-level item has been consumed. (A value
// or a table.) It must see only whitespace, and will turn back to lexTop
// upon a new line. If it sees EOF, it will quit the lexer successfully.
func lexTopEnd(lx *lexer) stateFn {
r := lx.next()
switch {
case r == commentStart:
// a comment will read to a new line for us.
lx.push(lexTop)
return lexCommentStart
case isWhitespace(r):
return lexTopEnd
case isNL(r):
lx.ignore()
return lexTop
case r == eof:
lx.ignore()
return lexTop
}
return lx.errorf("Expected a top-level item to end with a new line, "+
"comment or EOF, but got %q instead.", r)
}
// lexTable lexes the beginning of a table. Namely, it makes sure that
// it starts with a character other than '.' and ']'.
// It assumes that '[' has already been consumed.
// It also handles the case that this is an item in an array of tables.
// e.g., '[[name]]'.
func lexTableStart(lx *lexer) stateFn {
if lx.peek() == arrayTableStart {
lx.next()
lx.emit(itemArrayTableStart)
lx.push(lexArrayTableEnd)
} else {
lx.emit(itemTableStart)
lx.push(lexTableEnd)
}
return lexTableNameStart
}
func lexTableEnd(lx *lexer) stateFn {
lx.emit(itemTableEnd)
return lexTopEnd
}
func lexArrayTableEnd(lx *lexer) stateFn {
if r := lx.next(); r != arrayTableEnd {
return lx.errorf("Expected end of table array name delimiter %q, "+
"but got %q instead.", arrayTableEnd, r)
}
lx.emit(itemArrayTableEnd)
return lexTopEnd
}
func lexTableNameStart(lx *lexer) stateFn {
switch r := lx.peek(); {
case r == tableEnd || r == eof:
return lx.errorf("Unexpected end of table name. (Table names cannot " +
"be empty.)")
case r == tableSep:
return lx.errorf("Unexpected table separator. (Table names cannot " +
"be empty.)")
case r == stringStart || r == rawStringStart:
lx.ignore()
lx.push(lexTableNameEnd)
return lexValue // reuse string lexing
case isWhitespace(r):
return lexTableNameStart
default:
return lexBareTableName
}
}
// lexTableName lexes the name of a table. It assumes that at least one
// valid character for the table has already been read.
func lexBareTableName(lx *lexer) stateFn {
switch r := lx.next(); {
case isBareKeyChar(r):
return lexBareTableName
case r == tableSep || r == tableEnd:
lx.backup()
lx.emitTrim(itemText)
return lexTableNameEnd
default:
return lx.errorf("Bare keys cannot contain %q.", r)
}
}
// lexTableNameEnd reads the end of a piece of a table name, optionally
// consuming whitespace.
func lexTableNameEnd(lx *lexer) stateFn {
switch r := lx.next(); {
case isWhitespace(r):
return lexTableNameEnd
case r == tableSep:
lx.ignore()
return lexTableNameStart
case r == tableEnd:
return lx.pop()
default:
return lx.errorf("Expected '.' or ']' to end table name, but got %q "+
"instead.", r)
}
}
// lexKeyStart consumes a key name up until the first non-whitespace character.
// lexKeyStart will ignore whitespace.
func lexKeyStart(lx *lexer) stateFn {
r := lx.peek()
switch {
case r == keySep:
return lx.errorf("Unexpected key separator %q.", keySep)
case isWhitespace(r) || isNL(r):
lx.next()
return lexSkip(lx, lexKeyStart)
case r == stringStart || r == rawStringStart:
lx.ignore()
lx.emit(itemKeyStart)
lx.push(lexKeyEnd)
return lexValue // reuse string lexing
default:
lx.ignore()
lx.emit(itemKeyStart)
return lexBareKey
}
}
// lexBareKey consumes the text of a bare key. Assumes that the first character
// (which is not whitespace) has not yet been consumed.
func lexBareKey(lx *lexer) stateFn {
switch r := lx.next(); {
case isBareKeyChar(r):
return lexBareKey
case isWhitespace(r):
lx.emitTrim(itemText)
return lexKeyEnd
case r == keySep:
lx.backup()
lx.emitTrim(itemText)
return lexKeyEnd
default:
return lx.errorf("Bare keys cannot contain %q.", r)
}
}
// lexKeyEnd consumes the end of a key and trims whitespace (up to the key
// separator).
func lexKeyEnd(lx *lexer) stateFn {
switch r := lx.next(); {
case r == keySep:
return lexSkip(lx, lexValue)
case isWhitespace(r):
return lexSkip(lx, lexKeyEnd)
default:
return lx.errorf("Expected key separator %q, but got %q instead.",
keySep, r)
}
}
// lexValue starts the consumption of a value anywhere a value is expected.
// lexValue will ignore whitespace.
// After a value is lexed, the last state on the next is popped and returned.
func lexValue(lx *lexer) stateFn {
// We allow whitespace to precede a value, but NOT new lines.
// In array syntax, the array states are responsible for ignoring new
// lines.
r := lx.next()
if isWhitespace(r) {
return lexSkip(lx, lexValue)
}
switch {
case r == arrayStart:
lx.ignore()
lx.emit(itemArray)
return lexArrayValue
case r == stringStart:
if lx.accept(stringStart) {
if lx.accept(stringStart) {
lx.ignore() // Ignore """
return lexMultilineString
}
lx.backup()
}
lx.ignore() // ignore the '"'
return lexString
case r == rawStringStart:
if lx.accept(rawStringStart) {
if lx.accept(rawStringStart) {
lx.ignore() // Ignore """
return lexMultilineRawString
}
lx.backup()
}
lx.ignore() // ignore the "'"
return lexRawString
case r == 't':
return lexTrue
case r == 'f':
return lexFalse
case r == '-':
return lexNumberStart
case isDigit(r):
lx.backup() // avoid an extra state and use the same as above
return lexNumberOrDateStart
case r == '.': // special error case, be kind to users
return lx.errorf("Floats must start with a digit, not '.'.")
}
return lx.errorf("Expected value but found %q instead.", r)
}
// lexArrayValue consumes one value in an array. It assumes that '[' or ','
// have already been consumed. All whitespace and new lines are ignored.
func lexArrayValue(lx *lexer) stateFn {
r := lx.next()
switch {
case isWhitespace(r) || isNL(r):
return lexSkip(lx, lexArrayValue)
case r == commentStart:
lx.push(lexArrayValue)
return lexCommentStart
case r == arrayValTerm:
return lx.errorf("Unexpected array value terminator %q.",
arrayValTerm)
case r == arrayEnd:
return lexArrayEnd
}
lx.backup()
lx.push(lexArrayValueEnd)
return lexValue
}
// lexArrayValueEnd consumes the cruft between values of an array. Namely,
// it ignores whitespace and expects either a ',' or a ']'.
func lexArrayValueEnd(lx *lexer) stateFn {
r := lx.next()
switch {
case isWhitespace(r) || isNL(r):
return lexSkip(lx, lexArrayValueEnd)
case r == commentStart:
lx.push(lexArrayValueEnd)
return lexCommentStart
case r == arrayValTerm:
lx.ignore()
return lexArrayValue // move on to the next value
case r == arrayEnd:
return lexArrayEnd
}
return lx.errorf("Expected an array value terminator %q or an array "+
"terminator %q, but got %q instead.", arrayValTerm, arrayEnd, r)
}
// lexArrayEnd finishes the lexing of an array. It assumes that a ']' has
// just been consumed.
func lexArrayEnd(lx *lexer) stateFn {
lx.ignore()
lx.emit(itemArrayEnd)
return lx.pop()
}
// lexString consumes the inner contents of a string. It assumes that the
// beginning '"' has already been consumed and ignored.
func lexString(lx *lexer) stateFn {
r := lx.next()
switch {
case isNL(r):
return lx.errorf("Strings cannot contain new lines.")
case r == '\\':
lx.push(lexString)
return lexStringEscape
case r == stringEnd:
lx.backup()
lx.emit(itemString)
lx.next()
lx.ignore()
return lx.pop()
}
return lexString
}
// lexMultilineString consumes the inner contents of a string. It assumes that
// the beginning '"""' has already been consumed and ignored.
func lexMultilineString(lx *lexer) stateFn {
r := lx.next()
switch {
case r == '\\':
return lexMultilineStringEscape
case r == stringEnd:
if lx.accept(stringEnd) {
if lx.accept(stringEnd) {
lx.backup()
lx.backup()
lx.backup()
lx.emit(itemMultilineString)
lx.next()
lx.next()
lx.next()
lx.ignore()
return lx.pop()
}
lx.backup()
}
}
return lexMultilineString
}
// lexRawString consumes a raw string. Nothing can be escaped in such a string.
// It assumes that the beginning "'" has already been consumed and ignored.
func lexRawString(lx *lexer) stateFn {
r := lx.next()
switch {
case isNL(r):
return lx.errorf("Strings cannot contain new lines.")
case r == rawStringEnd:
lx.backup()
lx.emit(itemRawString)
lx.next()
lx.ignore()
return lx.pop()
}
return lexRawString
}
// lexMultilineRawString consumes a raw string. Nothing can be escaped in such
// a string. It assumes that the beginning "'" has already been consumed and
// ignored.
func lexMultilineRawString(lx *lexer) stateFn {
r := lx.next()
switch {
case r == rawStringEnd:
if lx.accept(rawStringEnd) {
if lx.accept(rawStringEnd) {
lx.backup()
lx.backup()
lx.backup()
lx.emit(itemRawMultilineString)
lx.next()
lx.next()
lx.next()
lx.ignore()
return lx.pop()
}
lx.backup()
}
}
return lexMultilineRawString
}
// lexMultilineStringEscape consumes an escaped character. It assumes that the
// preceding '\\' has already been consumed.
func lexMultilineStringEscape(lx *lexer) stateFn {
// Handle the special case first:
if isNL(lx.next()) {
lx.next()
return lexMultilineString
} else {
lx.backup()
lx.push(lexMultilineString)
return lexStringEscape(lx)
}
}
func lexStringEscape(lx *lexer) stateFn {
r := lx.next()
switch r {
case 'b':
fallthrough
case 't':
fallthrough
case 'n':
fallthrough
case 'f':
fallthrough
case 'r':
fallthrough
case '"':
fallthrough
case '\\':
return lx.pop()
case 'u':
return lexShortUnicodeEscape
case 'U':
return lexLongUnicodeEscape
}
return lx.errorf("Invalid escape character %q. Only the following "+
"escape characters are allowed: "+
"\\b, \\t, \\n, \\f, \\r, \\\", \\/, \\\\, "+
"\\uXXXX and \\UXXXXXXXX.", r)
}
func lexShortUnicodeEscape(lx *lexer) stateFn {
var r rune
for i := 0; i < 4; i++ {
r = lx.next()
if !isHexadecimal(r) {
return lx.errorf("Expected four hexadecimal digits after '\\u', "+
"but got '%s' instead.", lx.current())
}
}
return lx.pop()
}
func lexLongUnicodeEscape(lx *lexer) stateFn {
var r rune
for i := 0; i < 8; i++ {
r = lx.next()
if !isHexadecimal(r) {
return lx.errorf("Expected eight hexadecimal digits after '\\U', "+
"but got '%s' instead.", lx.current())
}
}
return lx.pop()
}
// lexNumberOrDateStart consumes either a (positive) integer, float or
// datetime. It assumes that NO negative sign has been consumed.
func lexNumberOrDateStart(lx *lexer) stateFn {
r := lx.next()
if !isDigit(r) {
if r == '.' {
return lx.errorf("Floats must start with a digit, not '.'.")
} else {
return lx.errorf("Expected a digit but got %q.", r)
}
}
return lexNumberOrDate
}
// lexNumberOrDate consumes either a (positive) integer, float or datetime.
func lexNumberOrDate(lx *lexer) stateFn {
r := lx.next()
switch {
case r == '-':
if lx.pos-lx.start != 5 {
return lx.errorf("All ISO8601 dates must be in full Zulu form.")
}
return lexDateAfterYear
case isDigit(r):
return lexNumberOrDate
case r == '.':
return lexFloatStart
}
lx.backup()
lx.emit(itemInteger)
return lx.pop()
}
// lexDateAfterYear consumes a full Zulu Datetime in ISO8601 format.
// It assumes that "YYYY-" has already been consumed.
func lexDateAfterYear(lx *lexer) stateFn {
formats := []rune{
// digits are '0'.
// everything else is direct equality.
'0', '0', '-', '0', '0',
'T',
'0', '0', ':', '0', '0', ':', '0', '0',
'Z',
}
for _, f := range formats {
r := lx.next()
if f == '0' {
if !isDigit(r) {
return lx.errorf("Expected digit in ISO8601 datetime, "+
"but found %q instead.", r)
}
} else if f != r {
return lx.errorf("Expected %q in ISO8601 datetime, "+
"but found %q instead.", f, r)
}
}
lx.emit(itemDatetime)
return lx.pop()
}
// lexNumberStart consumes either an integer or a float. It assumes that
// a negative sign has already been read, but that *no* digits have been
// consumed. lexNumberStart will move to the appropriate integer or float
// states.
func lexNumberStart(lx *lexer) stateFn {
// we MUST see a digit. Even floats have to start with a digit.
r := lx.next()
if !isDigit(r) {
if r == '.' {
return lx.errorf("Floats must start with a digit, not '.'.")
} else {
return lx.errorf("Expected a digit but got %q.", r)
}
}
return lexNumber
}
// lexNumber consumes an integer or a float after seeing the first digit.
func lexNumber(lx *lexer) stateFn {
r := lx.next()
switch {
case isDigit(r):
return lexNumber
case r == '.':
return lexFloatStart
}
lx.backup()
lx.emit(itemInteger)
return lx.pop()
}
// lexFloatStart starts the consumption of digits of a float after a '.'.
// Namely, at least one digit is required.
func lexFloatStart(lx *lexer) stateFn {
r := lx.next()
if !isDigit(r) {
return lx.errorf("Floats must have a digit after the '.', but got "+
"%q instead.", r)
}
return lexFloat
}
// lexFloat consumes the digits of a float after a '.'.
// Assumes that one digit has been consumed after a '.' already.
func lexFloat(lx *lexer) stateFn {
r := lx.next()
if isDigit(r) {
return lexFloat
}
lx.backup()
lx.emit(itemFloat)
return lx.pop()
}
// lexConst consumes the s[1:] in s. It assumes that s[0] has already been
// consumed.
func lexConst(lx *lexer, s string) stateFn {
for i := range s[1:] {
if r := lx.next(); r != rune(s[i+1]) {
return lx.errorf("Expected %q, but found %q instead.", s[:i+1],
s[:i]+string(r))
}
}
return nil
}
// lexTrue consumes the "rue" in "true". It assumes that 't' has already
// been consumed.
func lexTrue(lx *lexer) stateFn {
if fn := lexConst(lx, "true"); fn != nil {
return fn
}
lx.emit(itemBool)
return lx.pop()
}
// lexFalse consumes the "alse" in "false". It assumes that 'f' has already
// been consumed.
func lexFalse(lx *lexer) stateFn {
if fn := lexConst(lx, "false"); fn != nil {
return fn
}
lx.emit(itemBool)
return lx.pop()
}
// lexCommentStart begins the lexing of a comment. It will emit
// itemCommentStart and consume no characters, passing control to lexComment.
func lexCommentStart(lx *lexer) stateFn {
lx.ignore()
lx.emit(itemCommentStart)
return lexComment
}
// lexComment lexes an entire comment. It assumes that '#' has been consumed.
// It will consume *up to* the first new line character, and pass control
// back to the last state on the stack.
func lexComment(lx *lexer) stateFn {
r := lx.peek()
if isNL(r) || r == eof {
lx.emit(itemText)
return lx.pop()
}
lx.next()
return lexComment
}
// lexSkip ignores all slurped input and moves on to the next state.
func lexSkip(lx *lexer, nextState stateFn) stateFn {
return func(lx *lexer) stateFn {
lx.ignore()
return nextState
}
}
// isWhitespace returns true if `r` is a whitespace character according
// to the spec.
func isWhitespace(r rune) bool {
return r == '\t' || r == ' '
}
func isNL(r rune) bool {
return r == '\n' || r == '\r'
}
func isDigit(r rune) bool {
return r >= '0' && r <= '9'
}
func isHexadecimal(r rune) bool {
return (r >= '0' && r <= '9') ||
(r >= 'a' && r <= 'f') ||
(r >= 'A' && r <= 'F')
}
func isBareKeyChar(r rune) bool {
return (r >= 'A' && r <= 'Z') ||
(r >= 'a' && r <= 'z') ||
(r >= '0' && r <= '9') ||
r == '_' ||
r == '-'
}
func (itype itemType) String() string {
switch itype {
case itemError:
return "Error"
case itemNIL:
return "NIL"
case itemEOF:
return "EOF"
case itemText:
return "Text"
case itemString:
return "String"
case itemRawString:
return "String"
case itemMultilineString:
return "String"
case itemRawMultilineString:
return "String"
case itemBool:
return "Bool"
case itemInteger:
return "Integer"
case itemFloat:
return "Float"
case itemDatetime:
return "DateTime"
case itemTableStart:
return "TableStart"
case itemTableEnd:
return "TableEnd"
case itemKeyStart:
return "KeyStart"
case itemArray:
return "Array"
case itemArrayEnd:
return "ArrayEnd"
case itemCommentStart:
return "CommentStart"
}
panic(fmt.Sprintf("BUG: Unknown type '%d'.", int(itype)))
}
func (item item) String() string {
return fmt.Sprintf("(%s, %s)", item.typ.String(), item.val)
}

View file

@ -0,0 +1,498 @@
package toml
import (
"fmt"
"log"
"strconv"
"strings"
"time"
"unicode"
"unicode/utf8"
)
type parser struct {
mapping map[string]interface{}
types map[string]tomlType
lx *lexer
// A list of keys in the order that they appear in the TOML data.
ordered []Key
// the full key for the current hash in scope
context Key
// the base key name for everything except hashes
currentKey string
// rough approximation of line number
approxLine int
// A map of 'key.group.names' to whether they were created implicitly.
implicits map[string]bool
}
type parseError string
func (pe parseError) Error() string {
return string(pe)
}
func parse(data string) (p *parser, err error) {
defer func() {
if r := recover(); r != nil {
var ok bool
if err, ok = r.(parseError); ok {
return
}
panic(r)
}
}()
p = &parser{
mapping: make(map[string]interface{}),
types: make(map[string]tomlType),
lx: lex(data),
ordered: make([]Key, 0),
implicits: make(map[string]bool),
}
for {
item := p.next()
if item.typ == itemEOF {
break
}
p.topLevel(item)
}
return p, nil
}
func (p *parser) panicf(format string, v ...interface{}) {
msg := fmt.Sprintf("Near line %d (last key parsed '%s'): %s",
p.approxLine, p.current(), fmt.Sprintf(format, v...))
panic(parseError(msg))
}
func (p *parser) next() item {
it := p.lx.nextItem()
if it.typ == itemError {
p.panicf("%s", it.val)
}
return it
}
func (p *parser) bug(format string, v ...interface{}) {
log.Fatalf("BUG: %s\n\n", fmt.Sprintf(format, v...))
}
func (p *parser) expect(typ itemType) item {
it := p.next()
p.assertEqual(typ, it.typ)
return it
}
func (p *parser) assertEqual(expected, got itemType) {
if expected != got {
p.bug("Expected '%s' but got '%s'.", expected, got)
}
}
func (p *parser) topLevel(item item) {
switch item.typ {
case itemCommentStart:
p.approxLine = item.line
p.expect(itemText)
case itemTableStart:
kg := p.next()
p.approxLine = kg.line
var key Key
for ; kg.typ != itemTableEnd && kg.typ != itemEOF; kg = p.next() {
key = append(key, p.keyString(kg))
}
p.assertEqual(itemTableEnd, kg.typ)
p.establishContext(key, false)
p.setType("", tomlHash)
p.ordered = append(p.ordered, key)
case itemArrayTableStart:
kg := p.next()
p.approxLine = kg.line
var key Key
for ; kg.typ != itemArrayTableEnd && kg.typ != itemEOF; kg = p.next() {
key = append(key, p.keyString(kg))
}
p.assertEqual(itemArrayTableEnd, kg.typ)
p.establishContext(key, true)
p.setType("", tomlArrayHash)
p.ordered = append(p.ordered, key)
case itemKeyStart:
kname := p.next()
p.approxLine = kname.line
p.currentKey = p.keyString(kname)
val, typ := p.value(p.next())
p.setValue(p.currentKey, val)
p.setType(p.currentKey, typ)
p.ordered = append(p.ordered, p.context.add(p.currentKey))
p.currentKey = ""
default:
p.bug("Unexpected type at top level: %s", item.typ)
}
}
// Gets a string for a key (or part of a key in a table name).
func (p *parser) keyString(it item) string {
switch it.typ {
case itemText:
return it.val
case itemString, itemMultilineString,
itemRawString, itemRawMultilineString:
s, _ := p.value(it)
return s.(string)
default:
p.bug("Unexpected key type: %s", it.typ)
panic("unreachable")
}
}
// value translates an expected value from the lexer into a Go value wrapped
// as an empty interface.
func (p *parser) value(it item) (interface{}, tomlType) {
switch it.typ {
case itemString:
return p.replaceEscapes(it.val), p.typeOfPrimitive(it)
case itemMultilineString:
trimmed := stripFirstNewline(stripEscapedWhitespace(it.val))
return p.replaceEscapes(trimmed), p.typeOfPrimitive(it)
case itemRawString:
return it.val, p.typeOfPrimitive(it)
case itemRawMultilineString:
return stripFirstNewline(it.val), p.typeOfPrimitive(it)
case itemBool:
switch it.val {
case "true":
return true, p.typeOfPrimitive(it)
case "false":
return false, p.typeOfPrimitive(it)
}
p.bug("Expected boolean value, but got '%s'.", it.val)
case itemInteger:
num, err := strconv.ParseInt(it.val, 10, 64)
if err != nil {
// See comment below for floats describing why we make a
// distinction between a bug and a user error.
if e, ok := err.(*strconv.NumError); ok &&
e.Err == strconv.ErrRange {
p.panicf("Integer '%s' is out of the range of 64-bit "+
"signed integers.", it.val)
} else {
p.bug("Expected integer value, but got '%s'.", it.val)
}
}
return num, p.typeOfPrimitive(it)
case itemFloat:
num, err := strconv.ParseFloat(it.val, 64)
if err != nil {
// Distinguish float values. Normally, it'd be a bug if the lexer
// provides an invalid float, but it's possible that the float is
// out of range of valid values (which the lexer cannot determine).
// So mark the former as a bug but the latter as a legitimate user
// error.
//
// This is also true for integers.
if e, ok := err.(*strconv.NumError); ok &&
e.Err == strconv.ErrRange {
p.panicf("Float '%s' is out of the range of 64-bit "+
"IEEE-754 floating-point numbers.", it.val)
} else {
p.bug("Expected float value, but got '%s'.", it.val)
}
}
return num, p.typeOfPrimitive(it)
case itemDatetime:
t, err := time.Parse("2006-01-02T15:04:05Z", it.val)
if err != nil {
p.bug("Expected Zulu formatted DateTime, but got '%s'.", it.val)
}
return t, p.typeOfPrimitive(it)
case itemArray:
array := make([]interface{}, 0)
types := make([]tomlType, 0)
for it = p.next(); it.typ != itemArrayEnd; it = p.next() {
if it.typ == itemCommentStart {
p.expect(itemText)
continue
}
val, typ := p.value(it)
array = append(array, val)
types = append(types, typ)
}
return array, p.typeOfArray(types)
}
p.bug("Unexpected value type: %s", it.typ)
panic("unreachable")
}
// establishContext sets the current context of the parser,
// where the context is either a hash or an array of hashes. Which one is
// set depends on the value of the `array` parameter.
//
// Establishing the context also makes sure that the key isn't a duplicate, and
// will create implicit hashes automatically.
func (p *parser) establishContext(key Key, array bool) {
var ok bool
// Always start at the top level and drill down for our context.
hashContext := p.mapping
keyContext := make(Key, 0)
// We only need implicit hashes for key[0:-1]
for _, k := range key[0 : len(key)-1] {
_, ok = hashContext[k]
keyContext = append(keyContext, k)
// No key? Make an implicit hash and move on.
if !ok {
p.addImplicit(keyContext)
hashContext[k] = make(map[string]interface{})
}
// If the hash context is actually an array of tables, then set
// the hash context to the last element in that array.
//
// Otherwise, it better be a table, since this MUST be a key group (by
// virtue of it not being the last element in a key).
switch t := hashContext[k].(type) {
case []map[string]interface{}:
hashContext = t[len(t)-1]
case map[string]interface{}:
hashContext = t
default:
p.panicf("Key '%s' was already created as a hash.", keyContext)
}
}
p.context = keyContext
if array {
// If this is the first element for this array, then allocate a new
// list of tables for it.
k := key[len(key)-1]
if _, ok := hashContext[k]; !ok {
hashContext[k] = make([]map[string]interface{}, 0, 5)
}
// Add a new table. But make sure the key hasn't already been used
// for something else.
if hash, ok := hashContext[k].([]map[string]interface{}); ok {
hashContext[k] = append(hash, make(map[string]interface{}))
} else {
p.panicf("Key '%s' was already created and cannot be used as "+
"an array.", keyContext)
}
} else {
p.setValue(key[len(key)-1], make(map[string]interface{}))
}
p.context = append(p.context, key[len(key)-1])
}
// setValue sets the given key to the given value in the current context.
// It will make sure that the key hasn't already been defined, account for
// implicit key groups.
func (p *parser) setValue(key string, value interface{}) {
var tmpHash interface{}
var ok bool
hash := p.mapping
keyContext := make(Key, 0)
for _, k := range p.context {
keyContext = append(keyContext, k)
if tmpHash, ok = hash[k]; !ok {
p.bug("Context for key '%s' has not been established.", keyContext)
}
switch t := tmpHash.(type) {
case []map[string]interface{}:
// The context is a table of hashes. Pick the most recent table
// defined as the current hash.
hash = t[len(t)-1]
case map[string]interface{}:
hash = t
default:
p.bug("Expected hash to have type 'map[string]interface{}', but "+
"it has '%T' instead.", tmpHash)
}
}
keyContext = append(keyContext, key)
if _, ok := hash[key]; ok {
// Typically, if the given key has already been set, then we have
// to raise an error since duplicate keys are disallowed. However,
// it's possible that a key was previously defined implicitly. In this
// case, it is allowed to be redefined concretely. (See the
// `tests/valid/implicit-and-explicit-after.toml` test in `toml-test`.)
//
// But we have to make sure to stop marking it as an implicit. (So that
// another redefinition provokes an error.)
//
// Note that since it has already been defined (as a hash), we don't
// want to overwrite it. So our business is done.
if p.isImplicit(keyContext) {
p.removeImplicit(keyContext)
return
}
// Otherwise, we have a concrete key trying to override a previous
// key, which is *always* wrong.
p.panicf("Key '%s' has already been defined.", keyContext)
}
hash[key] = value
}
// setType sets the type of a particular value at a given key.
// It should be called immediately AFTER setValue.
//
// Note that if `key` is empty, then the type given will be applied to the
// current context (which is either a table or an array of tables).
func (p *parser) setType(key string, typ tomlType) {
keyContext := make(Key, 0, len(p.context)+1)
for _, k := range p.context {
keyContext = append(keyContext, k)
}
if len(key) > 0 { // allow type setting for hashes
keyContext = append(keyContext, key)
}
p.types[keyContext.String()] = typ
}
// addImplicit sets the given Key as having been created implicitly.
func (p *parser) addImplicit(key Key) {
p.implicits[key.String()] = true
}
// removeImplicit stops tagging the given key as having been implicitly
// created.
func (p *parser) removeImplicit(key Key) {
p.implicits[key.String()] = false
}
// isImplicit returns true if the key group pointed to by the key was created
// implicitly.
func (p *parser) isImplicit(key Key) bool {
return p.implicits[key.String()]
}
// current returns the full key name of the current context.
func (p *parser) current() string {
if len(p.currentKey) == 0 {
return p.context.String()
}
if len(p.context) == 0 {
return p.currentKey
}
return fmt.Sprintf("%s.%s", p.context, p.currentKey)
}
func stripFirstNewline(s string) string {
if len(s) == 0 || s[0] != '\n' {
return s
}
return s[1:len(s)]
}
func stripEscapedWhitespace(s string) string {
esc := strings.Split(s, "\\\n")
if len(esc) > 1 {
for i := 1; i < len(esc); i++ {
esc[i] = strings.TrimLeftFunc(esc[i], unicode.IsSpace)
}
}
return strings.Join(esc, "")
}
func (p *parser) replaceEscapes(str string) string {
var replaced []rune
s := []byte(str)
r := 0
for r < len(s) {
if s[r] != '\\' {
c, size := utf8.DecodeRune(s[r:])
r += size
replaced = append(replaced, c)
continue
}
r += 1
if r >= len(s) {
p.bug("Escape sequence at end of string.")
return ""
}
switch s[r] {
default:
p.bug("Expected valid escape code after \\, but got %q.", s[r])
return ""
case 'b':
replaced = append(replaced, rune(0x0008))
r += 1
case 't':
replaced = append(replaced, rune(0x0009))
r += 1
case 'n':
replaced = append(replaced, rune(0x000A))
r += 1
case 'f':
replaced = append(replaced, rune(0x000C))
r += 1
case 'r':
replaced = append(replaced, rune(0x000D))
r += 1
case '"':
replaced = append(replaced, rune(0x0022))
r += 1
case '\\':
replaced = append(replaced, rune(0x005C))
r += 1
case 'u':
// At this point, we know we have a Unicode escape of the form
// `uXXXX` at [r, r+5). (Because the lexer guarantees this
// for us.)
escaped := p.asciiEscapeToUnicode(s[r+1 : r+5])
replaced = append(replaced, escaped)
r += 5
case 'U':
// At this point, we know we have a Unicode escape of the form
// `uXXXX` at [r, r+9). (Because the lexer guarantees this
// for us.)
escaped := p.asciiEscapeToUnicode(s[r+1 : r+9])
replaced = append(replaced, escaped)
r += 9
}
}
return string(replaced)
}
func (p *parser) asciiEscapeToUnicode(bs []byte) rune {
s := string(bs)
hex, err := strconv.ParseUint(strings.ToLower(s), 16, 32)
if err != nil {
p.bug("Could not parse '%s' as a hexadecimal number, but the "+
"lexer claims it's OK: %s", s, err)
}
// BUG(burntsushi)
// I honestly don't understand how this works. I can't seem
// to find a way to make this fail. I figured this would fail on invalid
// UTF-8 characters like U+DCFF, but it doesn't.
if !utf8.ValidString(string(rune(hex))) {
p.panicf("Escaped character '\\u%s' is not valid UTF-8.", s)
}
return rune(hex)
}
func isStringType(ty itemType) bool {
return ty == itemString || ty == itemMultilineString ||
ty == itemRawString || ty == itemRawMultilineString
}

View file

@ -0,0 +1 @@
au BufWritePost *.go silent!make tags > /dev/null 2>&1

View file

@ -0,0 +1,91 @@
package toml
// tomlType represents any Go type that corresponds to a TOML type.
// While the first draft of the TOML spec has a simplistic type system that
// probably doesn't need this level of sophistication, we seem to be militating
// toward adding real composite types.
type tomlType interface {
typeString() string
}
// typeEqual accepts any two types and returns true if they are equal.
func typeEqual(t1, t2 tomlType) bool {
if t1 == nil || t2 == nil {
return false
}
return t1.typeString() == t2.typeString()
}
func typeIsHash(t tomlType) bool {
return typeEqual(t, tomlHash) || typeEqual(t, tomlArrayHash)
}
type tomlBaseType string
func (btype tomlBaseType) typeString() string {
return string(btype)
}
func (btype tomlBaseType) String() string {
return btype.typeString()
}
var (
tomlInteger tomlBaseType = "Integer"
tomlFloat tomlBaseType = "Float"
tomlDatetime tomlBaseType = "Datetime"
tomlString tomlBaseType = "String"
tomlBool tomlBaseType = "Bool"
tomlArray tomlBaseType = "Array"
tomlHash tomlBaseType = "Hash"
tomlArrayHash tomlBaseType = "ArrayHash"
)
// typeOfPrimitive returns a tomlType of any primitive value in TOML.
// Primitive values are: Integer, Float, Datetime, String and Bool.
//
// Passing a lexer item other than the following will cause a BUG message
// to occur: itemString, itemBool, itemInteger, itemFloat, itemDatetime.
func (p *parser) typeOfPrimitive(lexItem item) tomlType {
switch lexItem.typ {
case itemInteger:
return tomlInteger
case itemFloat:
return tomlFloat
case itemDatetime:
return tomlDatetime
case itemString:
return tomlString
case itemMultilineString:
return tomlString
case itemRawString:
return tomlString
case itemRawMultilineString:
return tomlString
case itemBool:
return tomlBool
}
p.bug("Cannot infer primitive type of lex item '%s'.", lexItem)
panic("unreachable")
}
// typeOfArray returns a tomlType for an array given a list of types of its
// values.
//
// In the current spec, if an array is homogeneous, then its type is always
// "Array". If the array is not homogeneous, an error is generated.
func (p *parser) typeOfArray(types []tomlType) tomlType {
// Empty arrays are cool.
if len(types) == 0 {
return tomlArray
}
theType := types[0]
for _, t := range types[1:] {
if !typeEqual(theType, t) {
p.panicf("Array contains values of type '%s' and '%s', but "+
"arrays must be homogeneous.", theType, t)
}
}
return tomlArray
}

View file

@ -0,0 +1,241 @@
package toml
// Struct field handling is adapted from code in encoding/json:
//
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the Go distribution.
import (
"reflect"
"sort"
"sync"
)
// A field represents a single field found in a struct.
type field struct {
name string // the name of the field (`toml` tag included)
tag bool // whether field has a `toml` tag
index []int // represents the depth of an anonymous field
typ reflect.Type // the type of the field
}
// byName sorts field by name, breaking ties with depth,
// then breaking ties with "name came from toml tag", then
// breaking ties with index sequence.
type byName []field
func (x byName) Len() int { return len(x) }
func (x byName) Swap(i, j int) { x[i], x[j] = x[j], x[i] }
func (x byName) Less(i, j int) bool {
if x[i].name != x[j].name {
return x[i].name < x[j].name
}
if len(x[i].index) != len(x[j].index) {
return len(x[i].index) < len(x[j].index)
}
if x[i].tag != x[j].tag {
return x[i].tag
}
return byIndex(x).Less(i, j)
}
// byIndex sorts field by index sequence.
type byIndex []field
func (x byIndex) Len() int { return len(x) }
func (x byIndex) Swap(i, j int) { x[i], x[j] = x[j], x[i] }
func (x byIndex) Less(i, j int) bool {
for k, xik := range x[i].index {
if k >= len(x[j].index) {
return false
}
if xik != x[j].index[k] {
return xik < x[j].index[k]
}
}
return len(x[i].index) < len(x[j].index)
}
// typeFields returns a list of fields that TOML should recognize for the given
// type. The algorithm is breadth-first search over the set of structs to
// include - the top struct and then any reachable anonymous structs.
func typeFields(t reflect.Type) []field {
// Anonymous fields to explore at the current level and the next.
current := []field{}
next := []field{{typ: t}}
// Count of queued names for current level and the next.
count := map[reflect.Type]int{}
nextCount := map[reflect.Type]int{}
// Types already visited at an earlier level.
visited := map[reflect.Type]bool{}
// Fields found.
var fields []field
for len(next) > 0 {
current, next = next, current[:0]
count, nextCount = nextCount, map[reflect.Type]int{}
for _, f := range current {
if visited[f.typ] {
continue
}
visited[f.typ] = true
// Scan f.typ for fields to include.
for i := 0; i < f.typ.NumField(); i++ {
sf := f.typ.Field(i)
if sf.PkgPath != "" { // unexported
continue
}
name := sf.Tag.Get("toml")
if name == "-" {
continue
}
index := make([]int, len(f.index)+1)
copy(index, f.index)
index[len(f.index)] = i
ft := sf.Type
if ft.Name() == "" && ft.Kind() == reflect.Ptr {
// Follow pointer.
ft = ft.Elem()
}
// Record found field and index sequence.
if name != "" || !sf.Anonymous || ft.Kind() != reflect.Struct {
tagged := name != ""
if name == "" {
name = sf.Name
}
fields = append(fields, field{name, tagged, index, ft})
if count[f.typ] > 1 {
// If there were multiple instances, add a second,
// so that the annihilation code will see a duplicate.
// It only cares about the distinction between 1 or 2,
// so don't bother generating any more copies.
fields = append(fields, fields[len(fields)-1])
}
continue
}
// Record new anonymous struct to explore in next round.
nextCount[ft]++
if nextCount[ft] == 1 {
f := field{name: ft.Name(), index: index, typ: ft}
next = append(next, f)
}
}
}
}
sort.Sort(byName(fields))
// Delete all fields that are hidden by the Go rules for embedded fields,
// except that fields with TOML tags are promoted.
// The fields are sorted in primary order of name, secondary order
// of field index length. Loop over names; for each name, delete
// hidden fields by choosing the one dominant field that survives.
out := fields[:0]
for advance, i := 0, 0; i < len(fields); i += advance {
// One iteration per name.
// Find the sequence of fields with the name of this first field.
fi := fields[i]
name := fi.name
for advance = 1; i+advance < len(fields); advance++ {
fj := fields[i+advance]
if fj.name != name {
break
}
}
if advance == 1 { // Only one field with this name
out = append(out, fi)
continue
}
dominant, ok := dominantField(fields[i : i+advance])
if ok {
out = append(out, dominant)
}
}
fields = out
sort.Sort(byIndex(fields))
return fields
}
// dominantField looks through the fields, all of which are known to
// have the same name, to find the single field that dominates the
// others using Go's embedding rules, modified by the presence of
// TOML tags. If there are multiple top-level fields, the boolean
// will be false: This condition is an error in Go and we skip all
// the fields.
func dominantField(fields []field) (field, bool) {
// The fields are sorted in increasing index-length order. The winner
// must therefore be one with the shortest index length. Drop all
// longer entries, which is easy: just truncate the slice.
length := len(fields[0].index)
tagged := -1 // Index of first tagged field.
for i, f := range fields {
if len(f.index) > length {
fields = fields[:i]
break
}
if f.tag {
if tagged >= 0 {
// Multiple tagged fields at the same level: conflict.
// Return no field.
return field{}, false
}
tagged = i
}
}
if tagged >= 0 {
return fields[tagged], true
}
// All remaining fields have the same length. If there's more than one,
// we have a conflict (two fields named "X" at the same level) and we
// return no field.
if len(fields) > 1 {
return field{}, false
}
return fields[0], true
}
var fieldCache struct {
sync.RWMutex
m map[reflect.Type][]field
}
// cachedTypeFields is like typeFields but uses a cache to avoid repeated work.
func cachedTypeFields(t reflect.Type) []field {
fieldCache.RLock()
f := fieldCache.m[t]
fieldCache.RUnlock()
if f != nil {
return f
}
// Compute fields without lock.
// Might duplicate effort but won't hold other computations back.
f = typeFields(t)
if f == nil {
f = []field{}
}
fieldCache.Lock()
if fieldCache.m == nil {
fieldCache.m = map[reflect.Type][]field{}
}
fieldCache.m[t] = f
fieldCache.Unlock()
return f
}

View file

@ -0,0 +1 @@
logrus

View file

@ -0,0 +1,8 @@
language: go
go:
- 1.2
- 1.3
- 1.4
- tip
install:
- go get -t ./...

View file

@ -0,0 +1,21 @@
The MIT License (MIT)
Copyright (c) 2014 Simon Eskildsen
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.

View file

@ -0,0 +1,377 @@
# Logrus <img src="http://i.imgur.com/hTeVwmJ.png" width="40" height="40" alt=":walrus:" class="emoji" title=":walrus:"/>&nbsp;[![Build Status](https://travis-ci.org/Sirupsen/logrus.svg?branch=master)](https://travis-ci.org/Sirupsen/logrus)&nbsp;[![godoc reference](https://godoc.org/github.com/Sirupsen/logrus?status.png)][godoc]
Logrus is a structured logger for Go (golang), completely API compatible with
the standard library logger. [Godoc][godoc]. **Please note the Logrus API is not
yet stable (pre 1.0). Logrus itself is completely stable and has been used in
many large deployments. The core API is unlikely to change much but please
version control your Logrus to make sure you aren't fetching latest `master` on
every build.**
Nicely color-coded in development (when a TTY is attached, otherwise just
plain text):
![Colored](http://i.imgur.com/PY7qMwd.png)
With `log.Formatter = new(logrus.JSONFormatter)`, for easy parsing by logstash
or Splunk:
```json
{"animal":"walrus","level":"info","msg":"A group of walrus emerges from the
ocean","size":10,"time":"2014-03-10 19:57:38.562264131 -0400 EDT"}
{"level":"warning","msg":"The group's number increased tremendously!",
"number":122,"omg":true,"time":"2014-03-10 19:57:38.562471297 -0400 EDT"}
{"animal":"walrus","level":"info","msg":"A giant walrus appears!",
"size":10,"time":"2014-03-10 19:57:38.562500591 -0400 EDT"}
{"animal":"walrus","level":"info","msg":"Tremendously sized cow enters the ocean.",
"size":9,"time":"2014-03-10 19:57:38.562527896 -0400 EDT"}
{"level":"fatal","msg":"The ice breaks!","number":100,"omg":true,
"time":"2014-03-10 19:57:38.562543128 -0400 EDT"}
```
With the default `log.Formatter = new(logrus.TextFormatter)` when a TTY is not
attached, the output is compatible with the
[logfmt](http://godoc.org/github.com/kr/logfmt) format:
```text
time="2014-04-20 15:36:23.830442383 -0400 EDT" level="info" msg="A group of walrus emerges from the ocean" animal="walrus" size=10
time="2014-04-20 15:36:23.830584199 -0400 EDT" level="warning" msg="The group's number increased tremendously!" omg=true number=122
time="2014-04-20 15:36:23.830596521 -0400 EDT" level="info" msg="A giant walrus appears!" animal="walrus" size=10
time="2014-04-20 15:36:23.830611837 -0400 EDT" level="info" msg="Tremendously sized cow enters the ocean." animal="walrus" size=9
time="2014-04-20 15:36:23.830626464 -0400 EDT" level="fatal" msg="The ice breaks!" omg=true number=100
```
#### Example
The simplest way to use Logrus is simply the package-level exported logger:
```go
package main
import (
log "github.com/Sirupsen/logrus"
)
func main() {
log.WithFields(log.Fields{
"animal": "walrus",
}).Info("A walrus appears")
}
```
Note that it's completely api-compatible with the stdlib logger, so you can
replace your `log` imports everywhere with `log "github.com/Sirupsen/logrus"`
and you'll now have the flexibility of Logrus. You can customize it all you
want:
```go
package main
import (
"os"
log "github.com/Sirupsen/logrus"
"github.com/Sirupsen/logrus/hooks/airbrake"
)
func init() {
// Log as JSON instead of the default ASCII formatter.
log.SetFormatter(&log.JSONFormatter{})
// Use the Airbrake hook to report errors that have Error severity or above to
// an exception tracker. You can create custom hooks, see the Hooks section.
log.AddHook(&logrus_airbrake.AirbrakeHook{})
// Output to stderr instead of stdout, could also be a file.
log.SetOutput(os.Stderr)
// Only log the warning severity or above.
log.SetLevel(log.WarnLevel)
}
func main() {
log.WithFields(log.Fields{
"animal": "walrus",
"size": 10,
}).Info("A group of walrus emerges from the ocean")
log.WithFields(log.Fields{
"omg": true,
"number": 122,
}).Warn("The group's number increased tremendously!")
log.WithFields(log.Fields{
"omg": true,
"number": 100,
}).Fatal("The ice breaks!")
}
```
For more advanced usage such as logging to multiple locations from the same
application, you can also create an instance of the `logrus` Logger:
```go
package main
import (
"github.com/Sirupsen/logrus"
)
// Create a new instance of the logger. You can have any number of instances.
var log = logrus.New()
func main() {
// The API for setting attributes is a little different than the package level
// exported logger. See Godoc.
log.Out = os.Stderr
log.WithFields(logrus.Fields{
"animal": "walrus",
"size": 10,
}).Info("A group of walrus emerges from the ocean")
}
```
#### Fields
Logrus encourages careful, structured logging though logging fields instead of
long, unparseable error messages. For example, instead of: `log.Fatalf("Failed
to send event %s to topic %s with key %d")`, you should log the much more
discoverable:
```go
log.WithFields(log.Fields{
"event": event,
"topic": topic,
"key": key,
}).Fatal("Failed to send event")
```
We've found this API forces you to think about logging in a way that produces
much more useful logging messages. We've been in countless situations where just
a single added field to a log statement that was already there would've saved us
hours. The `WithFields` call is optional.
In general, with Logrus using any of the `printf`-family functions should be
seen as a hint you should add a field, however, you can still use the
`printf`-family functions with Logrus.
#### Hooks
You can add hooks for logging levels. For example to send errors to an exception
tracking service on `Error`, `Fatal` and `Panic`, info to StatsD or log to
multiple places simultaneously, e.g. syslog.
```go
// Not the real implementation of the Airbrake hook. Just a simple sample.
import (
log "github.com/Sirupsen/logrus"
)
func init() {
log.AddHook(new(AirbrakeHook))
}
type AirbrakeHook struct{}
// `Fire()` takes the entry that the hook is fired for. `entry.Data[]` contains
// the fields for the entry. See the Fields section of the README.
func (hook *AirbrakeHook) Fire(entry *logrus.Entry) error {
err := airbrake.Notify(entry.Data["error"].(error))
if err != nil {
log.WithFields(log.Fields{
"source": "airbrake",
"endpoint": airbrake.Endpoint,
}).Info("Failed to send error to Airbrake")
}
return nil
}
// `Levels()` returns a slice of `Levels` the hook is fired for.
func (hook *AirbrakeHook) Levels() []log.Level {
return []log.Level{
log.ErrorLevel,
log.FatalLevel,
log.PanicLevel,
}
}
```
Logrus comes with built-in hooks. Add those, or your custom hook, in `init`:
```go
import (
log "github.com/Sirupsen/logrus"
"github.com/Sirupsen/logrus/hooks/airbrake"
"github.com/Sirupsen/logrus/hooks/syslog"
"log/syslog"
)
func init() {
log.AddHook(new(logrus_airbrake.AirbrakeHook))
hook, err := logrus_syslog.NewSyslogHook("udp", "localhost:514", syslog.LOG_INFO, "")
if err != nil {
log.Error("Unable to connect to local syslog daemon")
} else {
log.AddHook(hook)
}
}
```
* [`github.com/Sirupsen/logrus/hooks/airbrake`](https://github.com/Sirupsen/logrus/blob/master/hooks/airbrake/airbrake.go)
Send errors to an exception tracking service compatible with the Airbrake API.
Uses [`airbrake-go`](https://github.com/tobi/airbrake-go) behind the scenes.
* [`github.com/Sirupsen/logrus/hooks/papertrail`](https://github.com/Sirupsen/logrus/blob/master/hooks/papertrail/papertrail.go)
Send errors to the Papertrail hosted logging service via UDP.
* [`github.com/Sirupsen/logrus/hooks/syslog`](https://github.com/Sirupsen/logrus/blob/master/hooks/syslog/syslog.go)
Send errors to remote syslog server.
Uses standard library `log/syslog` behind the scenes.
* [`github.com/nubo/hiprus`](https://github.com/nubo/hiprus)
Send errors to a channel in hipchat.
* [`github.com/sebest/logrusly`](https://github.com/sebest/logrusly)
Send logs to Loggly (https://www.loggly.com/)
* [`github.com/johntdyer/slackrus`](https://github.com/johntdyer/slackrus)
Hook for Slack chat.
* [`github.com/wercker/journalhook`](https://github.com/wercker/journalhook).
Hook for logging to `systemd-journald`.
#### Level logging
Logrus has six logging levels: Debug, Info, Warning, Error, Fatal and Panic.
```go
log.Debug("Useful debugging information.")
log.Info("Something noteworthy happened!")
log.Warn("You should probably take a look at this.")
log.Error("Something failed but I'm not quitting.")
// Calls os.Exit(1) after logging
log.Fatal("Bye.")
// Calls panic() after logging
log.Panic("I'm bailing.")
```
You can set the logging level on a `Logger`, then it will only log entries with
that severity or anything above it:
```go
// Will log anything that is info or above (warn, error, fatal, panic). Default.
log.SetLevel(log.InfoLevel)
```
It may be useful to set `log.Level = logrus.DebugLevel` in a debug or verbose
environment if your application has that.
#### Entries
Besides the fields added with `WithField` or `WithFields` some fields are
automatically added to all logging events:
1. `time`. The timestamp when the entry was created.
2. `msg`. The logging message passed to `{Info,Warn,Error,Fatal,Panic}` after
the `AddFields` call. E.g. `Failed to send event.`
3. `level`. The logging level. E.g. `info`.
#### Environments
Logrus has no notion of environment.
If you wish for hooks and formatters to only be used in specific environments,
you should handle that yourself. For example, if your application has a global
variable `Environment`, which is a string representation of the environment you
could do:
```go
import (
log "github.com/Sirupsen/logrus"
)
init() {
// do something here to set environment depending on an environment variable
// or command-line flag
if Environment == "production" {
log.SetFormatter(logrus.JSONFormatter)
} else {
// The TextFormatter is default, you don't actually have to do this.
log.SetFormatter(logrus.TextFormatter)
}
}
```
This configuration is how `logrus` was intended to be used, but JSON in
production is mostly only useful if you do log aggregation with tools like
Splunk or Logstash.
#### Formatters
The built-in logging formatters are:
* `logrus.TextFormatter`. Logs the event in colors if stdout is a tty, otherwise
without colors.
* *Note:* to force colored output when there is no TTY, set the `ForceColors`
field to `true`. To force no colored output even if there is a TTY set the
`DisableColors` field to `true`
* `logrus.JSONFormatter`. Logs fields as JSON.
Third party logging formatters:
* [`zalgo`](https://github.com/aybabtme/logzalgo): invoking the P͉̫o̳̼̊w̖͈̰͎e̬͔̭͂r͚̼̹̲ ̫͓͉̳͈ō̠͕͖̚f̝͍̠ ͕̲̞͖͑Z̖̫̤̫ͪa͉̬͈̗l͖͎g̳̥o̰̥̅!̣͔̲̻͊̄ ̙̘̦̹̦.
You can define your formatter by implementing the `Formatter` interface,
requiring a `Format` method. `Format` takes an `*Entry`. `entry.Data` is a
`Fields` type (`map[string]interface{}`) with all your fields as well as the
default ones (see Entries section above):
```go
type MyJSONFormatter struct {
}
log.SetFormatter(new(MyJSONFormatter))
func (f *JSONFormatter) Format(entry *Entry) ([]byte, error) {
// Note this doesn't include Time, Level and Message which are available on
// the Entry. Consult `godoc` on information about those fields or read the
// source of the official loggers.
serialized, err := json.Marshal(entry.Data)
if err != nil {
return nil, fmt.Errorf("Failed to marshal fields to JSON, %v", err)
}
return append(serialized, '\n'), nil
}
```
#### Logger as an `io.Writer`
Logrus can be transormed into an `io.Writer`. That writer is the end of an `io.Pipe` and it is your responsibility to close it.
```go
w := logger.Writer()
defer w.Close()
srv := http.Server{
// create a stdlib log.Logger that writes to
// logrus.Logger.
ErrorLog: log.New(w, "", 0),
}
```
Each line written to that writer will be printed the usual way, using formatters
and hooks. The level for those entries is `info`.
#### Rotation
Log rotation is not provided with Logrus. Log rotation should be done by an
external program (like `logrotate(8)`) that can compress and delete old log
entries. It should not be a feature of the application-level logger.
[godoc]: https://godoc.org/github.com/Sirupsen/logrus

Some files were not shown because too many files have changed in this diff Show more