Merge pull request #25 from yutopp/feature/fix_bench_and_improve_performace

Use io.Reader instead of []byte for all cases
This commit is contained in:
yutopp
2018-10-05 22:27:49 +09:00
committed by GitHub
15 changed files with 220 additions and 150 deletions
+25 -9
View File
@@ -38,6 +38,9 @@ type ChunkStreamer struct {
writerSched *chunkStreamerWriterSched
msgDec *message.Decoder
msgEnc *message.Encoder
selfState *StreamControlState
peerState *StreamControlState
@@ -71,6 +74,9 @@ func NewChunkStreamer(r io.Reader, w io.Writer, config *StreamControlStateConfig
stopCh: make(chan struct{}),
},
msgDec: message.NewDecoder(nil),
msgEnc: message.NewEncoder(nil),
selfState: NewStreamControlState(config),
peerState: NewStreamControlState(config),
@@ -85,21 +91,31 @@ func NewChunkStreamer(r io.Reader, w io.Writer, config *StreamControlStateConfig
return cs
}
func (cs *ChunkStreamer) Read(cmsg *ChunkMessage) (int, uint32, error) {
func (cs *ChunkStreamer) Read(cmsg *ChunkMessage) (
chunkStreamID int,
timestamp uint32,
closer func(),
err error,
) {
reader, err := cs.NewChunkReader()
if err != nil {
return 0, 0, err
return 0, 0, nil, err
}
defer reader.Close()
closer = func() { reader.Close() }
defer func() {
if err != nil {
closer()
}
}()
dec := message.NewDecoder(reader, message.TypeID(reader.messageTypeID))
if err := dec.Decode(&cmsg.Message); err != nil {
return 0, 0, err
cs.msgDec.Reset(reader)
if err := cs.msgDec.Decode(message.TypeID(reader.messageTypeID), &cmsg.Message); err != nil {
return 0, 0, nil, err
}
cmsg.StreamID = reader.messageStreamID
return reader.basicHeader.chunkStreamID, uint32(reader.timestamp), nil
return reader.basicHeader.chunkStreamID, uint32(reader.timestamp), closer, nil
}
func (cs *ChunkStreamer) Write(
@@ -114,8 +130,8 @@ func (cs *ChunkStreamer) Write(
}
//defer writer.Close()
enc := message.NewEncoder(writer)
if err := enc.Encode(cmsg.Message); err != nil {
cs.msgEnc.Reset(writer)
if err := cs.msgEnc.Encode(cmsg.Message); err != nil {
return err
}
writer.timestamp = timestamp
+18 -10
View File
@@ -30,8 +30,9 @@ func TestStreamerSingleChunk(t *testing.T) {
streamer := NewChunkStreamer(inbuf, outbuf, nil)
chunkStreamID := 2
videoContent := []byte("testtesttest")
msg := &message.VideoMessage{
Payload: []byte("testtesttest"),
Payload: bytes.NewReader(videoContent),
}
timestamp := uint32(72)
@@ -59,14 +60,17 @@ func TestStreamerSingleChunk(t *testing.T) {
assert.NotNil(t, r)
defer r.Close()
dec := message.NewDecoder(r, message.TypeID(r.messageTypeID))
dec := message.NewDecoder(r)
var actualMsg message.Message
err = dec.Decode(&actualMsg)
err = dec.Decode(message.TypeID(r.messageTypeID), &actualMsg)
assert.Nil(t, err)
assert.Equal(t, uint64(timestamp), r.timestamp)
// check message
assert.Equal(t, msg, actualMsg)
assert.Equal(t, actualMsg.TypeID(), msg.TypeID())
actualMsgT := actualMsg.(*message.VideoMessage)
actualContent, _ := ioutil.ReadAll(actualMsgT.Payload)
assert.Equal(t, actualContent, videoContent)
}
func TestStreamerMultipleChunk(t *testing.T) {
@@ -80,9 +84,10 @@ func TestStreamerMultipleChunk(t *testing.T) {
streamer := NewChunkStreamer(inbuf, outbuf, nil)
chunkStreamID := 2
videoContent := []byte(strings.Repeat(payloadUnit, chunkSize))
msg := &message.VideoMessage{
// will be chunked (chunkSize * len(payloadUnit))
Payload: []byte(strings.Repeat(payloadUnit, chunkSize)),
Payload: bytes.NewReader(videoContent),
}
timestamp := uint32(72)
@@ -112,14 +117,17 @@ func TestStreamerMultipleChunk(t *testing.T) {
assert.NotNil(t, r)
defer r.Close()
dec := message.NewDecoder(r, message.TypeID(r.messageTypeID))
dec := message.NewDecoder(r)
var actualMsg message.Message
err = dec.Decode(&actualMsg)
err = dec.Decode(message.TypeID(r.messageTypeID), &actualMsg)
assert.Nil(t, err)
assert.Equal(t, uint64(timestamp), r.timestamp)
// check message
assert.Equal(t, msg, actualMsg)
assert.Equal(t, actualMsg.TypeID(), msg.TypeID())
actualMsgT := actualMsg.(*message.VideoMessage)
actualContent, _ := ioutil.ReadAll(actualMsgT.Payload)
assert.Equal(t, actualContent, videoContent)
}
func TestStreamerChunkExample1(t *testing.T) {
@@ -332,7 +340,7 @@ func TestChunkStreamerDualWriter(t *testing.T) {
err := streamer.Write(context.Background(), chunkStreamID, timestamp, &ChunkMessage{
StreamID: 0,
Message: &message.VideoMessage{
Payload: largePayload,
Payload: bytes.NewReader(largePayload),
},
})
assert.Nil(t, err)
@@ -366,7 +374,7 @@ func TestChunkStreamerDualWriterWithoutWaiting(t *testing.T) {
err := streamer.Write(context.Background(), chunkStreamID, timestamp, &ChunkMessage{
StreamID: 0,
Message: &message.VideoMessage{
Payload: largePayload,
Payload: bytes.NewReader(largePayload),
},
})
assert.Nil(t, err)
+5 -3
View File
@@ -183,19 +183,21 @@ func (c *Conn) runHandleMessageLoop() error {
return c.streamer.Err()
default:
chunkStreamID, timestamp, err := c.streamer.Read(&cmsg)
chunkStreamID, timestamp, closer, err := c.streamer.Read(&cmsg)
if err != nil {
return err
}
if err := c.dispatchStreamHandler(chunkStreamID, timestamp, &cmsg); err != nil {
if err := c.handleMessage(chunkStreamID, timestamp, closer, &cmsg); err != nil {
return err // Shutdown the connection
}
}
}
}
func (c *Conn) dispatchStreamHandler(chunkStreamID int, timestamp uint32, cmsg *ChunkMessage) error {
func (c *Conn) handleMessage(chunkStreamID int, timestamp uint32, closer func(), cmsg *ChunkMessage) error {
defer closer()
stream, err := c.streams.At(cmsg.StreamID)
if err != nil {
if c.config.IgnoreMessagesOnNotExistStream {
+3 -2
View File
@@ -9,6 +9,7 @@ package rtmp
import (
"github.com/yutopp/go-rtmp/message"
"io"
)
var _ Handler = (*DefaultHandler)(nil)
@@ -55,11 +56,11 @@ func (h *DefaultHandler) OnSetDataFrame(timestamp uint32, data *message.NetStrea
return nil
}
func (h *DefaultHandler) OnAudio(timestamp uint32, payload []byte) error {
func (h *DefaultHandler) OnAudio(timestamp uint32, payload io.Reader) error {
return nil
}
func (h *DefaultHandler) OnVideo(timestamp uint32, payload []byte) error {
func (h *DefaultHandler) OnVideo(timestamp uint32, payload io.Reader) error {
return nil
}
+5 -8
View File
@@ -8,6 +8,7 @@ import (
flvtag "github.com/yutopp/go-flv/tag"
"github.com/yutopp/go-rtmp"
rtmpmsg "github.com/yutopp/go-rtmp/message"
"io"
"log"
"os"
"path/filepath"
@@ -80,11 +81,9 @@ func (h *Handler) OnSetDataFrame(timestamp uint32, data *rtmpmsg.NetStreamSetDat
return nil
}
func (h *Handler) OnAudio(timestamp uint32, payload []byte) error {
r := bytes.NewReader(payload)
func (h *Handler) OnAudio(timestamp uint32, payload io.Reader) error {
var audio flvtag.AudioData
if err := flvtag.DecodeAudioData(r, &audio); err != nil {
if err := flvtag.DecodeAudioData(payload, &audio); err != nil {
return err
}
@@ -109,11 +108,9 @@ func (h *Handler) OnAudio(timestamp uint32, payload []byte) error {
return nil
}
func (h *Handler) OnVideo(timestamp uint32, payload []byte) error {
r := bytes.NewReader(payload)
func (h *Handler) OnVideo(timestamp uint32, payload io.Reader) error {
var video flvtag.VideoData
if err := flvtag.DecodeVideoData(r, &video); err != nil {
if err := flvtag.DecodeVideoData(payload, &video); err != nil {
return err
}
+3 -2
View File
@@ -9,6 +9,7 @@ package rtmp
import (
"github.com/yutopp/go-rtmp/message"
"io"
)
type Handler interface {
@@ -22,8 +23,8 @@ type Handler interface {
OnFCPublish(timestamp uint32, cmd *message.NetStreamFCPublish) error
OnFCUnpublish(timestamp uint32, cmd *message.NetStreamFCUnpublish) error
OnSetDataFrame(timestamp uint32, data *message.NetStreamSetDataFrame) error
OnAudio(timestamp uint32, payload []byte) error
OnVideo(timestamp uint32, payload []byte) error
OnAudio(timestamp uint32, payload io.Reader) error
OnVideo(timestamp uint32, payload io.Reader) error
OnUnknownMessage(timestamp uint32, msg message.Message) error
OnUnknownCommandMessage(timestamp uint32, cmd *message.CommandMessage) error
OnUnknownDataMessage(timestamp uint32, data *message.DataMessage) error
+8 -4
View File
@@ -7,6 +7,10 @@
package message
import (
"bytes"
)
type testCase struct {
Name string
TypeID
@@ -78,7 +82,7 @@ var testCases = []testCase{
Name: "AudioMessage",
TypeID: TypeIDAudioMessage,
Value: &AudioMessage{
Payload: []byte("audio data"),
Payload: bytes.NewReader([]byte("audio data")),
},
Binary: []byte("audio data"),
},
@@ -86,7 +90,7 @@ var testCases = []testCase{
Name: "VideoMessage",
TypeID: TypeIDVideoMessage,
Value: &VideoMessage{
Payload: []byte("video data"),
Payload: bytes.NewReader([]byte("video data")),
},
Binary: []byte("video data"),
},
@@ -97,7 +101,7 @@ var testCases = []testCase{
Value: &DataMessage{
Name: "test",
Encoding: EncodingTypeAMF0,
Body: []byte("test"),
Body: bytes.NewReader([]byte("test")),
},
Binary: []byte{
// Name: AMF0 / string marker
@@ -120,7 +124,7 @@ var testCases = []testCase{
CommandName: "_result",
TransactionID: 10,
Encoding: EncodingTypeAMF0,
Body: []byte("test"),
Body: bytes.NewReader([]byte("test")),
},
Binary: []byte{
// CommandName: AMF0 / string marker
+20 -69
View File
@@ -8,7 +8,6 @@
package message
import (
"bytes"
"encoding/binary"
"fmt"
"github.com/pkg/errors"
@@ -18,21 +17,21 @@ import (
)
type Decoder struct {
r io.Reader
typeID TypeID
cacheBuffer bytes.Buffer
r io.Reader
}
func NewDecoder(r io.Reader, typeID TypeID) *Decoder {
func NewDecoder(r io.Reader) *Decoder {
return &Decoder{
r: r,
typeID: typeID,
r: r,
}
}
func (dec *Decoder) Decode(msg *Message) error {
switch dec.typeID {
func (dec *Decoder) Reset(r io.Reader) {
dec.r = r
}
func (dec *Decoder) Decode(typeID TypeID, msg *Message) error {
switch typeID {
case TypeIDSetChunkSize:
return dec.decodeSetChunkSize(msg)
case TypeIDAbortMessage:
@@ -64,7 +63,7 @@ func (dec *Decoder) Decode(msg *Message) error {
case TypeIDAggregateMessage:
return dec.decodeAggregateMessage(msg)
default:
return fmt.Errorf("Unexpected message type(decode): ID = %d", dec.typeID)
return fmt.Errorf("Unexpected message type(decode): ID = %d", typeID)
}
}
@@ -172,40 +171,16 @@ func (dec *Decoder) decodeSetPeerBandwidth(msg *Message) error {
}
func (dec *Decoder) decodeAudioMessage(msg *Message) error {
buf := &dec.cacheBuffer // TODO: Provide thread safety if needed
buf.Reset()
_, err := io.Copy(buf, dec.r)
if err != nil {
return err
}
// Copy ownership
bin := make([]byte, len(buf.Bytes()))
copy(bin, buf.Bytes())
*msg = &AudioMessage{
Payload: bin,
Payload: dec.r, // Share an ownership of the reader
}
return nil
}
func (dec *Decoder) decodeVideoMessage(msg *Message) error {
buf := &dec.cacheBuffer // TODO: Provide thread safety if needed
buf.Reset()
_, err := io.Copy(buf, dec.r)
if err != nil {
return err
}
// Copy ownership
bin := make([]byte, len(buf.Bytes()))
copy(bin, buf.Bytes())
*msg = &VideoMessage{
Payload: bin,
Payload: dec.r, // Share an ownership of the reader
}
return nil
@@ -224,7 +199,7 @@ func (dec *Decoder) decodeCommandMessageAMF3(msg *Message) error {
}
func (dec *Decoder) decodeDataMessageAMF0(msg *Message) error {
if err := dec.decodeDataMessage(dec.r, msg, func(r io.Reader) (AMFDecoder, EncodingType) {
if err := dec.decodeDataMessage(msg, func(r io.Reader) (AMFDecoder, EncodingType) {
return amf0.NewDecoder(r), EncodingTypeAMF0
}); err != nil {
return err
@@ -238,7 +213,7 @@ func (dec *Decoder) decodeSharedObjectMessageAMF0(msg *Message) error {
}
func (dec *Decoder) decodeCommandMessageAMF0(msg *Message) error {
if err := dec.decodeCommandMessage(dec.r, msg, func(r io.Reader) (AMFDecoder, EncodingType) {
if err := dec.decodeCommandMessage(msg, func(r io.Reader) (AMFDecoder, EncodingType) {
return amf0.NewDecoder(r), EncodingTypeAMF0
}); err != nil {
return err
@@ -251,37 +226,25 @@ func (dec *Decoder) decodeAggregateMessage(msg *Message) error {
return fmt.Errorf("Not implemented: AggregateMessage")
}
func (dec *Decoder) decodeDataMessage(r io.Reader, msg *Message, f func(r io.Reader) (AMFDecoder, EncodingType)) error {
d, encTy := f(r)
func (dec *Decoder) decodeDataMessage(msg *Message, f func(r io.Reader) (AMFDecoder, EncodingType)) error {
d, encTy := f(dec.r)
var name string
if err := d.Decode(&name); err != nil {
return errors.Wrap(err, "Failed to decode name")
}
buf := &dec.cacheBuffer // TODO: Provide thread safety if needed
buf.Reset()
_, err := io.Copy(buf, dec.r)
if err != nil {
return errors.Wrap(err, "Failed to copy body payload")
}
// Copy ownership
bin := make([]byte, len(buf.Bytes()))
copy(bin, buf.Bytes())
*msg = &DataMessage{
Name: name,
Encoding: encTy,
Body: bin,
Body: dec.r, // Share an ownership of the reader
}
return nil
}
func (dec *Decoder) decodeCommandMessage(r io.Reader, msg *Message, f func(r io.Reader) (AMFDecoder, EncodingType)) error {
d, encTy := f(r)
func (dec *Decoder) decodeCommandMessage(msg *Message, f func(r io.Reader) (AMFDecoder, EncodingType)) error {
d, encTy := f(dec.r)
var name string
if err := d.Decode(&name); err != nil {
@@ -293,23 +256,11 @@ func (dec *Decoder) decodeCommandMessage(r io.Reader, msg *Message, f func(r io.
return errors.Wrap(err, "Failed to decode transactionID")
}
buf := &dec.cacheBuffer // TODO: Provide thread safety if needed
buf.Reset()
_, err := io.Copy(buf, dec.r)
if err != nil {
return errors.Wrap(err, "Failed to copy body payload")
}
// Copy ownership
bin := make([]byte, len(buf.Bytes()))
copy(bin, buf.Bytes())
*msg = &CommandMessage{
CommandName: name,
TransactionID: transactionID,
Encoding: encTy,
Body: bin,
Body: dec.r, // Share an ownership of the reader
}
return nil
+22 -15
View File
@@ -21,30 +21,37 @@ func TestDecodeCommon(t *testing.T) {
t.Parallel()
buf := bytes.NewReader(tc.Binary)
dec := NewDecoder(buf, tc.TypeID)
dec := NewDecoder(buf)
var msg Message
err := dec.Decode(&msg)
err := dec.Decode(tc.TypeID, &msg)
assert.Nil(t, err)
assert.Equal(t, tc.Value, msg)
assertEqualMessage(t, tc.Value, msg)
})
}
}
func BenchmarkDecodeVideoMessage(b *testing.B) {
buf := new(bytes.Buffer)
for i := 0; i < 1024; i++ {
buf.WriteString("abcde")
}
if buf.Len() != 5*1024 {
b.Fatalf("Buffer becomes unexpected state: Len = %d", buf.Len())
func BenchmarkDecode5KBVideoMessage(b *testing.B) {
sizes := []struct {
name string
len int
}{
{"5KB", 5 * 1024},
{"2MB", 2 * 1024 * 1024},
}
for _, size := range sizes {
b.Run(size.name, func(b *testing.B) {
buf := make([]byte, size.len)
r := bytes.NewReader(buf)
dec := NewDecoder(r)
dec := NewDecoder(buf, TypeIDVideoMessage)
b.ResetTimer()
for i := 0; i < b.N; i++ {
r.Reset(buf)
b.ResetTimer()
for i := 0; i < b.N; i++ {
var msg Message
dec.Decode(&msg)
var msg Message
dec.Decode(TypeIDVideoMessage, &msg)
}
})
}
}
+8 -5
View File
@@ -8,7 +8,6 @@
package message
import (
"bytes"
"encoding/binary"
"fmt"
"io"
@@ -24,6 +23,10 @@ func NewEncoder(w io.Writer) *Encoder {
}
}
func (enc *Encoder) Reset(w io.Writer) {
enc.w = w
}
// Encode
func (enc *Encoder) Encode(msg Message) error {
switch msg := msg.(type) {
@@ -124,7 +127,7 @@ func (enc *Encoder) encodeSetPeerBandwidth(m *SetPeerBandwidth) error {
}
func (enc *Encoder) encodeAudioMessage(m *AudioMessage) error {
if _, err := enc.w.Write(m.Payload); err != nil {
if _, err := io.Copy(enc.w, m.Payload); err != nil {
return err
}
@@ -132,7 +135,7 @@ func (enc *Encoder) encodeAudioMessage(m *AudioMessage) error {
}
func (enc *Encoder) encodeVideoMessage(m *VideoMessage) error {
if _, err := enc.w.Write(m.Payload); err != nil {
if _, err := io.Copy(enc.w, m.Payload); err != nil {
return err
}
@@ -150,7 +153,7 @@ func (enc *Encoder) encodeDataMessage(m *DataMessage) error {
return err
}
if _, err := io.Copy(enc.w, bytes.NewReader(m.Body)); err != nil {
if _, err := io.Copy(enc.w, m.Body); err != nil {
return err
}
@@ -171,7 +174,7 @@ func (enc *Encoder) encodeCommandMessage(m *CommandMessage) error {
return err
}
if _, err := io.Copy(enc.w, bytes.NewReader(m.Body)); err != nil {
if _, err := io.Copy(enc.w, m.Body); err != nil {
return err
}
+8 -4
View File
@@ -7,6 +7,10 @@
package message
import (
"io"
)
type TypeID byte
const (
@@ -97,7 +101,7 @@ func (m *SetPeerBandwidth) TypeID() TypeID {
// AudioMessage(8)
type AudioMessage struct {
Payload []byte
Payload io.Reader
}
func (m *AudioMessage) TypeID() TypeID {
@@ -106,7 +110,7 @@ func (m *AudioMessage) TypeID() TypeID {
// VideoMessage(9)
type VideoMessage struct {
Payload []byte
Payload io.Reader
}
func (m *VideoMessage) TypeID() TypeID {
@@ -117,7 +121,7 @@ func (m *VideoMessage) TypeID() TypeID {
type DataMessage struct {
Name string
Encoding EncodingType
Body []byte
Body io.Reader
}
func (m *DataMessage) TypeID() TypeID {
@@ -156,7 +160,7 @@ type CommandMessage struct {
CommandName string
TransactionID int64
Encoding EncodingType
Body []byte
Body io.Reader
}
func (m *CommandMessage) TypeID() TypeID {
+72
View File
@@ -0,0 +1,72 @@
//
// Copyright (c) 2018- yutopp (yutopp@gmail.com)
//
// Distributed under the Boost Software License, Version 1.0. (See accompanying
// file LICENSE_1_0.txt or copy at https://www.boost.org/LICENSE_1_0.txt)
//
package message
import (
"bytes"
"github.com/stretchr/testify/assert"
"io"
"io/ioutil"
"testing"
)
func assertEqualMessage(t *testing.T, expected, actual Message) {
assert.Equal(t, expected.TypeID(), actual.TypeID())
switch expected := expected.(type) {
case *AudioMessage:
actual, ok := actual.(*AudioMessage)
assert.True(t, ok)
assertEqualPayload(t, expected.Payload, actual.Payload)
case *VideoMessage:
actual, ok := actual.(*VideoMessage)
assert.True(t, ok)
assertEqualPayload(t, expected.Payload, actual.Payload)
case *DataMessage:
actual, ok := actual.(*DataMessage)
assert.True(t, ok)
assert.Equal(t, expected.Name, actual.Name)
assert.Equal(t, expected.Encoding, actual.Encoding)
assertEqualPayload(t, expected.Body, actual.Body)
case *CommandMessage:
actual, ok := actual.(*CommandMessage)
assert.True(t, ok)
assert.Equal(t, expected.CommandName, actual.CommandName)
assert.Equal(t, expected.TransactionID, actual.TransactionID)
assert.Equal(t, expected.Encoding, actual.Encoding)
assertEqualPayload(t, expected.Body, actual.Body)
default:
assert.Equal(t, expected, actual)
}
}
func assertEqualPayload(t *testing.T, expected, actual io.Reader) {
expectedBin, err := ioutil.ReadAll(expected)
assert.Nil(t, err)
switch p := expected.(type) {
case *bytes.Reader:
defer p.Seek(0, io.SeekStart) // Restore test case states
default:
t.FailNow()
}
assert.NotZero(t, len(expectedBin))
actualBin, err := ioutil.ReadAll(actual)
assert.Nil(t, err)
assert.NotZero(t, len(actualBin))
assert.Equal(t, expectedBin, actualBin)
}
+5 -7
View File
@@ -79,11 +79,10 @@ func (s *Stream) Connect() (*message.NetConnectionConnectResult, error) {
// TODO: check result
select {
case <-t.doneCh:
r := bytes.NewReader(t.body)
amfDec := message.NewAMFDecoder(r, t.encoding)
amfDec := message.NewAMFDecoder(t.body, t.encoding)
var value message.AMFConvertible
if err := message.DecodeBodyConnectResult(r, amfDec, &value); err != nil {
if err := message.DecodeBodyConnectResult(t.body, amfDec, &value); err != nil {
return nil, errors.Wrap(err, "Failed to decode result")
}
result := value.(*message.NetConnectionConnectResult)
@@ -144,11 +143,10 @@ func (s *Stream) CreateStream() (*message.NetConnectionCreateStreamResult, error
// TODO: check result
select {
case <-t.doneCh:
r := bytes.NewReader(t.body)
amfDec := message.NewAMFDecoder(r, t.encoding)
amfDec := message.NewAMFDecoder(t.body, t.encoding)
var value message.AMFConvertible
if err := message.DecodeBodyCreateStreamResult(r, amfDec, &value); err != nil {
if err := message.DecodeBodyCreateStreamResult(t.body, amfDec, &value); err != nil {
return nil, errors.Wrap(err, "Failed to decode result")
}
result := value.(*message.NetConnectionCreateStreamResult)
@@ -222,7 +220,7 @@ func (s *Stream) writeCommandMessage(
CommandName: commandName,
TransactionID: transactionID,
Encoding: s.encTy,
Body: buf.Bytes(),
Body: buf,
})
}
+5 -11
View File
@@ -8,7 +8,6 @@
package rtmp
import (
"bytes"
"github.com/pkg/errors"
"github.com/sirupsen/logrus"
"sync"
@@ -152,10 +151,9 @@ func (h *streamHandler) handleData(
) error {
bodyDecoder := message.DataBodyDecoderFor(dataMsg.Name)
r := bytes.NewReader(dataMsg.Body)
amfDec := message.NewAMFDecoder(r, dataMsg.Encoding)
amfDec := message.NewAMFDecoder(dataMsg.Body, dataMsg.Encoding)
var value message.AMFConvertible
if err := bodyDecoder(r, amfDec, &value); err != nil {
if err := bodyDecoder(dataMsg.Body, amfDec, &value); err != nil {
return err
}
@@ -179,10 +177,7 @@ func (h *streamHandler) handleCommand(
}
// Set result (NOTE: shoule use a mutex for t?)
t.commandName = cmdMsg.CommandName
t.encoding = cmdMsg.Encoding
t.body = cmdMsg.Body
close(t.doneCh)
t.Reply(cmdMsg.CommandName, cmdMsg.Encoding, cmdMsg.Body)
// Remove transacaction because this transaction is resolved
if err := h.stream.transactions.Delete(cmdMsg.TransactionID); err != nil {
@@ -194,12 +189,11 @@ func (h *streamHandler) handleCommand(
// TODO: Support onStatus
}
r := bytes.NewReader(cmdMsg.Body)
amfDec := message.NewAMFDecoder(r, cmdMsg.Encoding)
amfDec := message.NewAMFDecoder(cmdMsg.Body, cmdMsg.Encoding)
bodyDecoder := message.CmdBodyDecoderFor(cmdMsg.CommandName, cmdMsg.TransactionID)
var value message.AMFConvertible
if err := bodyDecoder(r, amfDec, &value); err != nil {
if err := bodyDecoder(cmdMsg.Body, amfDec, &value); err != nil {
return err
}
+13 -1
View File
@@ -8,7 +8,9 @@
package rtmp
import (
"bytes"
"github.com/pkg/errors"
"io"
"sync"
"github.com/yutopp/go-rtmp/message"
@@ -17,10 +19,20 @@ import (
type transaction struct {
commandName string
encoding message.EncodingType
body []byte
body *bytes.Buffer
lastErr error
doneCh chan struct{}
}
func (t *transaction) Reply(commandName string, encoding message.EncodingType, body io.Reader) {
t.commandName = commandName
t.encoding = encoding
t.body = new(bytes.Buffer)
_, err := io.Copy(t.body, body)
t.lastErr = err
close(t.doneCh)
}
type transactions struct {
transactions map[int64]*transaction
m sync.RWMutex