package fastcopy import ( "io" "sync" _ "unsafe" // link to io.errInvalidWrite. ) var ( // global pool instance. pool = CopyPool{size: 4096} //go:linkname errInvalidWrite io.errInvalidWrite errInvalidWrite error ) // CopyPool provides a memory pool of byte // buffers for io copies from readers to writers. type CopyPool struct { size int pool sync.Pool } // See CopyPool.Buffer(). func Buffer(sz int) int { return pool.Buffer(sz) } // See CopyPool.CopyN(). func CopyN(dst io.Writer, src io.Reader, n int64) (int64, error) { return pool.CopyN(dst, src, n) } // See CopyPool.Copy(). func Copy(dst io.Writer, src io.Reader) (int64, error) { return pool.Copy(dst, src) } // Buffer sets the pool buffer size to allocate. Returns current size. // Note this is NOT atomically safe, please call BEFORE other calls to CopyPool. func (cp *CopyPool) Buffer(sz int) int { if sz > 0 { // update size cp.size = sz } else if cp.size < 1 { // default size return 4096 } return cp.size } // CopyN performs the same logic as io.CopyN(), with the difference // being that the byte buffer is acquired from a memory pool. func (cp *CopyPool) CopyN(dst io.Writer, src io.Reader, n int64) (int64, error) { written, err := cp.Copy(dst, io.LimitReader(src, n)) if written == n { return n, nil } if written < n && err == nil { // src stopped early; must have been EOF. err = io.EOF } return written, err } // Copy performs the same logic as io.Copy(), with the difference // being that the byte buffer is acquired from a memory pool. func (cp *CopyPool) Copy(dst io.Writer, src io.Reader) (int64, error) { // Prefer using io.WriterTo to do the copy (avoids alloc + copy) if wt, ok := src.(io.WriterTo); ok { return wt.WriteTo(dst) } // Prefer using io.ReaderFrom to do the copy. if rt, ok := dst.(io.ReaderFrom); ok { return rt.ReadFrom(src) } var buf []byte if b, ok := cp.pool.Get().(*[]byte); ok { // Acquired buf from pool buf = *b } else { // Allocate new buffer of size buf = make([]byte, cp.Buffer(0)) } // Defer release to pool defer cp.pool.Put(&buf) var n int64 for { // Perform next read into buf nr, err := src.Read(buf) if nr > 0 { // We error check AFTER checking // no. read bytes so incomplete // read still gets written up to nr. // Perform next write from buf nw, ew := dst.Write(buf[0:nr]) // Check for valid write if nw < 0 || nr < nw { if ew == nil { ew = errInvalidWrite } return n, ew } // Incr total count n += int64(nw) // Check write error if ew != nil { return n, ew } // Check unequal read/writes if nr != nw { return n, io.ErrShortWrite } } // Return on err if err != nil { if err == io.EOF { err = nil // expected } return n, err } } }