mirror of
https://github.com/encounter/go-rtmp.git
synced 2026-03-30 11:12:49 -07:00
Merge pull request #25 from yutopp/feature/fix_bench_and_improve_performace
Use io.Reader instead of []byte for all cases
This commit is contained in:
+25
-9
@@ -38,6 +38,9 @@ type ChunkStreamer struct {
|
||||
|
||||
writerSched *chunkStreamerWriterSched
|
||||
|
||||
msgDec *message.Decoder
|
||||
msgEnc *message.Encoder
|
||||
|
||||
selfState *StreamControlState
|
||||
peerState *StreamControlState
|
||||
|
||||
@@ -71,6 +74,9 @@ func NewChunkStreamer(r io.Reader, w io.Writer, config *StreamControlStateConfig
|
||||
stopCh: make(chan struct{}),
|
||||
},
|
||||
|
||||
msgDec: message.NewDecoder(nil),
|
||||
msgEnc: message.NewEncoder(nil),
|
||||
|
||||
selfState: NewStreamControlState(config),
|
||||
peerState: NewStreamControlState(config),
|
||||
|
||||
@@ -85,21 +91,31 @@ func NewChunkStreamer(r io.Reader, w io.Writer, config *StreamControlStateConfig
|
||||
return cs
|
||||
}
|
||||
|
||||
func (cs *ChunkStreamer) Read(cmsg *ChunkMessage) (int, uint32, error) {
|
||||
func (cs *ChunkStreamer) Read(cmsg *ChunkMessage) (
|
||||
chunkStreamID int,
|
||||
timestamp uint32,
|
||||
closer func(),
|
||||
err error,
|
||||
) {
|
||||
reader, err := cs.NewChunkReader()
|
||||
if err != nil {
|
||||
return 0, 0, err
|
||||
return 0, 0, nil, err
|
||||
}
|
||||
defer reader.Close()
|
||||
closer = func() { reader.Close() }
|
||||
defer func() {
|
||||
if err != nil {
|
||||
closer()
|
||||
}
|
||||
}()
|
||||
|
||||
dec := message.NewDecoder(reader, message.TypeID(reader.messageTypeID))
|
||||
if err := dec.Decode(&cmsg.Message); err != nil {
|
||||
return 0, 0, err
|
||||
cs.msgDec.Reset(reader)
|
||||
if err := cs.msgDec.Decode(message.TypeID(reader.messageTypeID), &cmsg.Message); err != nil {
|
||||
return 0, 0, nil, err
|
||||
}
|
||||
|
||||
cmsg.StreamID = reader.messageStreamID
|
||||
|
||||
return reader.basicHeader.chunkStreamID, uint32(reader.timestamp), nil
|
||||
return reader.basicHeader.chunkStreamID, uint32(reader.timestamp), closer, nil
|
||||
}
|
||||
|
||||
func (cs *ChunkStreamer) Write(
|
||||
@@ -114,8 +130,8 @@ func (cs *ChunkStreamer) Write(
|
||||
}
|
||||
//defer writer.Close()
|
||||
|
||||
enc := message.NewEncoder(writer)
|
||||
if err := enc.Encode(cmsg.Message); err != nil {
|
||||
cs.msgEnc.Reset(writer)
|
||||
if err := cs.msgEnc.Encode(cmsg.Message); err != nil {
|
||||
return err
|
||||
}
|
||||
writer.timestamp = timestamp
|
||||
|
||||
+18
-10
@@ -30,8 +30,9 @@ func TestStreamerSingleChunk(t *testing.T) {
|
||||
streamer := NewChunkStreamer(inbuf, outbuf, nil)
|
||||
|
||||
chunkStreamID := 2
|
||||
videoContent := []byte("testtesttest")
|
||||
msg := &message.VideoMessage{
|
||||
Payload: []byte("testtesttest"),
|
||||
Payload: bytes.NewReader(videoContent),
|
||||
}
|
||||
timestamp := uint32(72)
|
||||
|
||||
@@ -59,14 +60,17 @@ func TestStreamerSingleChunk(t *testing.T) {
|
||||
assert.NotNil(t, r)
|
||||
defer r.Close()
|
||||
|
||||
dec := message.NewDecoder(r, message.TypeID(r.messageTypeID))
|
||||
dec := message.NewDecoder(r)
|
||||
var actualMsg message.Message
|
||||
err = dec.Decode(&actualMsg)
|
||||
err = dec.Decode(message.TypeID(r.messageTypeID), &actualMsg)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, uint64(timestamp), r.timestamp)
|
||||
|
||||
// check message
|
||||
assert.Equal(t, msg, actualMsg)
|
||||
assert.Equal(t, actualMsg.TypeID(), msg.TypeID())
|
||||
actualMsgT := actualMsg.(*message.VideoMessage)
|
||||
actualContent, _ := ioutil.ReadAll(actualMsgT.Payload)
|
||||
assert.Equal(t, actualContent, videoContent)
|
||||
}
|
||||
|
||||
func TestStreamerMultipleChunk(t *testing.T) {
|
||||
@@ -80,9 +84,10 @@ func TestStreamerMultipleChunk(t *testing.T) {
|
||||
streamer := NewChunkStreamer(inbuf, outbuf, nil)
|
||||
|
||||
chunkStreamID := 2
|
||||
videoContent := []byte(strings.Repeat(payloadUnit, chunkSize))
|
||||
msg := &message.VideoMessage{
|
||||
// will be chunked (chunkSize * len(payloadUnit))
|
||||
Payload: []byte(strings.Repeat(payloadUnit, chunkSize)),
|
||||
Payload: bytes.NewReader(videoContent),
|
||||
}
|
||||
timestamp := uint32(72)
|
||||
|
||||
@@ -112,14 +117,17 @@ func TestStreamerMultipleChunk(t *testing.T) {
|
||||
assert.NotNil(t, r)
|
||||
defer r.Close()
|
||||
|
||||
dec := message.NewDecoder(r, message.TypeID(r.messageTypeID))
|
||||
dec := message.NewDecoder(r)
|
||||
var actualMsg message.Message
|
||||
err = dec.Decode(&actualMsg)
|
||||
err = dec.Decode(message.TypeID(r.messageTypeID), &actualMsg)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, uint64(timestamp), r.timestamp)
|
||||
|
||||
// check message
|
||||
assert.Equal(t, msg, actualMsg)
|
||||
assert.Equal(t, actualMsg.TypeID(), msg.TypeID())
|
||||
actualMsgT := actualMsg.(*message.VideoMessage)
|
||||
actualContent, _ := ioutil.ReadAll(actualMsgT.Payload)
|
||||
assert.Equal(t, actualContent, videoContent)
|
||||
}
|
||||
|
||||
func TestStreamerChunkExample1(t *testing.T) {
|
||||
@@ -332,7 +340,7 @@ func TestChunkStreamerDualWriter(t *testing.T) {
|
||||
err := streamer.Write(context.Background(), chunkStreamID, timestamp, &ChunkMessage{
|
||||
StreamID: 0,
|
||||
Message: &message.VideoMessage{
|
||||
Payload: largePayload,
|
||||
Payload: bytes.NewReader(largePayload),
|
||||
},
|
||||
})
|
||||
assert.Nil(t, err)
|
||||
@@ -366,7 +374,7 @@ func TestChunkStreamerDualWriterWithoutWaiting(t *testing.T) {
|
||||
err := streamer.Write(context.Background(), chunkStreamID, timestamp, &ChunkMessage{
|
||||
StreamID: 0,
|
||||
Message: &message.VideoMessage{
|
||||
Payload: largePayload,
|
||||
Payload: bytes.NewReader(largePayload),
|
||||
},
|
||||
})
|
||||
assert.Nil(t, err)
|
||||
|
||||
@@ -183,19 +183,21 @@ func (c *Conn) runHandleMessageLoop() error {
|
||||
return c.streamer.Err()
|
||||
|
||||
default:
|
||||
chunkStreamID, timestamp, err := c.streamer.Read(&cmsg)
|
||||
chunkStreamID, timestamp, closer, err := c.streamer.Read(&cmsg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := c.dispatchStreamHandler(chunkStreamID, timestamp, &cmsg); err != nil {
|
||||
if err := c.handleMessage(chunkStreamID, timestamp, closer, &cmsg); err != nil {
|
||||
return err // Shutdown the connection
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Conn) dispatchStreamHandler(chunkStreamID int, timestamp uint32, cmsg *ChunkMessage) error {
|
||||
func (c *Conn) handleMessage(chunkStreamID int, timestamp uint32, closer func(), cmsg *ChunkMessage) error {
|
||||
defer closer()
|
||||
|
||||
stream, err := c.streams.At(cmsg.StreamID)
|
||||
if err != nil {
|
||||
if c.config.IgnoreMessagesOnNotExistStream {
|
||||
|
||||
+3
-2
@@ -9,6 +9,7 @@ package rtmp
|
||||
|
||||
import (
|
||||
"github.com/yutopp/go-rtmp/message"
|
||||
"io"
|
||||
)
|
||||
|
||||
var _ Handler = (*DefaultHandler)(nil)
|
||||
@@ -55,11 +56,11 @@ func (h *DefaultHandler) OnSetDataFrame(timestamp uint32, data *message.NetStrea
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *DefaultHandler) OnAudio(timestamp uint32, payload []byte) error {
|
||||
func (h *DefaultHandler) OnAudio(timestamp uint32, payload io.Reader) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *DefaultHandler) OnVideo(timestamp uint32, payload []byte) error {
|
||||
func (h *DefaultHandler) OnVideo(timestamp uint32, payload io.Reader) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
flvtag "github.com/yutopp/go-flv/tag"
|
||||
"github.com/yutopp/go-rtmp"
|
||||
rtmpmsg "github.com/yutopp/go-rtmp/message"
|
||||
"io"
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
@@ -80,11 +81,9 @@ func (h *Handler) OnSetDataFrame(timestamp uint32, data *rtmpmsg.NetStreamSetDat
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *Handler) OnAudio(timestamp uint32, payload []byte) error {
|
||||
r := bytes.NewReader(payload)
|
||||
|
||||
func (h *Handler) OnAudio(timestamp uint32, payload io.Reader) error {
|
||||
var audio flvtag.AudioData
|
||||
if err := flvtag.DecodeAudioData(r, &audio); err != nil {
|
||||
if err := flvtag.DecodeAudioData(payload, &audio); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -109,11 +108,9 @@ func (h *Handler) OnAudio(timestamp uint32, payload []byte) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *Handler) OnVideo(timestamp uint32, payload []byte) error {
|
||||
r := bytes.NewReader(payload)
|
||||
|
||||
func (h *Handler) OnVideo(timestamp uint32, payload io.Reader) error {
|
||||
var video flvtag.VideoData
|
||||
if err := flvtag.DecodeVideoData(r, &video); err != nil {
|
||||
if err := flvtag.DecodeVideoData(payload, &video); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
||||
+3
-2
@@ -9,6 +9,7 @@ package rtmp
|
||||
|
||||
import (
|
||||
"github.com/yutopp/go-rtmp/message"
|
||||
"io"
|
||||
)
|
||||
|
||||
type Handler interface {
|
||||
@@ -22,8 +23,8 @@ type Handler interface {
|
||||
OnFCPublish(timestamp uint32, cmd *message.NetStreamFCPublish) error
|
||||
OnFCUnpublish(timestamp uint32, cmd *message.NetStreamFCUnpublish) error
|
||||
OnSetDataFrame(timestamp uint32, data *message.NetStreamSetDataFrame) error
|
||||
OnAudio(timestamp uint32, payload []byte) error
|
||||
OnVideo(timestamp uint32, payload []byte) error
|
||||
OnAudio(timestamp uint32, payload io.Reader) error
|
||||
OnVideo(timestamp uint32, payload io.Reader) error
|
||||
OnUnknownMessage(timestamp uint32, msg message.Message) error
|
||||
OnUnknownCommandMessage(timestamp uint32, cmd *message.CommandMessage) error
|
||||
OnUnknownDataMessage(timestamp uint32, data *message.DataMessage) error
|
||||
|
||||
@@ -7,6 +7,10 @@
|
||||
|
||||
package message
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
)
|
||||
|
||||
type testCase struct {
|
||||
Name string
|
||||
TypeID
|
||||
@@ -78,7 +82,7 @@ var testCases = []testCase{
|
||||
Name: "AudioMessage",
|
||||
TypeID: TypeIDAudioMessage,
|
||||
Value: &AudioMessage{
|
||||
Payload: []byte("audio data"),
|
||||
Payload: bytes.NewReader([]byte("audio data")),
|
||||
},
|
||||
Binary: []byte("audio data"),
|
||||
},
|
||||
@@ -86,7 +90,7 @@ var testCases = []testCase{
|
||||
Name: "VideoMessage",
|
||||
TypeID: TypeIDVideoMessage,
|
||||
Value: &VideoMessage{
|
||||
Payload: []byte("video data"),
|
||||
Payload: bytes.NewReader([]byte("video data")),
|
||||
},
|
||||
Binary: []byte("video data"),
|
||||
},
|
||||
@@ -97,7 +101,7 @@ var testCases = []testCase{
|
||||
Value: &DataMessage{
|
||||
Name: "test",
|
||||
Encoding: EncodingTypeAMF0,
|
||||
Body: []byte("test"),
|
||||
Body: bytes.NewReader([]byte("test")),
|
||||
},
|
||||
Binary: []byte{
|
||||
// Name: AMF0 / string marker
|
||||
@@ -120,7 +124,7 @@ var testCases = []testCase{
|
||||
CommandName: "_result",
|
||||
TransactionID: 10,
|
||||
Encoding: EncodingTypeAMF0,
|
||||
Body: []byte("test"),
|
||||
Body: bytes.NewReader([]byte("test")),
|
||||
},
|
||||
Binary: []byte{
|
||||
// CommandName: AMF0 / string marker
|
||||
|
||||
+20
-69
@@ -8,7 +8,6 @@
|
||||
package message
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"github.com/pkg/errors"
|
||||
@@ -18,21 +17,21 @@ import (
|
||||
)
|
||||
|
||||
type Decoder struct {
|
||||
r io.Reader
|
||||
typeID TypeID
|
||||
|
||||
cacheBuffer bytes.Buffer
|
||||
r io.Reader
|
||||
}
|
||||
|
||||
func NewDecoder(r io.Reader, typeID TypeID) *Decoder {
|
||||
func NewDecoder(r io.Reader) *Decoder {
|
||||
return &Decoder{
|
||||
r: r,
|
||||
typeID: typeID,
|
||||
r: r,
|
||||
}
|
||||
}
|
||||
|
||||
func (dec *Decoder) Decode(msg *Message) error {
|
||||
switch dec.typeID {
|
||||
func (dec *Decoder) Reset(r io.Reader) {
|
||||
dec.r = r
|
||||
}
|
||||
|
||||
func (dec *Decoder) Decode(typeID TypeID, msg *Message) error {
|
||||
switch typeID {
|
||||
case TypeIDSetChunkSize:
|
||||
return dec.decodeSetChunkSize(msg)
|
||||
case TypeIDAbortMessage:
|
||||
@@ -64,7 +63,7 @@ func (dec *Decoder) Decode(msg *Message) error {
|
||||
case TypeIDAggregateMessage:
|
||||
return dec.decodeAggregateMessage(msg)
|
||||
default:
|
||||
return fmt.Errorf("Unexpected message type(decode): ID = %d", dec.typeID)
|
||||
return fmt.Errorf("Unexpected message type(decode): ID = %d", typeID)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -172,40 +171,16 @@ func (dec *Decoder) decodeSetPeerBandwidth(msg *Message) error {
|
||||
}
|
||||
|
||||
func (dec *Decoder) decodeAudioMessage(msg *Message) error {
|
||||
buf := &dec.cacheBuffer // TODO: Provide thread safety if needed
|
||||
buf.Reset()
|
||||
|
||||
_, err := io.Copy(buf, dec.r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Copy ownership
|
||||
bin := make([]byte, len(buf.Bytes()))
|
||||
copy(bin, buf.Bytes())
|
||||
|
||||
*msg = &AudioMessage{
|
||||
Payload: bin,
|
||||
Payload: dec.r, // Share an ownership of the reader
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (dec *Decoder) decodeVideoMessage(msg *Message) error {
|
||||
buf := &dec.cacheBuffer // TODO: Provide thread safety if needed
|
||||
buf.Reset()
|
||||
|
||||
_, err := io.Copy(buf, dec.r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Copy ownership
|
||||
bin := make([]byte, len(buf.Bytes()))
|
||||
copy(bin, buf.Bytes())
|
||||
|
||||
*msg = &VideoMessage{
|
||||
Payload: bin,
|
||||
Payload: dec.r, // Share an ownership of the reader
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -224,7 +199,7 @@ func (dec *Decoder) decodeCommandMessageAMF3(msg *Message) error {
|
||||
}
|
||||
|
||||
func (dec *Decoder) decodeDataMessageAMF0(msg *Message) error {
|
||||
if err := dec.decodeDataMessage(dec.r, msg, func(r io.Reader) (AMFDecoder, EncodingType) {
|
||||
if err := dec.decodeDataMessage(msg, func(r io.Reader) (AMFDecoder, EncodingType) {
|
||||
return amf0.NewDecoder(r), EncodingTypeAMF0
|
||||
}); err != nil {
|
||||
return err
|
||||
@@ -238,7 +213,7 @@ func (dec *Decoder) decodeSharedObjectMessageAMF0(msg *Message) error {
|
||||
}
|
||||
|
||||
func (dec *Decoder) decodeCommandMessageAMF0(msg *Message) error {
|
||||
if err := dec.decodeCommandMessage(dec.r, msg, func(r io.Reader) (AMFDecoder, EncodingType) {
|
||||
if err := dec.decodeCommandMessage(msg, func(r io.Reader) (AMFDecoder, EncodingType) {
|
||||
return amf0.NewDecoder(r), EncodingTypeAMF0
|
||||
}); err != nil {
|
||||
return err
|
||||
@@ -251,37 +226,25 @@ func (dec *Decoder) decodeAggregateMessage(msg *Message) error {
|
||||
return fmt.Errorf("Not implemented: AggregateMessage")
|
||||
}
|
||||
|
||||
func (dec *Decoder) decodeDataMessage(r io.Reader, msg *Message, f func(r io.Reader) (AMFDecoder, EncodingType)) error {
|
||||
d, encTy := f(r)
|
||||
func (dec *Decoder) decodeDataMessage(msg *Message, f func(r io.Reader) (AMFDecoder, EncodingType)) error {
|
||||
d, encTy := f(dec.r)
|
||||
|
||||
var name string
|
||||
if err := d.Decode(&name); err != nil {
|
||||
return errors.Wrap(err, "Failed to decode name")
|
||||
}
|
||||
|
||||
buf := &dec.cacheBuffer // TODO: Provide thread safety if needed
|
||||
buf.Reset()
|
||||
|
||||
_, err := io.Copy(buf, dec.r)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "Failed to copy body payload")
|
||||
}
|
||||
|
||||
// Copy ownership
|
||||
bin := make([]byte, len(buf.Bytes()))
|
||||
copy(bin, buf.Bytes())
|
||||
|
||||
*msg = &DataMessage{
|
||||
Name: name,
|
||||
Encoding: encTy,
|
||||
Body: bin,
|
||||
Body: dec.r, // Share an ownership of the reader
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (dec *Decoder) decodeCommandMessage(r io.Reader, msg *Message, f func(r io.Reader) (AMFDecoder, EncodingType)) error {
|
||||
d, encTy := f(r)
|
||||
func (dec *Decoder) decodeCommandMessage(msg *Message, f func(r io.Reader) (AMFDecoder, EncodingType)) error {
|
||||
d, encTy := f(dec.r)
|
||||
|
||||
var name string
|
||||
if err := d.Decode(&name); err != nil {
|
||||
@@ -293,23 +256,11 @@ func (dec *Decoder) decodeCommandMessage(r io.Reader, msg *Message, f func(r io.
|
||||
return errors.Wrap(err, "Failed to decode transactionID")
|
||||
}
|
||||
|
||||
buf := &dec.cacheBuffer // TODO: Provide thread safety if needed
|
||||
buf.Reset()
|
||||
|
||||
_, err := io.Copy(buf, dec.r)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "Failed to copy body payload")
|
||||
}
|
||||
|
||||
// Copy ownership
|
||||
bin := make([]byte, len(buf.Bytes()))
|
||||
copy(bin, buf.Bytes())
|
||||
|
||||
*msg = &CommandMessage{
|
||||
CommandName: name,
|
||||
TransactionID: transactionID,
|
||||
Encoding: encTy,
|
||||
Body: bin,
|
||||
Body: dec.r, // Share an ownership of the reader
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
+22
-15
@@ -21,30 +21,37 @@ func TestDecodeCommon(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
buf := bytes.NewReader(tc.Binary)
|
||||
dec := NewDecoder(buf, tc.TypeID)
|
||||
dec := NewDecoder(buf)
|
||||
|
||||
var msg Message
|
||||
err := dec.Decode(&msg)
|
||||
err := dec.Decode(tc.TypeID, &msg)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, tc.Value, msg)
|
||||
assertEqualMessage(t, tc.Value, msg)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkDecodeVideoMessage(b *testing.B) {
|
||||
buf := new(bytes.Buffer)
|
||||
for i := 0; i < 1024; i++ {
|
||||
buf.WriteString("abcde")
|
||||
}
|
||||
if buf.Len() != 5*1024 {
|
||||
b.Fatalf("Buffer becomes unexpected state: Len = %d", buf.Len())
|
||||
func BenchmarkDecode5KBVideoMessage(b *testing.B) {
|
||||
sizes := []struct {
|
||||
name string
|
||||
len int
|
||||
}{
|
||||
{"5KB", 5 * 1024},
|
||||
{"2MB", 2 * 1024 * 1024},
|
||||
}
|
||||
for _, size := range sizes {
|
||||
b.Run(size.name, func(b *testing.B) {
|
||||
buf := make([]byte, size.len)
|
||||
r := bytes.NewReader(buf)
|
||||
dec := NewDecoder(r)
|
||||
|
||||
dec := NewDecoder(buf, TypeIDVideoMessage)
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
r.Reset(buf)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
var msg Message
|
||||
dec.Decode(&msg)
|
||||
var msg Message
|
||||
dec.Decode(TypeIDVideoMessage, &msg)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
+8
-5
@@ -8,7 +8,6 @@
|
||||
package message
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -24,6 +23,10 @@ func NewEncoder(w io.Writer) *Encoder {
|
||||
}
|
||||
}
|
||||
|
||||
func (enc *Encoder) Reset(w io.Writer) {
|
||||
enc.w = w
|
||||
}
|
||||
|
||||
// Encode
|
||||
func (enc *Encoder) Encode(msg Message) error {
|
||||
switch msg := msg.(type) {
|
||||
@@ -124,7 +127,7 @@ func (enc *Encoder) encodeSetPeerBandwidth(m *SetPeerBandwidth) error {
|
||||
}
|
||||
|
||||
func (enc *Encoder) encodeAudioMessage(m *AudioMessage) error {
|
||||
if _, err := enc.w.Write(m.Payload); err != nil {
|
||||
if _, err := io.Copy(enc.w, m.Payload); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -132,7 +135,7 @@ func (enc *Encoder) encodeAudioMessage(m *AudioMessage) error {
|
||||
}
|
||||
|
||||
func (enc *Encoder) encodeVideoMessage(m *VideoMessage) error {
|
||||
if _, err := enc.w.Write(m.Payload); err != nil {
|
||||
if _, err := io.Copy(enc.w, m.Payload); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -150,7 +153,7 @@ func (enc *Encoder) encodeDataMessage(m *DataMessage) error {
|
||||
return err
|
||||
}
|
||||
|
||||
if _, err := io.Copy(enc.w, bytes.NewReader(m.Body)); err != nil {
|
||||
if _, err := io.Copy(enc.w, m.Body); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -171,7 +174,7 @@ func (enc *Encoder) encodeCommandMessage(m *CommandMessage) error {
|
||||
return err
|
||||
}
|
||||
|
||||
if _, err := io.Copy(enc.w, bytes.NewReader(m.Body)); err != nil {
|
||||
if _, err := io.Copy(enc.w, m.Body); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
||||
+8
-4
@@ -7,6 +7,10 @@
|
||||
|
||||
package message
|
||||
|
||||
import (
|
||||
"io"
|
||||
)
|
||||
|
||||
type TypeID byte
|
||||
|
||||
const (
|
||||
@@ -97,7 +101,7 @@ func (m *SetPeerBandwidth) TypeID() TypeID {
|
||||
|
||||
// AudioMessage(8)
|
||||
type AudioMessage struct {
|
||||
Payload []byte
|
||||
Payload io.Reader
|
||||
}
|
||||
|
||||
func (m *AudioMessage) TypeID() TypeID {
|
||||
@@ -106,7 +110,7 @@ func (m *AudioMessage) TypeID() TypeID {
|
||||
|
||||
// VideoMessage(9)
|
||||
type VideoMessage struct {
|
||||
Payload []byte
|
||||
Payload io.Reader
|
||||
}
|
||||
|
||||
func (m *VideoMessage) TypeID() TypeID {
|
||||
@@ -117,7 +121,7 @@ func (m *VideoMessage) TypeID() TypeID {
|
||||
type DataMessage struct {
|
||||
Name string
|
||||
Encoding EncodingType
|
||||
Body []byte
|
||||
Body io.Reader
|
||||
}
|
||||
|
||||
func (m *DataMessage) TypeID() TypeID {
|
||||
@@ -156,7 +160,7 @@ type CommandMessage struct {
|
||||
CommandName string
|
||||
TransactionID int64
|
||||
Encoding EncodingType
|
||||
Body []byte
|
||||
Body io.Reader
|
||||
}
|
||||
|
||||
func (m *CommandMessage) TypeID() TypeID {
|
||||
|
||||
@@ -0,0 +1,72 @@
|
||||
//
|
||||
// Copyright (c) 2018- yutopp (yutopp@gmail.com)
|
||||
//
|
||||
// Distributed under the Boost Software License, Version 1.0. (See accompanying
|
||||
// file LICENSE_1_0.txt or copy at https://www.boost.org/LICENSE_1_0.txt)
|
||||
//
|
||||
|
||||
package message
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func assertEqualMessage(t *testing.T, expected, actual Message) {
|
||||
assert.Equal(t, expected.TypeID(), actual.TypeID())
|
||||
|
||||
switch expected := expected.(type) {
|
||||
case *AudioMessage:
|
||||
actual, ok := actual.(*AudioMessage)
|
||||
assert.True(t, ok)
|
||||
|
||||
assertEqualPayload(t, expected.Payload, actual.Payload)
|
||||
|
||||
case *VideoMessage:
|
||||
actual, ok := actual.(*VideoMessage)
|
||||
assert.True(t, ok)
|
||||
|
||||
assertEqualPayload(t, expected.Payload, actual.Payload)
|
||||
|
||||
case *DataMessage:
|
||||
actual, ok := actual.(*DataMessage)
|
||||
assert.True(t, ok)
|
||||
|
||||
assert.Equal(t, expected.Name, actual.Name)
|
||||
assert.Equal(t, expected.Encoding, actual.Encoding)
|
||||
assertEqualPayload(t, expected.Body, actual.Body)
|
||||
|
||||
case *CommandMessage:
|
||||
actual, ok := actual.(*CommandMessage)
|
||||
assert.True(t, ok)
|
||||
|
||||
assert.Equal(t, expected.CommandName, actual.CommandName)
|
||||
assert.Equal(t, expected.TransactionID, actual.TransactionID)
|
||||
assert.Equal(t, expected.Encoding, actual.Encoding)
|
||||
assertEqualPayload(t, expected.Body, actual.Body)
|
||||
|
||||
default:
|
||||
assert.Equal(t, expected, actual)
|
||||
}
|
||||
}
|
||||
|
||||
func assertEqualPayload(t *testing.T, expected, actual io.Reader) {
|
||||
expectedBin, err := ioutil.ReadAll(expected)
|
||||
assert.Nil(t, err)
|
||||
switch p := expected.(type) {
|
||||
case *bytes.Reader:
|
||||
defer p.Seek(0, io.SeekStart) // Restore test case states
|
||||
default:
|
||||
t.FailNow()
|
||||
}
|
||||
assert.NotZero(t, len(expectedBin))
|
||||
|
||||
actualBin, err := ioutil.ReadAll(actual)
|
||||
assert.Nil(t, err)
|
||||
assert.NotZero(t, len(actualBin))
|
||||
|
||||
assert.Equal(t, expectedBin, actualBin)
|
||||
}
|
||||
@@ -79,11 +79,10 @@ func (s *Stream) Connect() (*message.NetConnectionConnectResult, error) {
|
||||
// TODO: check result
|
||||
select {
|
||||
case <-t.doneCh:
|
||||
r := bytes.NewReader(t.body)
|
||||
amfDec := message.NewAMFDecoder(r, t.encoding)
|
||||
amfDec := message.NewAMFDecoder(t.body, t.encoding)
|
||||
|
||||
var value message.AMFConvertible
|
||||
if err := message.DecodeBodyConnectResult(r, amfDec, &value); err != nil {
|
||||
if err := message.DecodeBodyConnectResult(t.body, amfDec, &value); err != nil {
|
||||
return nil, errors.Wrap(err, "Failed to decode result")
|
||||
}
|
||||
result := value.(*message.NetConnectionConnectResult)
|
||||
@@ -144,11 +143,10 @@ func (s *Stream) CreateStream() (*message.NetConnectionCreateStreamResult, error
|
||||
// TODO: check result
|
||||
select {
|
||||
case <-t.doneCh:
|
||||
r := bytes.NewReader(t.body)
|
||||
amfDec := message.NewAMFDecoder(r, t.encoding)
|
||||
amfDec := message.NewAMFDecoder(t.body, t.encoding)
|
||||
|
||||
var value message.AMFConvertible
|
||||
if err := message.DecodeBodyCreateStreamResult(r, amfDec, &value); err != nil {
|
||||
if err := message.DecodeBodyCreateStreamResult(t.body, amfDec, &value); err != nil {
|
||||
return nil, errors.Wrap(err, "Failed to decode result")
|
||||
}
|
||||
result := value.(*message.NetConnectionCreateStreamResult)
|
||||
@@ -222,7 +220,7 @@ func (s *Stream) writeCommandMessage(
|
||||
CommandName: commandName,
|
||||
TransactionID: transactionID,
|
||||
Encoding: s.encTy,
|
||||
Body: buf.Bytes(),
|
||||
Body: buf,
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
+5
-11
@@ -8,7 +8,6 @@
|
||||
package rtmp
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/sirupsen/logrus"
|
||||
"sync"
|
||||
@@ -152,10 +151,9 @@ func (h *streamHandler) handleData(
|
||||
) error {
|
||||
bodyDecoder := message.DataBodyDecoderFor(dataMsg.Name)
|
||||
|
||||
r := bytes.NewReader(dataMsg.Body)
|
||||
amfDec := message.NewAMFDecoder(r, dataMsg.Encoding)
|
||||
amfDec := message.NewAMFDecoder(dataMsg.Body, dataMsg.Encoding)
|
||||
var value message.AMFConvertible
|
||||
if err := bodyDecoder(r, amfDec, &value); err != nil {
|
||||
if err := bodyDecoder(dataMsg.Body, amfDec, &value); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -179,10 +177,7 @@ func (h *streamHandler) handleCommand(
|
||||
}
|
||||
|
||||
// Set result (NOTE: shoule use a mutex for t?)
|
||||
t.commandName = cmdMsg.CommandName
|
||||
t.encoding = cmdMsg.Encoding
|
||||
t.body = cmdMsg.Body
|
||||
close(t.doneCh)
|
||||
t.Reply(cmdMsg.CommandName, cmdMsg.Encoding, cmdMsg.Body)
|
||||
|
||||
// Remove transacaction because this transaction is resolved
|
||||
if err := h.stream.transactions.Delete(cmdMsg.TransactionID); err != nil {
|
||||
@@ -194,12 +189,11 @@ func (h *streamHandler) handleCommand(
|
||||
// TODO: Support onStatus
|
||||
}
|
||||
|
||||
r := bytes.NewReader(cmdMsg.Body)
|
||||
amfDec := message.NewAMFDecoder(r, cmdMsg.Encoding)
|
||||
amfDec := message.NewAMFDecoder(cmdMsg.Body, cmdMsg.Encoding)
|
||||
bodyDecoder := message.CmdBodyDecoderFor(cmdMsg.CommandName, cmdMsg.TransactionID)
|
||||
|
||||
var value message.AMFConvertible
|
||||
if err := bodyDecoder(r, amfDec, &value); err != nil {
|
||||
if err := bodyDecoder(cmdMsg.Body, amfDec, &value); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
||||
+13
-1
@@ -8,7 +8,9 @@
|
||||
package rtmp
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"github.com/pkg/errors"
|
||||
"io"
|
||||
"sync"
|
||||
|
||||
"github.com/yutopp/go-rtmp/message"
|
||||
@@ -17,10 +19,20 @@ import (
|
||||
type transaction struct {
|
||||
commandName string
|
||||
encoding message.EncodingType
|
||||
body []byte
|
||||
body *bytes.Buffer
|
||||
lastErr error
|
||||
doneCh chan struct{}
|
||||
}
|
||||
|
||||
func (t *transaction) Reply(commandName string, encoding message.EncodingType, body io.Reader) {
|
||||
t.commandName = commandName
|
||||
t.encoding = encoding
|
||||
t.body = new(bytes.Buffer)
|
||||
_, err := io.Copy(t.body, body)
|
||||
t.lastErr = err
|
||||
close(t.doneCh)
|
||||
}
|
||||
|
||||
type transactions struct {
|
||||
transactions map[int64]*transaction
|
||||
m sync.RWMutex
|
||||
|
||||
Reference in New Issue
Block a user