mirror of
https://github.com/encounter/go-rtmp.git
synced 2026-03-30 11:12:49 -07:00
481 lines
11 KiB
Go
481 lines
11 KiB
Go
//
|
|
// Copyright (c) 2018- yutopp (yutopp@gmail.com)
|
|
//
|
|
// Distributed under the Boost Software License, Version 1.0. (See accompanying
|
|
// file LICENSE_1_0.txt or copy at https://www.boost.org/LICENSE_1_0.txt)
|
|
//
|
|
|
|
package rtmp
|
|
|
|
import (
|
|
"context"
|
|
"github.com/pkg/errors"
|
|
"github.com/sirupsen/logrus"
|
|
"io"
|
|
"math"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/yutopp/go-rtmp/message"
|
|
)
|
|
|
|
const ctrlMsgChunkStreamID = 2
|
|
|
|
const maxWriterQueueSize = 64
|
|
|
|
type ChunkMessage struct {
|
|
StreamID uint32
|
|
Message message.Message
|
|
}
|
|
|
|
type ChunkStreamer struct {
|
|
r *ChunkStreamerReader
|
|
w *ChunkStreamerWriter
|
|
|
|
readers map[int]*ChunkStreamReader
|
|
writers map[int]*ChunkStreamWriter
|
|
mu sync.Mutex
|
|
|
|
writerSched *chunkStreamerWriterSched
|
|
|
|
msgDec *message.Decoder
|
|
msgEnc *message.Encoder
|
|
|
|
selfState *StreamControlState
|
|
peerState *StreamControlState
|
|
|
|
err error
|
|
done chan struct{}
|
|
|
|
controlStreamWriter func(chunkStreamID int, timestamp uint32, msg message.Message) error
|
|
|
|
config *StreamControlStateConfig
|
|
logger logrus.FieldLogger
|
|
}
|
|
|
|
func NewChunkStreamer(r io.Reader, w io.Writer, config *StreamControlStateConfig) *ChunkStreamer {
|
|
if config == nil {
|
|
config = defaultStreamControlStateConfig
|
|
}
|
|
|
|
cs := &ChunkStreamer{
|
|
r: &ChunkStreamerReader{
|
|
reader: r,
|
|
},
|
|
w: &ChunkStreamerWriter{
|
|
writer: w,
|
|
},
|
|
|
|
readers: make(map[int]*ChunkStreamReader),
|
|
writers: make(map[int]*ChunkStreamWriter),
|
|
|
|
writerSched: &chunkStreamerWriterSched{
|
|
writers: make(chan *ChunkStreamWriter, maxWriterQueueSize),
|
|
stopCh: make(chan struct{}),
|
|
},
|
|
|
|
msgDec: message.NewDecoder(nil),
|
|
msgEnc: message.NewEncoder(nil),
|
|
|
|
selfState: NewStreamControlState(config),
|
|
peerState: NewStreamControlState(config),
|
|
|
|
done: make(chan struct{}),
|
|
|
|
config: config,
|
|
logger: logrus.StandardLogger(),
|
|
}
|
|
cs.writerSched.streamer = cs
|
|
go cs.schedWriteLoop()
|
|
|
|
return cs
|
|
}
|
|
|
|
func (cs *ChunkStreamer) Read(cmsg *ChunkMessage) (
|
|
chunkStreamID int,
|
|
timestamp uint32,
|
|
closer func(),
|
|
err error,
|
|
) {
|
|
reader, err := cs.NewChunkReader()
|
|
if err != nil {
|
|
return 0, 0, nil, err
|
|
}
|
|
closer = func() { reader.Close() }
|
|
defer func() {
|
|
if err != nil {
|
|
closer()
|
|
}
|
|
}()
|
|
|
|
cs.msgDec.Reset(reader)
|
|
if err := cs.msgDec.Decode(message.TypeID(reader.messageTypeID), &cmsg.Message); err != nil {
|
|
return 0, 0, nil, err
|
|
}
|
|
|
|
cmsg.StreamID = reader.messageStreamID
|
|
|
|
return reader.basicHeader.chunkStreamID, uint32(reader.timestamp), closer, nil
|
|
}
|
|
|
|
func (cs *ChunkStreamer) Write(
|
|
ctx context.Context, // NOTE: Retire writing when a current chunk is busy
|
|
chunkStreamID int,
|
|
timestamp uint32,
|
|
cmsg *ChunkMessage,
|
|
) error {
|
|
writer, err := cs.NewChunkWriter(ctx, chunkStreamID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
//defer writer.Close()
|
|
|
|
cs.msgEnc.Reset(writer)
|
|
if err := cs.msgEnc.Encode(cmsg.Message); err != nil {
|
|
return err
|
|
}
|
|
writer.timestamp = timestamp
|
|
writer.messageLength = uint32(writer.buf.Len())
|
|
writer.messageTypeID = byte(cmsg.Message.TypeID())
|
|
writer.messageStreamID = cmsg.StreamID
|
|
|
|
return cs.Sched(writer)
|
|
}
|
|
|
|
func (cs *ChunkStreamer) NewChunkReader() (*ChunkStreamReader, error) {
|
|
again:
|
|
isCompleted, reader, err := cs.readChunk()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if cs.r.FragmentReadBytes() >= uint32(cs.peerState.ackWindowSize/2) { // TODO: fix size
|
|
if err := cs.sendAck(cs.r.TotalReadBytes()); err != nil {
|
|
return nil, err
|
|
}
|
|
cs.r.ResetFragmentReadBytes()
|
|
}
|
|
|
|
if !isCompleted {
|
|
goto again
|
|
}
|
|
return reader, nil
|
|
}
|
|
|
|
// NewChunkWriter Returns a writer for a chunkStreamID.
|
|
// Wait until writing have been finished if the writer is running.
|
|
func (cs *ChunkStreamer) NewChunkWriter(ctx context.Context, chunkStreamID int) (*ChunkStreamWriter, error) {
|
|
writer, err := cs.prepareChunkWriter(chunkStreamID)
|
|
if err != nil {
|
|
return nil, errors.Wrapf(err, "Failed to prepare chunk writer")
|
|
}
|
|
if err := writer.Wait(ctx); err != nil {
|
|
return nil, errors.Wrapf(err, "Failed to wait chunk writer")
|
|
}
|
|
|
|
return writer, nil
|
|
}
|
|
|
|
func (cs *ChunkStreamer) Sched(writer *ChunkStreamWriter) error {
|
|
return cs.writerSched.Sched(writer)
|
|
}
|
|
|
|
func (cs *ChunkStreamer) SelfState() *StreamControlState {
|
|
return cs.selfState
|
|
}
|
|
|
|
func (cs *ChunkStreamer) PeerState() *StreamControlState {
|
|
return cs.peerState
|
|
}
|
|
|
|
func (cs *ChunkStreamer) Done() <-chan struct{} {
|
|
return cs.done
|
|
}
|
|
|
|
func (cs *ChunkStreamer) Err() error {
|
|
return cs.err
|
|
}
|
|
|
|
func (cs *ChunkStreamer) Close() error {
|
|
return cs.writerSched.Close()
|
|
}
|
|
|
|
// returns nil reader when chunk is fragmented.
|
|
func (cs *ChunkStreamer) readChunk() (bool, *ChunkStreamReader, error) {
|
|
var bh chunkBasicHeader
|
|
if err := decodeChunkBasicHeader(cs.r, &bh); err != nil {
|
|
return false, nil, err
|
|
}
|
|
cs.logger.Debugf("(READ) BasicHeader = %+v", bh)
|
|
|
|
var mh chunkMessageHeader
|
|
if err := decodeChunkMessageHeader(cs.r, bh.fmt, &mh); err != nil {
|
|
return false, nil, err
|
|
}
|
|
cs.logger.Debugf("(READ) MessageHeader = %+v", mh)
|
|
|
|
reader, err := cs.prepareChunkReader(bh.chunkStreamID)
|
|
if err != nil {
|
|
return false, nil, errors.Wrapf(err, "Failed to prepare chunk reader")
|
|
}
|
|
|
|
reader.basicHeader = bh
|
|
reader.messageHeader = mh
|
|
|
|
switch bh.fmt {
|
|
case 0:
|
|
reader.timestamp = uint64(mh.timestamp)
|
|
reader.timestampDelta = 0 // reset
|
|
reader.messageLength = mh.messageLength
|
|
reader.messageTypeID = mh.messageTypeID
|
|
reader.messageStreamID = mh.messageStreamID
|
|
|
|
case 1:
|
|
reader.timestampDelta = mh.timestampDelta
|
|
reader.messageLength = mh.messageLength
|
|
reader.messageTypeID = mh.messageTypeID
|
|
|
|
case 2:
|
|
reader.timestampDelta = mh.timestampDelta
|
|
|
|
case 3:
|
|
// DO NOTHING
|
|
|
|
default:
|
|
panic("unsupported chunk") // TODO: fix
|
|
}
|
|
|
|
cs.logger.Debugf("(READ) MessageLength = %d, Current = %d", reader.messageLength, reader.buf.Len())
|
|
|
|
expectLen := int(reader.messageLength) - reader.buf.Len()
|
|
if expectLen <= 0 {
|
|
panic("invalid state") // TODO fix
|
|
}
|
|
|
|
if uint32(expectLen) > cs.peerState.chunkSize {
|
|
expectLen = int(cs.peerState.chunkSize)
|
|
}
|
|
cs.logger.Debugf("(READ) Length = %d", expectLen)
|
|
|
|
if _, err := io.CopyN(&reader.buf, cs.r, int64(expectLen)); err != nil {
|
|
return false, nil, err
|
|
}
|
|
//cs.logger.Debugf("(READ) Buffer: %+v", reader.buf.Bytes())
|
|
|
|
if int(reader.messageLength)-reader.buf.Len() != 0 {
|
|
// fragmented
|
|
return false, reader, nil
|
|
}
|
|
|
|
// read completed, update timestamp
|
|
reader.timestamp += uint64(reader.timestampDelta)
|
|
|
|
return true, reader, nil
|
|
}
|
|
|
|
func (cs *ChunkStreamer) writeChunk(writer *ChunkStreamWriter) (bool, error) {
|
|
cs.updateWriterHeader(writer)
|
|
|
|
cs.logger.Debugf("(WRITE) Headers: Basic = %+v / Message = %+v", writer.basicHeader, writer.messageHeader)
|
|
//cs.logger.Debugf("(WRITE) Buffer: %+v", writer.buf.Bytes())
|
|
|
|
expectLen := writer.buf.Len()
|
|
if uint32(expectLen) > cs.selfState.chunkSize {
|
|
expectLen = int(cs.selfState.chunkSize)
|
|
}
|
|
|
|
if err := encodeChunkBasicHeader(cs.w, &writer.basicHeader); err != nil {
|
|
return false, err
|
|
}
|
|
if err := encodeChunkMessageHeader(cs.w, writer.basicHeader.fmt, &writer.messageHeader); err != nil {
|
|
return false, err
|
|
}
|
|
|
|
if _, err := io.CopyN(cs.w, writer, int64(expectLen)); err != nil {
|
|
return false, err
|
|
}
|
|
if err := cs.w.Flush(); err != nil {
|
|
return false, err
|
|
}
|
|
|
|
if writer.buf.Len() != 0 {
|
|
// fragmented
|
|
return false, nil
|
|
}
|
|
|
|
return true, nil
|
|
}
|
|
|
|
func (cs *ChunkStreamer) updateWriterHeader(writer *ChunkStreamWriter) {
|
|
fmt := byte(2) // default: only timestamp delta
|
|
if writer.messageHeader.messageLength != writer.messageLength || writer.messageTypeID != writer.messageHeader.messageTypeID {
|
|
// header or type id is updated, change fmt to 1 to notify difference and update state
|
|
writer.messageHeader.messageLength = writer.messageLength
|
|
writer.messageHeader.messageTypeID = writer.messageTypeID
|
|
fmt = 1
|
|
}
|
|
if writer.timestamp != writer.messageHeader.timestamp {
|
|
if writer.timestamp >= writer.messageHeader.timestamp {
|
|
writer.timestampDelta = writer.timestamp - writer.messageHeader.timestamp
|
|
} else {
|
|
// timestamp is reversed, clear timestamp data
|
|
fmt = 0
|
|
writer.timestampDelta = 0
|
|
}
|
|
}
|
|
if writer.timestampDelta == writer.messageHeader.timestampDelta && fmt == 2 {
|
|
fmt = 3
|
|
}
|
|
writer.messageHeader.timestampDelta = writer.timestampDelta
|
|
writer.messageHeader.timestamp = writer.timestamp
|
|
|
|
if writer.messageHeader.messageStreamID != writer.messageStreamID {
|
|
fmt = 0
|
|
writer.messageHeader.messageStreamID = writer.messageStreamID
|
|
}
|
|
writer.basicHeader.fmt = fmt
|
|
}
|
|
|
|
func (cs *ChunkStreamer) waitWriters() {
|
|
cs.mu.Lock()
|
|
defer cs.mu.Unlock()
|
|
|
|
// Wait until that writers are finished for 3 seconds. (NOTE: 3s is adhoc value...)
|
|
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
|
defer cancel()
|
|
|
|
for k, writer := range cs.writers {
|
|
if err := writer.Wait(ctx); err != nil {
|
|
cs.logger.Warnf("Failed to wait writer: ID = %d", k)
|
|
}
|
|
}
|
|
}
|
|
|
|
func (cs *ChunkStreamer) forceCloseWriters() {
|
|
cs.mu.Lock()
|
|
defer cs.mu.Unlock()
|
|
|
|
for _, writer := range cs.writers {
|
|
//writer.lastErr = cs.err
|
|
close(writer.closeCh)
|
|
}
|
|
}
|
|
|
|
func (cs *ChunkStreamer) schedWriteLoop() {
|
|
defer close(cs.done)
|
|
cs.err = cs.writerSched.Run()
|
|
|
|
if cs.err != nil {
|
|
cs.forceCloseWriters()
|
|
}
|
|
}
|
|
|
|
func (cs *ChunkStreamer) prepareChunkReader(chunkStreamID int) (*ChunkStreamReader, error) {
|
|
cs.mu.Lock()
|
|
defer cs.mu.Unlock()
|
|
|
|
reader, ok := cs.readers[chunkStreamID]
|
|
if !ok {
|
|
if len(cs.readers) >= cs.config.MaxChunkStreams {
|
|
return nil, errors.Errorf(
|
|
"Creating chunk streams limit exceeded(Reader): Limit = %d",
|
|
cs.config.MaxChunkStreams,
|
|
)
|
|
}
|
|
|
|
reader = &ChunkStreamReader{}
|
|
cs.readers[chunkStreamID] = reader
|
|
}
|
|
|
|
return reader, nil
|
|
}
|
|
|
|
func (cs *ChunkStreamer) prepareChunkWriter(chunkStreamID int) (*ChunkStreamWriter, error) {
|
|
cs.mu.Lock()
|
|
defer cs.mu.Unlock()
|
|
|
|
writer, ok := cs.writers[chunkStreamID]
|
|
if !ok {
|
|
if len(cs.writers) >= cs.config.MaxChunkStreams {
|
|
return nil, errors.Errorf(
|
|
"Creating chunk streams limit exceeded(Writer): Limit = %d",
|
|
cs.config.MaxChunkStreams,
|
|
)
|
|
}
|
|
|
|
writer = &ChunkStreamWriter{
|
|
basicHeader: chunkBasicHeader{
|
|
chunkStreamID: chunkStreamID,
|
|
},
|
|
messageHeader: chunkMessageHeader{
|
|
timestamp: math.MaxUint32, // initial state will be updated by writer.timestamp
|
|
},
|
|
doneCh: make(chan struct{}),
|
|
closeCh: make(chan struct{}),
|
|
}
|
|
close(writer.doneCh)
|
|
cs.writers[chunkStreamID] = writer
|
|
}
|
|
|
|
return writer, nil
|
|
}
|
|
|
|
func (cs *ChunkStreamer) sendAck(readBytes uint32) error {
|
|
cs.logger.Debugf("Sending Ack...: Bytes = %d", readBytes)
|
|
// TODO: fix timestamp
|
|
return cs.controlStreamWriter(ctrlMsgChunkStreamID, 0, &message.Ack{
|
|
SequenceNumber: readBytes,
|
|
})
|
|
}
|
|
|
|
type chunkStreamerWriterSched struct {
|
|
streamer *ChunkStreamer
|
|
writers chan *ChunkStreamWriter
|
|
stopCh chan struct{}
|
|
}
|
|
|
|
func (sched *chunkStreamerWriterSched) Sched(writer *ChunkStreamWriter) error {
|
|
sched.writers <- writer
|
|
|
|
return nil
|
|
}
|
|
|
|
func (sched *chunkStreamerWriterSched) Run() (err error) {
|
|
defer func() {
|
|
if r := recover(); r != nil {
|
|
errTmp, ok := r.(error)
|
|
if !ok {
|
|
errTmp = errors.Errorf("Panic: %+v", r)
|
|
}
|
|
err = errors.WithStack(errTmp)
|
|
}
|
|
}()
|
|
|
|
for {
|
|
select {
|
|
case writer := <-sched.writers:
|
|
isCompleted, err := sched.streamer.writeChunk(writer)
|
|
if err != nil {
|
|
writer.lastErr = err
|
|
close(writer.doneCh)
|
|
return err
|
|
}
|
|
if isCompleted {
|
|
close(writer.doneCh)
|
|
continue
|
|
}
|
|
|
|
// Enqueue writer
|
|
sched.writers <- writer
|
|
|
|
case <-sched.stopCh:
|
|
return nil
|
|
}
|
|
}
|
|
}
|
|
|
|
func (sched *chunkStreamerWriterSched) Close() error {
|
|
close(sched.stopCh)
|
|
|
|
return nil
|
|
}
|