diff --git a/agent.go b/agent.go index 11986b0..8f53659 100644 --- a/agent.go +++ b/agent.go @@ -3,6 +3,8 @@ package ice import ( + "bytes" + "encoding/binary" "fmt" "math/rand" "net" @@ -29,6 +31,8 @@ const ( // the number of bytes that can be buffered before we start to error maxBufferSize = 1000 * 1000 // 1MB + + stunAttrHeaderLength = 4 ) // Agent represents the ICE agent @@ -768,9 +772,88 @@ func (a *Agent) handleNewPeerReflexiveCandidate(local *Candidate, remote net.Add return nil } +func (a *Agent) assertInboundUsername(m *stun.Message) error { + usernameAttr := &stun.Username{} + usernameRawAttr, usernameFound := m.GetOneAttribute(stun.AttrUsername) + + if !usernameFound { + return fmt.Errorf("inbound packet missing Username") + } else if err := usernameAttr.Unpack(m, usernameRawAttr); err != nil { + return err + } + + expectedUsername := a.localUfrag + ":" + a.remoteUfrag + if usernameAttr.Username != expectedUsername { + return fmt.Errorf("username mismatch expected(%x) actual(%x)", expectedUsername, usernameAttr.Username) + } + + return nil +} + +func (a *Agent) assertInboundMessageIntegrity(m *stun.Message, key []byte) error { + messageIntegrityAttr := &stun.MessageIntegrity{} + messageIntegrityRawAttr, messageIntegrityAttrFound := m.GetOneAttribute(stun.AttrMessageIntegrity) + + if !messageIntegrityAttrFound { + return fmt.Errorf("inbound packet missing MessageIntegrity") + } else if err := messageIntegrityAttr.Unpack(m, messageIntegrityRawAttr); err != nil { + return err + } + + tailLength := messageIntegrityRawAttr.Length + stunAttrHeaderLength + rawCopy := make([]byte, len(m.Raw)) + copy(rawCopy, m.Raw) + + // If we have a fingerprint we need to exclude it from the MessageIntegrity computation + if rawFingerprint, hasFingerprint := m.GetOneAttribute(stun.AttrFingerprint); hasFingerprint { + fingerprintLength := rawFingerprint.Length + stunAttrHeaderLength + tailLength += fingerprintLength + + // Rewrite the packet header to be new length (excluding values we don't care about) + currLength := binary.BigEndian.Uint16(rawCopy[2:4]) + binary.BigEndian.PutUint16(rawCopy[2:], currLength-fingerprintLength) + } + + lengthToHash := len(rawCopy) - int(tailLength) + if lengthToHash < 1 { + return fmt.Errorf("unable to assert MessageIntegrity, length calculation failed (%d)", lengthToHash) + } + + computedMessageIntegrity, err := stun.MessageIntegrityCalculateHMAC(key, rawCopy[:lengthToHash]) + if err != nil { + return err + } else if !bytes.Equal(computedMessageIntegrity, messageIntegrityRawAttr.Value) { + return fmt.Errorf("messageIntegrity mismatch expected(%x) actual(%x)", computedMessageIntegrity, messageIntegrityRawAttr.Value) + } + + return nil +} + // handleInbound processes STUN traffic from a remote candidate func (a *Agent) handleInbound(m *stun.Message, local *Candidate, remote net.Addr) { + if m == nil || local == nil { + return + } a.log.Tracef("inbound STUN from %s to %s", remote.String(), local.String()) + + switch { + case m.Method == stun.MethodBinding && m.Class == stun.ClassSuccessResponse: + if err := a.assertInboundMessageIntegrity(m, []byte(a.remotePwd)); err != nil { + a.log.Warnf("discard message from (%s), %v", remote, err) + return + } + case m.Method == stun.MethodBinding && m.Class == stun.ClassRequest: + if err := a.assertInboundUsername(m); err != nil { + a.log.Warnf("discard message from (%s), %v", remote, err) + return + } else if err := a.assertInboundMessageIntegrity(m, []byte(a.localPwd)); err != nil { + a.log.Warnf("discard message from (%s), %v", remote, err) + return + } + default: + return + } + remoteCandidate := a.findRemoteCandidate(local.NetworkType, remote) if remoteCandidate == nil { a.log.Debugf("detected a new peer-reflexive candiate: %s ", remote) @@ -781,13 +864,8 @@ func (a *Agent) handleInbound(m *stun.Message, local *Candidate, remote net.Addr } return } - remoteCandidate.seen(false) - if m.Class == stun.ClassIndication { - return - } - if a.isControlling { a.handleInboundControlling(m, local, remoteCandidate) } else { diff --git a/agent_test.go b/agent_test.go index 6da1ab0..a3597eb 100644 --- a/agent_test.go +++ b/agent_test.go @@ -5,6 +5,7 @@ import ( "testing" "time" + "github.com/pion/stun" "github.com/pion/transport/test" ) @@ -199,7 +200,21 @@ func TestHandlePeerReflexive(t *testing.T) { remote := &net.UDPAddr{IP: net.ParseIP("172.17.0.3"), Port: 999} - a.handleInbound(nil, local, remote) + msg, err := stun.Build(stun.ClassRequest, stun.MethodBinding, stun.GenerateTransactionID(), + &stun.Username{Username: a.localUfrag + ":" + a.remoteUfrag}, + &stun.UseCandidate{}, + &stun.IceControlling{TieBreaker: a.tieBreaker}, + &stun.Priority{Priority: local.Priority()}, + &stun.MessageIntegrity{ + Key: []byte(a.localPwd), + }, + &stun.Fingerprint{}, + ) + if err != nil { + t.Fatal(err) + } + + a.handleInbound(msg, local, remote) // length of remote candidate list must be one now if len(a.remoteCandidates) != 1 { @@ -351,3 +366,101 @@ func TestConnectivityOnStartup(t *testing.T) { <-aConnected <-bConnected } + +func TestInboundValidity(t *testing.T) { + buildMsg := func(class stun.MessageClass, username, key string) *stun.Message { + msg, err := stun.Build(class, stun.MethodBinding, stun.GenerateTransactionID(), + &stun.Username{Username: username}, + &stun.MessageIntegrity{ + Key: []byte(key), + }, + &stun.Fingerprint{}, + ) + if err != nil { + t.Fatal(err) + } + + return msg + } + + remote := &net.UDPAddr{IP: net.ParseIP("172.17.0.3"), Port: 999} + local, err := NewCandidateHost("udp", net.ParseIP("192.168.0.2"), 777, 1) + if err != nil { + t.Fatalf("failed to create a new candidate: %v", err) + } + + t.Run("Invalid Binding requests should be discarded", func(t *testing.T) { + a, err := NewAgent(&AgentConfig{}) + if err != nil { + t.Fatalf("Error constructing ice.Agent") + } + + a.handleInbound(buildMsg(stun.ClassRequest, "invalid", a.localPwd), local, remote) + if len(a.remoteCandidates) == 1 { + t.Fatal("Binding with invalid Username was able to create prflx candidate") + } + + a.handleInbound(buildMsg(stun.ClassRequest, a.localUfrag+":"+a.remoteUfrag, "Invalid"), local, remote) + if len(a.remoteCandidates) == 1 { + t.Fatal("Binding with invalid MessageIntegrity was able to create prflx candidate") + } + }) + + t.Run("Invalid Binding success responses should be discarded", func(t *testing.T) { + a, err := NewAgent(&AgentConfig{}) + if err != nil { + t.Fatalf("Error constructing ice.Agent") + } + + a.handleInbound(buildMsg(stun.ClassSuccessResponse, a.localUfrag+":"+a.remoteUfrag, "Invalid"), local, remote) + if len(a.remoteCandidates) == 1 { + t.Fatal("Binding with invalid MessageIntegrity was able to create prflx candidate") + } + }) + + t.Run("Discard non-binding messages", func(t *testing.T) { + a, err := NewAgent(&AgentConfig{}) + if err != nil { + t.Fatalf("Error constructing ice.Agent") + } + + a.handleInbound(buildMsg(stun.ClassErrorResponse, a.localUfrag+":"+a.remoteUfrag, "Invalid"), local, remote) + if len(a.remoteCandidates) == 1 { + t.Fatal("non-binding message was able to create prflxRemote") + } + }) + + t.Run("Valid bind request", func(t *testing.T) { + a, err := NewAgent(&AgentConfig{}) + if err != nil { + t.Fatalf("Error constructing ice.Agent") + } + + a.handleInbound(buildMsg(stun.ClassRequest, a.localUfrag+":"+a.remoteUfrag, a.localPwd), local, remote) + if len(a.remoteCandidates) != 1 { + t.Fatal("Binding with valid values was unable to create prflx candidate") + } + }) + + t.Run("Valid bind without fingerprint", func(t *testing.T) { + a, err := NewAgent(&AgentConfig{}) + if err != nil { + t.Fatalf("Error constructing ice.Agent") + } + + msg, err := stun.Build(stun.ClassRequest, stun.MethodBinding, stun.GenerateTransactionID(), + &stun.Username{Username: a.localUfrag + ":" + a.remoteUfrag}, + &stun.MessageIntegrity{ + Key: []byte(a.localPwd), + }, + ) + if err != nil { + t.Fatal(err) + } + + a.handleInbound(msg, local, remote) + if len(a.remoteCandidates) != 1 { + t.Fatal("Binding with valid values (but no fingerprint) was unable to create prflx candidate") + } + }) +}