go
/
misc
1
0
Fork 0
misc/crypto/stream.go

121 lines
2.7 KiB
Go

package crypto
import (
"bytes"
"crypto/cipher"
"errors"
"io"
)
// ErrShortCiphertext is an error that has text too short.
var ErrShortCiphertext = errors.New("input too short to be valid ciphertext")
/*
NewCryptoWriter returns block cipher writer.
from https://github.com/acasajus/dkeyczar/blob/master/streams.go
*/
func NewCryptoWriter(bm cipher.BlockMode, sink io.WriteCloser) io.WriteCloser {
return &cryptoWriter{
bm: bm,
sink: sink,
buffer: bytes.NewBuffer(nil),
}
}
type cryptoWriter struct {
bm cipher.BlockMode
buffer *bytes.Buffer
sink io.WriteCloser
count int
}
func (c *cryptoWriter) Write(data []byte) (int, error) {
if _, err := c.buffer.Write(data); err != nil {
return 0, err
}
bL := c.buffer.Len() - c.buffer.Len()%c.bm.BlockSize()
tmp := c.buffer.Next(bL)
c.bm.CryptBlocks(tmp, tmp)
wL := 0
for wL < len(tmp) {
n, err := c.sink.Write(tmp[wL:])
if err != nil {
return 0, err
}
wL += n
}
c.count += wL
return len(data), nil
}
func (c *cryptoWriter) Close() error {
tmp := Pkcs5pad(c.buffer.Next(c.buffer.Len()), c.bm.BlockSize())
c.bm.CryptBlocks(tmp, tmp)
wL := 0
for wL < len(tmp) {
n, err := c.sink.Write(tmp[wL:])
if err != nil {
return err
}
wL += n
}
c.count += wL
return c.sink.Close()
}
func NewCryptoReader(bm cipher.BlockMode, source io.ReadCloser) io.ReadCloser {
return &cryptoReader{
bm: bm,
source: source,
outBuf: bytes.NewBuffer(nil),
inBuf: bytes.NewBuffer(nil),
eof: false,
}
}
type cryptoReader struct {
bm cipher.BlockMode
outBuf *bytes.Buffer
inBuf *bytes.Buffer
source io.ReadCloser
eof bool
}
func (cr *cryptoReader) Read(data []byte) (int, error) {
missing := len(data) - cr.outBuf.Len()
for !cr.eof && missing > 0 {
toRead := missing + cr.bm.BlockSize() + 1 //Always go beyond the required data to be able to unpad when eof'ed
if off := toRead % cr.bm.BlockSize(); off > 0 {
toRead += cr.bm.BlockSize() - off //Make sure we read in multiples of blocksize
}
cr.inBuf.Grow(toRead)
n, err := io.CopyN(cr.inBuf, cr.source, int64(toRead))
if err == io.EOF {
cr.eof = true
} else if err != nil {
return 0, err
}
readBytes := int(n)
if readBytes%cr.bm.BlockSize() > 0 && cr.eof {
return 0, ErrShortCiphertext
}
bytesToDec := readBytes - readBytes%cr.bm.BlockSize()
tmpdata := cr.inBuf.Next(bytesToDec)
cr.bm.CryptBlocks(tmpdata, tmpdata)
if _, err := cr.outBuf.Write(tmpdata); err != nil {
return 0, err
}
if cr.eof {
pad := cr.outBuf.Bytes()[cr.outBuf.Len()-1]
cr.outBuf.Truncate(cr.outBuf.Len() - int(pad))
}
missing = len(data) - cr.outBuf.Len()
}
return cr.outBuf.Read(data)
}
func (cr *cryptoReader) Close() error {
cr.eof = true
return cr.source.Close()
}