mirror of
https://github.com/netbirdio/ice.git
synced 2026-05-22 17:10:58 -07:00
Assert MessageIntegrity/Username for inbound
When we get an inbound message assert these values, also discard any other packet types besides binding. In the future we should extend to handle inbound error messages Resolves #19 Resolves #21
This commit is contained in:
@@ -3,6 +3,8 @@
|
||||
package ice
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"net"
|
||||
@@ -29,6 +31,8 @@ const (
|
||||
|
||||
// the number of bytes that can be buffered before we start to error
|
||||
maxBufferSize = 1000 * 1000 // 1MB
|
||||
|
||||
stunAttrHeaderLength = 4
|
||||
)
|
||||
|
||||
// Agent represents the ICE agent
|
||||
@@ -768,9 +772,88 @@ func (a *Agent) handleNewPeerReflexiveCandidate(local *Candidate, remote net.Add
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Agent) assertInboundUsername(m *stun.Message) error {
|
||||
usernameAttr := &stun.Username{}
|
||||
usernameRawAttr, usernameFound := m.GetOneAttribute(stun.AttrUsername)
|
||||
|
||||
if !usernameFound {
|
||||
return fmt.Errorf("inbound packet missing Username")
|
||||
} else if err := usernameAttr.Unpack(m, usernameRawAttr); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
expectedUsername := a.localUfrag + ":" + a.remoteUfrag
|
||||
if usernameAttr.Username != expectedUsername {
|
||||
return fmt.Errorf("username mismatch expected(%x) actual(%x)", expectedUsername, usernameAttr.Username)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Agent) assertInboundMessageIntegrity(m *stun.Message, key []byte) error {
|
||||
messageIntegrityAttr := &stun.MessageIntegrity{}
|
||||
messageIntegrityRawAttr, messageIntegrityAttrFound := m.GetOneAttribute(stun.AttrMessageIntegrity)
|
||||
|
||||
if !messageIntegrityAttrFound {
|
||||
return fmt.Errorf("inbound packet missing MessageIntegrity")
|
||||
} else if err := messageIntegrityAttr.Unpack(m, messageIntegrityRawAttr); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
tailLength := messageIntegrityRawAttr.Length + stunAttrHeaderLength
|
||||
rawCopy := make([]byte, len(m.Raw))
|
||||
copy(rawCopy, m.Raw)
|
||||
|
||||
// If we have a fingerprint we need to exclude it from the MessageIntegrity computation
|
||||
if rawFingerprint, hasFingerprint := m.GetOneAttribute(stun.AttrFingerprint); hasFingerprint {
|
||||
fingerprintLength := rawFingerprint.Length + stunAttrHeaderLength
|
||||
tailLength += fingerprintLength
|
||||
|
||||
// Rewrite the packet header to be new length (excluding values we don't care about)
|
||||
currLength := binary.BigEndian.Uint16(rawCopy[2:4])
|
||||
binary.BigEndian.PutUint16(rawCopy[2:], currLength-fingerprintLength)
|
||||
}
|
||||
|
||||
lengthToHash := len(rawCopy) - int(tailLength)
|
||||
if lengthToHash < 1 {
|
||||
return fmt.Errorf("unable to assert MessageIntegrity, length calculation failed (%d)", lengthToHash)
|
||||
}
|
||||
|
||||
computedMessageIntegrity, err := stun.MessageIntegrityCalculateHMAC(key, rawCopy[:lengthToHash])
|
||||
if err != nil {
|
||||
return err
|
||||
} else if !bytes.Equal(computedMessageIntegrity, messageIntegrityRawAttr.Value) {
|
||||
return fmt.Errorf("messageIntegrity mismatch expected(%x) actual(%x)", computedMessageIntegrity, messageIntegrityRawAttr.Value)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleInbound processes STUN traffic from a remote candidate
|
||||
func (a *Agent) handleInbound(m *stun.Message, local *Candidate, remote net.Addr) {
|
||||
if m == nil || local == nil {
|
||||
return
|
||||
}
|
||||
a.log.Tracef("inbound STUN from %s to %s", remote.String(), local.String())
|
||||
|
||||
switch {
|
||||
case m.Method == stun.MethodBinding && m.Class == stun.ClassSuccessResponse:
|
||||
if err := a.assertInboundMessageIntegrity(m, []byte(a.remotePwd)); err != nil {
|
||||
a.log.Warnf("discard message from (%s), %v", remote, err)
|
||||
return
|
||||
}
|
||||
case m.Method == stun.MethodBinding && m.Class == stun.ClassRequest:
|
||||
if err := a.assertInboundUsername(m); err != nil {
|
||||
a.log.Warnf("discard message from (%s), %v", remote, err)
|
||||
return
|
||||
} else if err := a.assertInboundMessageIntegrity(m, []byte(a.localPwd)); err != nil {
|
||||
a.log.Warnf("discard message from (%s), %v", remote, err)
|
||||
return
|
||||
}
|
||||
default:
|
||||
return
|
||||
}
|
||||
|
||||
remoteCandidate := a.findRemoteCandidate(local.NetworkType, remote)
|
||||
if remoteCandidate == nil {
|
||||
a.log.Debugf("detected a new peer-reflexive candiate: %s ", remote)
|
||||
@@ -781,13 +864,8 @@ func (a *Agent) handleInbound(m *stun.Message, local *Candidate, remote net.Addr
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
remoteCandidate.seen(false)
|
||||
|
||||
if m.Class == stun.ClassIndication {
|
||||
return
|
||||
}
|
||||
|
||||
if a.isControlling {
|
||||
a.handleInboundControlling(m, local, remoteCandidate)
|
||||
} else {
|
||||
|
||||
+114
-1
@@ -5,6 +5,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/pion/stun"
|
||||
"github.com/pion/transport/test"
|
||||
)
|
||||
|
||||
@@ -199,7 +200,21 @@ func TestHandlePeerReflexive(t *testing.T) {
|
||||
|
||||
remote := &net.UDPAddr{IP: net.ParseIP("172.17.0.3"), Port: 999}
|
||||
|
||||
a.handleInbound(nil, local, remote)
|
||||
msg, err := stun.Build(stun.ClassRequest, stun.MethodBinding, stun.GenerateTransactionID(),
|
||||
&stun.Username{Username: a.localUfrag + ":" + a.remoteUfrag},
|
||||
&stun.UseCandidate{},
|
||||
&stun.IceControlling{TieBreaker: a.tieBreaker},
|
||||
&stun.Priority{Priority: local.Priority()},
|
||||
&stun.MessageIntegrity{
|
||||
Key: []byte(a.localPwd),
|
||||
},
|
||||
&stun.Fingerprint{},
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
a.handleInbound(msg, local, remote)
|
||||
|
||||
// length of remote candidate list must be one now
|
||||
if len(a.remoteCandidates) != 1 {
|
||||
@@ -351,3 +366,101 @@ func TestConnectivityOnStartup(t *testing.T) {
|
||||
<-aConnected
|
||||
<-bConnected
|
||||
}
|
||||
|
||||
func TestInboundValidity(t *testing.T) {
|
||||
buildMsg := func(class stun.MessageClass, username, key string) *stun.Message {
|
||||
msg, err := stun.Build(class, stun.MethodBinding, stun.GenerateTransactionID(),
|
||||
&stun.Username{Username: username},
|
||||
&stun.MessageIntegrity{
|
||||
Key: []byte(key),
|
||||
},
|
||||
&stun.Fingerprint{},
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
return msg
|
||||
}
|
||||
|
||||
remote := &net.UDPAddr{IP: net.ParseIP("172.17.0.3"), Port: 999}
|
||||
local, err := NewCandidateHost("udp", net.ParseIP("192.168.0.2"), 777, 1)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create a new candidate: %v", err)
|
||||
}
|
||||
|
||||
t.Run("Invalid Binding requests should be discarded", func(t *testing.T) {
|
||||
a, err := NewAgent(&AgentConfig{})
|
||||
if err != nil {
|
||||
t.Fatalf("Error constructing ice.Agent")
|
||||
}
|
||||
|
||||
a.handleInbound(buildMsg(stun.ClassRequest, "invalid", a.localPwd), local, remote)
|
||||
if len(a.remoteCandidates) == 1 {
|
||||
t.Fatal("Binding with invalid Username was able to create prflx candidate")
|
||||
}
|
||||
|
||||
a.handleInbound(buildMsg(stun.ClassRequest, a.localUfrag+":"+a.remoteUfrag, "Invalid"), local, remote)
|
||||
if len(a.remoteCandidates) == 1 {
|
||||
t.Fatal("Binding with invalid MessageIntegrity was able to create prflx candidate")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Invalid Binding success responses should be discarded", func(t *testing.T) {
|
||||
a, err := NewAgent(&AgentConfig{})
|
||||
if err != nil {
|
||||
t.Fatalf("Error constructing ice.Agent")
|
||||
}
|
||||
|
||||
a.handleInbound(buildMsg(stun.ClassSuccessResponse, a.localUfrag+":"+a.remoteUfrag, "Invalid"), local, remote)
|
||||
if len(a.remoteCandidates) == 1 {
|
||||
t.Fatal("Binding with invalid MessageIntegrity was able to create prflx candidate")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Discard non-binding messages", func(t *testing.T) {
|
||||
a, err := NewAgent(&AgentConfig{})
|
||||
if err != nil {
|
||||
t.Fatalf("Error constructing ice.Agent")
|
||||
}
|
||||
|
||||
a.handleInbound(buildMsg(stun.ClassErrorResponse, a.localUfrag+":"+a.remoteUfrag, "Invalid"), local, remote)
|
||||
if len(a.remoteCandidates) == 1 {
|
||||
t.Fatal("non-binding message was able to create prflxRemote")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Valid bind request", func(t *testing.T) {
|
||||
a, err := NewAgent(&AgentConfig{})
|
||||
if err != nil {
|
||||
t.Fatalf("Error constructing ice.Agent")
|
||||
}
|
||||
|
||||
a.handleInbound(buildMsg(stun.ClassRequest, a.localUfrag+":"+a.remoteUfrag, a.localPwd), local, remote)
|
||||
if len(a.remoteCandidates) != 1 {
|
||||
t.Fatal("Binding with valid values was unable to create prflx candidate")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Valid bind without fingerprint", func(t *testing.T) {
|
||||
a, err := NewAgent(&AgentConfig{})
|
||||
if err != nil {
|
||||
t.Fatalf("Error constructing ice.Agent")
|
||||
}
|
||||
|
||||
msg, err := stun.Build(stun.ClassRequest, stun.MethodBinding, stun.GenerateTransactionID(),
|
||||
&stun.Username{Username: a.localUfrag + ":" + a.remoteUfrag},
|
||||
&stun.MessageIntegrity{
|
||||
Key: []byte(a.localPwd),
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
a.handleInbound(msg, local, remote)
|
||||
if len(a.remoteCandidates) != 1 {
|
||||
t.Fatal("Binding with valid values (but no fingerprint) was unable to create prflx candidate")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user