diff --git a/chunk_streamer.go b/chunk_streamer.go index 575a7a5..8b98722 100644 --- a/chunk_streamer.go +++ b/chunk_streamer.go @@ -38,6 +38,9 @@ type ChunkStreamer struct { writerSched *chunkStreamerWriterSched + msgDec *message.Decoder + msgEnc *message.Encoder + selfState *StreamControlState peerState *StreamControlState @@ -71,6 +74,9 @@ func NewChunkStreamer(r io.Reader, w io.Writer, config *StreamControlStateConfig stopCh: make(chan struct{}), }, + msgDec: message.NewDecoder(nil), + msgEnc: message.NewEncoder(nil), + selfState: NewStreamControlState(config), peerState: NewStreamControlState(config), @@ -85,21 +91,31 @@ func NewChunkStreamer(r io.Reader, w io.Writer, config *StreamControlStateConfig return cs } -func (cs *ChunkStreamer) Read(cmsg *ChunkMessage) (int, uint32, error) { +func (cs *ChunkStreamer) Read(cmsg *ChunkMessage) ( + chunkStreamID int, + timestamp uint32, + closer func(), + err error, +) { reader, err := cs.NewChunkReader() if err != nil { - return 0, 0, err + return 0, 0, nil, err } - defer reader.Close() + closer = func() { reader.Close() } + defer func() { + if err != nil { + closer() + } + }() - dec := message.NewDecoder(reader, message.TypeID(reader.messageTypeID)) - if err := dec.Decode(&cmsg.Message); err != nil { - return 0, 0, err + cs.msgDec.Reset(reader) + if err := cs.msgDec.Decode(message.TypeID(reader.messageTypeID), &cmsg.Message); err != nil { + return 0, 0, nil, err } cmsg.StreamID = reader.messageStreamID - return reader.basicHeader.chunkStreamID, uint32(reader.timestamp), nil + return reader.basicHeader.chunkStreamID, uint32(reader.timestamp), closer, nil } func (cs *ChunkStreamer) Write( @@ -114,8 +130,8 @@ func (cs *ChunkStreamer) Write( } //defer writer.Close() - enc := message.NewEncoder(writer) - if err := enc.Encode(cmsg.Message); err != nil { + cs.msgEnc.Reset(writer) + if err := cs.msgEnc.Encode(cmsg.Message); err != nil { return err } writer.timestamp = timestamp diff --git a/chunk_streamer_test.go b/chunk_streamer_test.go index a31aa17..f8a1a93 100644 --- a/chunk_streamer_test.go +++ b/chunk_streamer_test.go @@ -30,8 +30,9 @@ func TestStreamerSingleChunk(t *testing.T) { streamer := NewChunkStreamer(inbuf, outbuf, nil) chunkStreamID := 2 + videoContent := []byte("testtesttest") msg := &message.VideoMessage{ - Payload: []byte("testtesttest"), + Payload: bytes.NewReader(videoContent), } timestamp := uint32(72) @@ -59,14 +60,17 @@ func TestStreamerSingleChunk(t *testing.T) { assert.NotNil(t, r) defer r.Close() - dec := message.NewDecoder(r, message.TypeID(r.messageTypeID)) + dec := message.NewDecoder(r) var actualMsg message.Message - err = dec.Decode(&actualMsg) + err = dec.Decode(message.TypeID(r.messageTypeID), &actualMsg) assert.Nil(t, err) assert.Equal(t, uint64(timestamp), r.timestamp) // check message - assert.Equal(t, msg, actualMsg) + assert.Equal(t, actualMsg.TypeID(), msg.TypeID()) + actualMsgT := actualMsg.(*message.VideoMessage) + actualContent, _ := ioutil.ReadAll(actualMsgT.Payload) + assert.Equal(t, actualContent, videoContent) } func TestStreamerMultipleChunk(t *testing.T) { @@ -80,9 +84,10 @@ func TestStreamerMultipleChunk(t *testing.T) { streamer := NewChunkStreamer(inbuf, outbuf, nil) chunkStreamID := 2 + videoContent := []byte(strings.Repeat(payloadUnit, chunkSize)) msg := &message.VideoMessage{ // will be chunked (chunkSize * len(payloadUnit)) - Payload: []byte(strings.Repeat(payloadUnit, chunkSize)), + Payload: bytes.NewReader(videoContent), } timestamp := uint32(72) @@ -112,14 +117,17 @@ func TestStreamerMultipleChunk(t *testing.T) { assert.NotNil(t, r) defer r.Close() - dec := message.NewDecoder(r, message.TypeID(r.messageTypeID)) + dec := message.NewDecoder(r) var actualMsg message.Message - err = dec.Decode(&actualMsg) + err = dec.Decode(message.TypeID(r.messageTypeID), &actualMsg) assert.Nil(t, err) assert.Equal(t, uint64(timestamp), r.timestamp) // check message - assert.Equal(t, msg, actualMsg) + assert.Equal(t, actualMsg.TypeID(), msg.TypeID()) + actualMsgT := actualMsg.(*message.VideoMessage) + actualContent, _ := ioutil.ReadAll(actualMsgT.Payload) + assert.Equal(t, actualContent, videoContent) } func TestStreamerChunkExample1(t *testing.T) { @@ -332,7 +340,7 @@ func TestChunkStreamerDualWriter(t *testing.T) { err := streamer.Write(context.Background(), chunkStreamID, timestamp, &ChunkMessage{ StreamID: 0, Message: &message.VideoMessage{ - Payload: largePayload, + Payload: bytes.NewReader(largePayload), }, }) assert.Nil(t, err) @@ -366,7 +374,7 @@ func TestChunkStreamerDualWriterWithoutWaiting(t *testing.T) { err := streamer.Write(context.Background(), chunkStreamID, timestamp, &ChunkMessage{ StreamID: 0, Message: &message.VideoMessage{ - Payload: largePayload, + Payload: bytes.NewReader(largePayload), }, }) assert.Nil(t, err) diff --git a/conn.go b/conn.go index bac78ec..115f1e9 100644 --- a/conn.go +++ b/conn.go @@ -183,19 +183,21 @@ func (c *Conn) runHandleMessageLoop() error { return c.streamer.Err() default: - chunkStreamID, timestamp, err := c.streamer.Read(&cmsg) + chunkStreamID, timestamp, closer, err := c.streamer.Read(&cmsg) if err != nil { return err } - if err := c.dispatchStreamHandler(chunkStreamID, timestamp, &cmsg); err != nil { + if err := c.handleMessage(chunkStreamID, timestamp, closer, &cmsg); err != nil { return err // Shutdown the connection } } } } -func (c *Conn) dispatchStreamHandler(chunkStreamID int, timestamp uint32, cmsg *ChunkMessage) error { +func (c *Conn) handleMessage(chunkStreamID int, timestamp uint32, closer func(), cmsg *ChunkMessage) error { + defer closer() + stream, err := c.streams.At(cmsg.StreamID) if err != nil { if c.config.IgnoreMessagesOnNotExistStream { diff --git a/default_handler.go b/default_handler.go index d6b7b04..284c424 100644 --- a/default_handler.go +++ b/default_handler.go @@ -9,6 +9,7 @@ package rtmp import ( "github.com/yutopp/go-rtmp/message" + "io" ) var _ Handler = (*DefaultHandler)(nil) @@ -55,11 +56,11 @@ func (h *DefaultHandler) OnSetDataFrame(timestamp uint32, data *message.NetStrea return nil } -func (h *DefaultHandler) OnAudio(timestamp uint32, payload []byte) error { +func (h *DefaultHandler) OnAudio(timestamp uint32, payload io.Reader) error { return nil } -func (h *DefaultHandler) OnVideo(timestamp uint32, payload []byte) error { +func (h *DefaultHandler) OnVideo(timestamp uint32, payload io.Reader) error { return nil } diff --git a/example/server_demo/handler.go b/example/server_demo/handler.go index ab87cd0..2c78bf3 100644 --- a/example/server_demo/handler.go +++ b/example/server_demo/handler.go @@ -8,6 +8,7 @@ import ( flvtag "github.com/yutopp/go-flv/tag" "github.com/yutopp/go-rtmp" rtmpmsg "github.com/yutopp/go-rtmp/message" + "io" "log" "os" "path/filepath" @@ -80,11 +81,9 @@ func (h *Handler) OnSetDataFrame(timestamp uint32, data *rtmpmsg.NetStreamSetDat return nil } -func (h *Handler) OnAudio(timestamp uint32, payload []byte) error { - r := bytes.NewReader(payload) - +func (h *Handler) OnAudio(timestamp uint32, payload io.Reader) error { var audio flvtag.AudioData - if err := flvtag.DecodeAudioData(r, &audio); err != nil { + if err := flvtag.DecodeAudioData(payload, &audio); err != nil { return err } @@ -109,11 +108,9 @@ func (h *Handler) OnAudio(timestamp uint32, payload []byte) error { return nil } -func (h *Handler) OnVideo(timestamp uint32, payload []byte) error { - r := bytes.NewReader(payload) - +func (h *Handler) OnVideo(timestamp uint32, payload io.Reader) error { var video flvtag.VideoData - if err := flvtag.DecodeVideoData(r, &video); err != nil { + if err := flvtag.DecodeVideoData(payload, &video); err != nil { return err } diff --git a/handler.go b/handler.go index f4e9473..7a119a8 100644 --- a/handler.go +++ b/handler.go @@ -9,6 +9,7 @@ package rtmp import ( "github.com/yutopp/go-rtmp/message" + "io" ) type Handler interface { @@ -22,8 +23,8 @@ type Handler interface { OnFCPublish(timestamp uint32, cmd *message.NetStreamFCPublish) error OnFCUnpublish(timestamp uint32, cmd *message.NetStreamFCUnpublish) error OnSetDataFrame(timestamp uint32, data *message.NetStreamSetDataFrame) error - OnAudio(timestamp uint32, payload []byte) error - OnVideo(timestamp uint32, payload []byte) error + OnAudio(timestamp uint32, payload io.Reader) error + OnVideo(timestamp uint32, payload io.Reader) error OnUnknownMessage(timestamp uint32, msg message.Message) error OnUnknownCommandMessage(timestamp uint32, cmd *message.CommandMessage) error OnUnknownDataMessage(timestamp uint32, data *message.DataMessage) error diff --git a/message/common_test.go b/message/common_test.go index afcbc55..b07d2bf 100644 --- a/message/common_test.go +++ b/message/common_test.go @@ -7,6 +7,10 @@ package message +import ( + "bytes" +) + type testCase struct { Name string TypeID @@ -78,7 +82,7 @@ var testCases = []testCase{ Name: "AudioMessage", TypeID: TypeIDAudioMessage, Value: &AudioMessage{ - Payload: []byte("audio data"), + Payload: bytes.NewReader([]byte("audio data")), }, Binary: []byte("audio data"), }, @@ -86,7 +90,7 @@ var testCases = []testCase{ Name: "VideoMessage", TypeID: TypeIDVideoMessage, Value: &VideoMessage{ - Payload: []byte("video data"), + Payload: bytes.NewReader([]byte("video data")), }, Binary: []byte("video data"), }, @@ -97,7 +101,7 @@ var testCases = []testCase{ Value: &DataMessage{ Name: "test", Encoding: EncodingTypeAMF0, - Body: []byte("test"), + Body: bytes.NewReader([]byte("test")), }, Binary: []byte{ // Name: AMF0 / string marker @@ -120,7 +124,7 @@ var testCases = []testCase{ CommandName: "_result", TransactionID: 10, Encoding: EncodingTypeAMF0, - Body: []byte("test"), + Body: bytes.NewReader([]byte("test")), }, Binary: []byte{ // CommandName: AMF0 / string marker diff --git a/message/decoder.go b/message/decoder.go index bbfdf80..d627f0b 100644 --- a/message/decoder.go +++ b/message/decoder.go @@ -8,7 +8,6 @@ package message import ( - "bytes" "encoding/binary" "fmt" "github.com/pkg/errors" @@ -18,21 +17,21 @@ import ( ) type Decoder struct { - r io.Reader - typeID TypeID - - cacheBuffer bytes.Buffer + r io.Reader } -func NewDecoder(r io.Reader, typeID TypeID) *Decoder { +func NewDecoder(r io.Reader) *Decoder { return &Decoder{ - r: r, - typeID: typeID, + r: r, } } -func (dec *Decoder) Decode(msg *Message) error { - switch dec.typeID { +func (dec *Decoder) Reset(r io.Reader) { + dec.r = r +} + +func (dec *Decoder) Decode(typeID TypeID, msg *Message) error { + switch typeID { case TypeIDSetChunkSize: return dec.decodeSetChunkSize(msg) case TypeIDAbortMessage: @@ -64,7 +63,7 @@ func (dec *Decoder) Decode(msg *Message) error { case TypeIDAggregateMessage: return dec.decodeAggregateMessage(msg) default: - return fmt.Errorf("Unexpected message type(decode): ID = %d", dec.typeID) + return fmt.Errorf("Unexpected message type(decode): ID = %d", typeID) } } @@ -172,40 +171,16 @@ func (dec *Decoder) decodeSetPeerBandwidth(msg *Message) error { } func (dec *Decoder) decodeAudioMessage(msg *Message) error { - buf := &dec.cacheBuffer // TODO: Provide thread safety if needed - buf.Reset() - - _, err := io.Copy(buf, dec.r) - if err != nil { - return err - } - - // Copy ownership - bin := make([]byte, len(buf.Bytes())) - copy(bin, buf.Bytes()) - *msg = &AudioMessage{ - Payload: bin, + Payload: dec.r, // Share an ownership of the reader } return nil } func (dec *Decoder) decodeVideoMessage(msg *Message) error { - buf := &dec.cacheBuffer // TODO: Provide thread safety if needed - buf.Reset() - - _, err := io.Copy(buf, dec.r) - if err != nil { - return err - } - - // Copy ownership - bin := make([]byte, len(buf.Bytes())) - copy(bin, buf.Bytes()) - *msg = &VideoMessage{ - Payload: bin, + Payload: dec.r, // Share an ownership of the reader } return nil @@ -224,7 +199,7 @@ func (dec *Decoder) decodeCommandMessageAMF3(msg *Message) error { } func (dec *Decoder) decodeDataMessageAMF0(msg *Message) error { - if err := dec.decodeDataMessage(dec.r, msg, func(r io.Reader) (AMFDecoder, EncodingType) { + if err := dec.decodeDataMessage(msg, func(r io.Reader) (AMFDecoder, EncodingType) { return amf0.NewDecoder(r), EncodingTypeAMF0 }); err != nil { return err @@ -238,7 +213,7 @@ func (dec *Decoder) decodeSharedObjectMessageAMF0(msg *Message) error { } func (dec *Decoder) decodeCommandMessageAMF0(msg *Message) error { - if err := dec.decodeCommandMessage(dec.r, msg, func(r io.Reader) (AMFDecoder, EncodingType) { + if err := dec.decodeCommandMessage(msg, func(r io.Reader) (AMFDecoder, EncodingType) { return amf0.NewDecoder(r), EncodingTypeAMF0 }); err != nil { return err @@ -251,37 +226,25 @@ func (dec *Decoder) decodeAggregateMessage(msg *Message) error { return fmt.Errorf("Not implemented: AggregateMessage") } -func (dec *Decoder) decodeDataMessage(r io.Reader, msg *Message, f func(r io.Reader) (AMFDecoder, EncodingType)) error { - d, encTy := f(r) +func (dec *Decoder) decodeDataMessage(msg *Message, f func(r io.Reader) (AMFDecoder, EncodingType)) error { + d, encTy := f(dec.r) var name string if err := d.Decode(&name); err != nil { return errors.Wrap(err, "Failed to decode name") } - buf := &dec.cacheBuffer // TODO: Provide thread safety if needed - buf.Reset() - - _, err := io.Copy(buf, dec.r) - if err != nil { - return errors.Wrap(err, "Failed to copy body payload") - } - - // Copy ownership - bin := make([]byte, len(buf.Bytes())) - copy(bin, buf.Bytes()) - *msg = &DataMessage{ Name: name, Encoding: encTy, - Body: bin, + Body: dec.r, // Share an ownership of the reader } return nil } -func (dec *Decoder) decodeCommandMessage(r io.Reader, msg *Message, f func(r io.Reader) (AMFDecoder, EncodingType)) error { - d, encTy := f(r) +func (dec *Decoder) decodeCommandMessage(msg *Message, f func(r io.Reader) (AMFDecoder, EncodingType)) error { + d, encTy := f(dec.r) var name string if err := d.Decode(&name); err != nil { @@ -293,23 +256,11 @@ func (dec *Decoder) decodeCommandMessage(r io.Reader, msg *Message, f func(r io. return errors.Wrap(err, "Failed to decode transactionID") } - buf := &dec.cacheBuffer // TODO: Provide thread safety if needed - buf.Reset() - - _, err := io.Copy(buf, dec.r) - if err != nil { - return errors.Wrap(err, "Failed to copy body payload") - } - - // Copy ownership - bin := make([]byte, len(buf.Bytes())) - copy(bin, buf.Bytes()) - *msg = &CommandMessage{ CommandName: name, TransactionID: transactionID, Encoding: encTy, - Body: bin, + Body: dec.r, // Share an ownership of the reader } return nil diff --git a/message/decoder_test.go b/message/decoder_test.go index 944646b..e7ea18a 100644 --- a/message/decoder_test.go +++ b/message/decoder_test.go @@ -21,30 +21,37 @@ func TestDecodeCommon(t *testing.T) { t.Parallel() buf := bytes.NewReader(tc.Binary) - dec := NewDecoder(buf, tc.TypeID) + dec := NewDecoder(buf) var msg Message - err := dec.Decode(&msg) + err := dec.Decode(tc.TypeID, &msg) assert.Nil(t, err) - assert.Equal(t, tc.Value, msg) + assertEqualMessage(t, tc.Value, msg) }) } } -func BenchmarkDecodeVideoMessage(b *testing.B) { - buf := new(bytes.Buffer) - for i := 0; i < 1024; i++ { - buf.WriteString("abcde") - } - if buf.Len() != 5*1024 { - b.Fatalf("Buffer becomes unexpected state: Len = %d", buf.Len()) +func BenchmarkDecode5KBVideoMessage(b *testing.B) { + sizes := []struct { + name string + len int + }{ + {"5KB", 5 * 1024}, + {"2MB", 2 * 1024 * 1024}, } + for _, size := range sizes { + b.Run(size.name, func(b *testing.B) { + buf := make([]byte, size.len) + r := bytes.NewReader(buf) + dec := NewDecoder(r) - dec := NewDecoder(buf, TypeIDVideoMessage) + b.ResetTimer() + for i := 0; i < b.N; i++ { + r.Reset(buf) - b.ResetTimer() - for i := 0; i < b.N; i++ { - var msg Message - dec.Decode(&msg) + var msg Message + dec.Decode(TypeIDVideoMessage, &msg) + } + }) } } diff --git a/message/encoder.go b/message/encoder.go index 5aa66fb..f567273 100644 --- a/message/encoder.go +++ b/message/encoder.go @@ -8,7 +8,6 @@ package message import ( - "bytes" "encoding/binary" "fmt" "io" @@ -24,6 +23,10 @@ func NewEncoder(w io.Writer) *Encoder { } } +func (enc *Encoder) Reset(w io.Writer) { + enc.w = w +} + // Encode func (enc *Encoder) Encode(msg Message) error { switch msg := msg.(type) { @@ -124,7 +127,7 @@ func (enc *Encoder) encodeSetPeerBandwidth(m *SetPeerBandwidth) error { } func (enc *Encoder) encodeAudioMessage(m *AudioMessage) error { - if _, err := enc.w.Write(m.Payload); err != nil { + if _, err := io.Copy(enc.w, m.Payload); err != nil { return err } @@ -132,7 +135,7 @@ func (enc *Encoder) encodeAudioMessage(m *AudioMessage) error { } func (enc *Encoder) encodeVideoMessage(m *VideoMessage) error { - if _, err := enc.w.Write(m.Payload); err != nil { + if _, err := io.Copy(enc.w, m.Payload); err != nil { return err } @@ -150,7 +153,7 @@ func (enc *Encoder) encodeDataMessage(m *DataMessage) error { return err } - if _, err := io.Copy(enc.w, bytes.NewReader(m.Body)); err != nil { + if _, err := io.Copy(enc.w, m.Body); err != nil { return err } @@ -171,7 +174,7 @@ func (enc *Encoder) encodeCommandMessage(m *CommandMessage) error { return err } - if _, err := io.Copy(enc.w, bytes.NewReader(m.Body)); err != nil { + if _, err := io.Copy(enc.w, m.Body); err != nil { return err } diff --git a/message/message.go b/message/message.go index e3da751..662f83a 100644 --- a/message/message.go +++ b/message/message.go @@ -7,6 +7,10 @@ package message +import ( + "io" +) + type TypeID byte const ( @@ -97,7 +101,7 @@ func (m *SetPeerBandwidth) TypeID() TypeID { // AudioMessage(8) type AudioMessage struct { - Payload []byte + Payload io.Reader } func (m *AudioMessage) TypeID() TypeID { @@ -106,7 +110,7 @@ func (m *AudioMessage) TypeID() TypeID { // VideoMessage(9) type VideoMessage struct { - Payload []byte + Payload io.Reader } func (m *VideoMessage) TypeID() TypeID { @@ -117,7 +121,7 @@ func (m *VideoMessage) TypeID() TypeID { type DataMessage struct { Name string Encoding EncodingType - Body []byte + Body io.Reader } func (m *DataMessage) TypeID() TypeID { @@ -156,7 +160,7 @@ type CommandMessage struct { CommandName string TransactionID int64 Encoding EncodingType - Body []byte + Body io.Reader } func (m *CommandMessage) TypeID() TypeID { diff --git a/message/message_test.go b/message/message_test.go new file mode 100644 index 0000000..80ca9d0 --- /dev/null +++ b/message/message_test.go @@ -0,0 +1,72 @@ +// +// Copyright (c) 2018- yutopp (yutopp@gmail.com) +// +// Distributed under the Boost Software License, Version 1.0. (See accompanying +// file LICENSE_1_0.txt or copy at https://www.boost.org/LICENSE_1_0.txt) +// + +package message + +import ( + "bytes" + "github.com/stretchr/testify/assert" + "io" + "io/ioutil" + "testing" +) + +func assertEqualMessage(t *testing.T, expected, actual Message) { + assert.Equal(t, expected.TypeID(), actual.TypeID()) + + switch expected := expected.(type) { + case *AudioMessage: + actual, ok := actual.(*AudioMessage) + assert.True(t, ok) + + assertEqualPayload(t, expected.Payload, actual.Payload) + + case *VideoMessage: + actual, ok := actual.(*VideoMessage) + assert.True(t, ok) + + assertEqualPayload(t, expected.Payload, actual.Payload) + + case *DataMessage: + actual, ok := actual.(*DataMessage) + assert.True(t, ok) + + assert.Equal(t, expected.Name, actual.Name) + assert.Equal(t, expected.Encoding, actual.Encoding) + assertEqualPayload(t, expected.Body, actual.Body) + + case *CommandMessage: + actual, ok := actual.(*CommandMessage) + assert.True(t, ok) + + assert.Equal(t, expected.CommandName, actual.CommandName) + assert.Equal(t, expected.TransactionID, actual.TransactionID) + assert.Equal(t, expected.Encoding, actual.Encoding) + assertEqualPayload(t, expected.Body, actual.Body) + + default: + assert.Equal(t, expected, actual) + } +} + +func assertEqualPayload(t *testing.T, expected, actual io.Reader) { + expectedBin, err := ioutil.ReadAll(expected) + assert.Nil(t, err) + switch p := expected.(type) { + case *bytes.Reader: + defer p.Seek(0, io.SeekStart) // Restore test case states + default: + t.FailNow() + } + assert.NotZero(t, len(expectedBin)) + + actualBin, err := ioutil.ReadAll(actual) + assert.Nil(t, err) + assert.NotZero(t, len(actualBin)) + + assert.Equal(t, expectedBin, actualBin) +} diff --git a/stream.go b/stream.go index 54962db..014f261 100644 --- a/stream.go +++ b/stream.go @@ -79,11 +79,10 @@ func (s *Stream) Connect() (*message.NetConnectionConnectResult, error) { // TODO: check result select { case <-t.doneCh: - r := bytes.NewReader(t.body) - amfDec := message.NewAMFDecoder(r, t.encoding) + amfDec := message.NewAMFDecoder(t.body, t.encoding) var value message.AMFConvertible - if err := message.DecodeBodyConnectResult(r, amfDec, &value); err != nil { + if err := message.DecodeBodyConnectResult(t.body, amfDec, &value); err != nil { return nil, errors.Wrap(err, "Failed to decode result") } result := value.(*message.NetConnectionConnectResult) @@ -144,11 +143,10 @@ func (s *Stream) CreateStream() (*message.NetConnectionCreateStreamResult, error // TODO: check result select { case <-t.doneCh: - r := bytes.NewReader(t.body) - amfDec := message.NewAMFDecoder(r, t.encoding) + amfDec := message.NewAMFDecoder(t.body, t.encoding) var value message.AMFConvertible - if err := message.DecodeBodyCreateStreamResult(r, amfDec, &value); err != nil { + if err := message.DecodeBodyCreateStreamResult(t.body, amfDec, &value); err != nil { return nil, errors.Wrap(err, "Failed to decode result") } result := value.(*message.NetConnectionCreateStreamResult) @@ -222,7 +220,7 @@ func (s *Stream) writeCommandMessage( CommandName: commandName, TransactionID: transactionID, Encoding: s.encTy, - Body: buf.Bytes(), + Body: buf, }) } diff --git a/stream_handler.go b/stream_handler.go index 9e79013..6ab4e25 100644 --- a/stream_handler.go +++ b/stream_handler.go @@ -8,7 +8,6 @@ package rtmp import ( - "bytes" "github.com/pkg/errors" "github.com/sirupsen/logrus" "sync" @@ -152,10 +151,9 @@ func (h *streamHandler) handleData( ) error { bodyDecoder := message.DataBodyDecoderFor(dataMsg.Name) - r := bytes.NewReader(dataMsg.Body) - amfDec := message.NewAMFDecoder(r, dataMsg.Encoding) + amfDec := message.NewAMFDecoder(dataMsg.Body, dataMsg.Encoding) var value message.AMFConvertible - if err := bodyDecoder(r, amfDec, &value); err != nil { + if err := bodyDecoder(dataMsg.Body, amfDec, &value); err != nil { return err } @@ -179,10 +177,7 @@ func (h *streamHandler) handleCommand( } // Set result (NOTE: shoule use a mutex for t?) - t.commandName = cmdMsg.CommandName - t.encoding = cmdMsg.Encoding - t.body = cmdMsg.Body - close(t.doneCh) + t.Reply(cmdMsg.CommandName, cmdMsg.Encoding, cmdMsg.Body) // Remove transacaction because this transaction is resolved if err := h.stream.transactions.Delete(cmdMsg.TransactionID); err != nil { @@ -194,12 +189,11 @@ func (h *streamHandler) handleCommand( // TODO: Support onStatus } - r := bytes.NewReader(cmdMsg.Body) - amfDec := message.NewAMFDecoder(r, cmdMsg.Encoding) + amfDec := message.NewAMFDecoder(cmdMsg.Body, cmdMsg.Encoding) bodyDecoder := message.CmdBodyDecoderFor(cmdMsg.CommandName, cmdMsg.TransactionID) var value message.AMFConvertible - if err := bodyDecoder(r, amfDec, &value); err != nil { + if err := bodyDecoder(cmdMsg.Body, amfDec, &value); err != nil { return err } diff --git a/transactions.go b/transactions.go index f1f3d65..202682f 100644 --- a/transactions.go +++ b/transactions.go @@ -8,7 +8,9 @@ package rtmp import ( + "bytes" "github.com/pkg/errors" + "io" "sync" "github.com/yutopp/go-rtmp/message" @@ -17,10 +19,20 @@ import ( type transaction struct { commandName string encoding message.EncodingType - body []byte + body *bytes.Buffer + lastErr error doneCh chan struct{} } +func (t *transaction) Reply(commandName string, encoding message.EncodingType, body io.Reader) { + t.commandName = commandName + t.encoding = encoding + t.body = new(bytes.Buffer) + _, err := io.Copy(t.body, body) + t.lastErr = err + close(t.doneCh) +} + type transactions struct { transactions map[int64]*transaction m sync.RWMutex