netstack: parse incoming packet headers up-front

Netstack has traditionally parsed headers on-demand as a packet moves up the
stack. This is conceptually simple and convenient, but incompatible with
iptables, where headers can be inspected and mangled before even a routing
decision is made.

This changes header parsing to happen early in the incoming packet path, as soon
as the NIC gets the packet from a link endpoint. Even if an invalid packet is
found (e.g. a TCP header of insufficient length), the packet is passed up the
stack for proper stats bookkeeping.

PiperOrigin-RevId: 315179302
This commit is contained in:
Kevin Krakauer
2020-06-07 13:37:25 -07:00
committed by gVisor bot
parent 6260304179
commit 32b823fcdb
28 changed files with 539 additions and 364 deletions
+4 -30
View File
@@ -111,36 +111,10 @@ func (tm *TCPMatcher) Match(hook stack.Hook, pkt *stack.PacketBuffer, interfaceN
return false, false
}
// Now we need the transport header. However, this may not have been set
// yet.
// TODO(gvisor.dev/issue/170): Parsing the transport header should
// ultimately be moved into the stack.Check codepath as matchers are
// added.
var tcpHeader header.TCP
if pkt.TransportHeader != nil {
tcpHeader = header.TCP(pkt.TransportHeader)
} else {
var length int
if hook == stack.Prerouting {
// The network header hasn't been parsed yet. We have to do it here.
hdr, ok := pkt.Data.PullUp(header.IPv4MinimumSize)
if !ok {
// There's no valid TCP header here, so we hotdrop the
// packet.
return false, true
}
h := header.IPv4(hdr)
pkt.NetworkHeader = hdr
length = int(h.HeaderLength())
}
// The TCP header hasn't been parsed yet. We have to do it here.
hdr, ok := pkt.Data.PullUp(length + header.TCPMinimumSize)
if !ok {
// There's no valid TCP header here, so we hotdrop the
// packet.
return false, true
}
tcpHeader = header.TCP(hdr[length:])
tcpHeader := header.TCP(pkt.TransportHeader)
if len(tcpHeader) < header.TCPMinimumSize {
// There's no valid TCP header here, so we drop the packet immediately.
return false, true
}
// Check whether the source and destination ports are within the
+4 -30
View File
@@ -110,36 +110,10 @@ func (um *UDPMatcher) Match(hook stack.Hook, pkt *stack.PacketBuffer, interfaceN
return false, false
}
// Now we need the transport header. However, this may not have been set
// yet.
// TODO(gvisor.dev/issue/170): Parsing the transport header should
// ultimately be moved into the stack.Check codepath as matchers are
// added.
var udpHeader header.UDP
if pkt.TransportHeader != nil {
udpHeader = header.UDP(pkt.TransportHeader)
} else {
var length int
if hook == stack.Prerouting {
// The network header hasn't been parsed yet. We have to do it here.
hdr, ok := pkt.Data.PullUp(header.IPv4MinimumSize)
if !ok {
// There's no valid UDP header here, so we hotdrop the
// packet.
return false, true
}
h := header.IPv4(hdr)
pkt.NetworkHeader = hdr
length = int(h.HeaderLength())
}
// The UDP header hasn't been parsed yet. We have to do it here.
hdr, ok := pkt.Data.PullUp(length + header.UDPMinimumSize)
if !ok {
// There's no valid UDP header here, so we hotdrop the
// packet.
return false, true
}
udpHeader = header.UDP(hdr[length:])
udpHeader := header.UDP(pkt.TransportHeader)
if len(udpHeader) < header.UDPMinimumSize {
// There's no valid UDP header here, so we drop the packet immediately.
return false, true
}
// Check whether the source and destination ports are within the
+5
View File
@@ -159,6 +159,11 @@ func (b IPv4) Flags() uint8 {
return uint8(binary.BigEndian.Uint16(b[flagsFO:]) >> 13)
}
// More returns whether the more fragments flag is set.
func (b IPv4) More() bool {
return b.Flags()&IPv4FlagMoreFragments != 0
}
// TTL returns the "TTL" field of the ipv4 header.
func (b IPv4) TTL() uint8 {
return b[ttl]
@@ -354,6 +354,13 @@ func (b IPv6FragmentExtHdr) ID() uint32 {
return binary.BigEndian.Uint32(b[ipv6FragmentExtHdrIdentificationOffset:])
}
// IsAtomic returns whether the fragment header indicates an atomic fragment. An
// atomic fragment is a fragment that contains all the data required to
// reassemble a full packet.
func (b IPv6FragmentExtHdr) IsAtomic() bool {
return !b.More() && b.FragmentOffset() == 0
}
// IPv6PayloadIterator is an iterator over the contents of an IPv6 payload.
//
// The IPv6 payload may contain IPv6 extension headers before any upper layer
+12 -5
View File
@@ -99,11 +99,7 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu
}
func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
v, ok := pkt.Data.PullUp(header.ARPSize)
if !ok {
return
}
h := header.ARP(v)
h := header.ARP(pkt.NetworkHeader)
if !h.IsValid() {
return
}
@@ -209,6 +205,17 @@ func (*protocol) Close() {}
// Wait implements stack.TransportProtocol.Wait.
func (*protocol) Wait() {}
// Parse implements stack.NetworkProtocol.Parse.
func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNumber, hasTransportHdr bool, ok bool) {
hdr, ok := pkt.Data.PullUp(header.ARPSize)
if !ok {
return 0, false, false
}
pkt.NetworkHeader = hdr
pkt.Data.TrimFront(header.ARPSize)
return 0, false, true
}
var broadcastMAC = tcpip.LinkAddress([]byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff})
// NewProtocol returns an ARP network protocol.
@@ -81,8 +81,8 @@ func NewFragmentation(highMemoryLimit, lowMemoryLimit int, reassemblingTimeout t
}
}
// Process processes an incoming fragment belonging to an ID
// and returns a complete packet when all the packets belonging to that ID have been received.
// Process processes an incoming fragment belonging to an ID and returns a
// complete packet when all the packets belonging to that ID have been received.
func (f *Fragmentation) Process(id uint32, first, last uint16, more bool, vv buffer.VectorisedView) (buffer.VectorisedView, bool, error) {
f.mu.Lock()
r, ok := f.reassemblers[id]
+29 -19
View File
@@ -293,9 +293,9 @@ func TestIPv4Receive(t *testing.T) {
if err != nil {
t.Fatalf("could not find route: %v", err)
}
ep.HandlePacket(&r, &stack.PacketBuffer{
Data: view.ToVectorisedView(),
})
pkt := stack.PacketBuffer{Data: view.ToVectorisedView()}
proto.Parse(&pkt)
ep.HandlePacket(&r, &pkt)
if o.dataCalls != 1 {
t.Fatalf("Bad number of data calls: got %x, want 1", o.dataCalls)
}
@@ -382,10 +382,7 @@ func TestIPv4ReceiveControl(t *testing.T) {
o.typ = c.expectedTyp
o.extra = c.expectedExtra
vv := view[:len(view)-c.trunc].ToVectorisedView()
ep.HandlePacket(&r, &stack.PacketBuffer{
Data: vv,
})
ep.HandlePacket(&r, truncatedPacket(view, c.trunc, header.IPv4MinimumSize))
if want := c.expectedCount; o.controlCalls != want {
t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, o.controlCalls, want)
}
@@ -448,17 +445,17 @@ func TestIPv4FragmentationReceive(t *testing.T) {
}
// Send first segment.
ep.HandlePacket(&r, &stack.PacketBuffer{
Data: frag1.ToVectorisedView(),
})
pkt := stack.PacketBuffer{Data: frag1.ToVectorisedView()}
proto.Parse(&pkt)
ep.HandlePacket(&r, &pkt)
if o.dataCalls != 0 {
t.Fatalf("Bad number of data calls: got %x, want 0", o.dataCalls)
}
// Send second segment.
ep.HandlePacket(&r, &stack.PacketBuffer{
Data: frag2.ToVectorisedView(),
})
pkt = stack.PacketBuffer{Data: frag2.ToVectorisedView()}
proto.Parse(&pkt)
ep.HandlePacket(&r, &pkt)
if o.dataCalls != 1 {
t.Fatalf("Bad number of data calls: got %x, want 1", o.dataCalls)
}
@@ -538,9 +535,9 @@ func TestIPv6Receive(t *testing.T) {
t.Fatalf("could not find route: %v", err)
}
ep.HandlePacket(&r, &stack.PacketBuffer{
Data: view.ToVectorisedView(),
})
pkt := stack.PacketBuffer{Data: view.ToVectorisedView()}
proto.Parse(&pkt)
ep.HandlePacket(&r, &pkt)
if o.dataCalls != 1 {
t.Fatalf("Bad number of data calls: got %x, want 1", o.dataCalls)
}
@@ -652,12 +649,25 @@ func TestIPv6ReceiveControl(t *testing.T) {
// Set ICMPv6 checksum.
icmp.SetChecksum(header.ICMPv6Checksum(icmp, outerSrcAddr, localIpv6Addr, buffer.VectorisedView{}))
ep.HandlePacket(&r, &stack.PacketBuffer{
Data: view[:len(view)-c.trunc].ToVectorisedView(),
})
ep.HandlePacket(&r, truncatedPacket(view, c.trunc, header.IPv6MinimumSize))
if want := c.expectedCount; o.controlCalls != want {
t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, o.controlCalls, want)
}
})
}
}
// truncatedPacket returns a PacketBuffer based on a truncated view. If view,
// after truncation, is large enough to hold a network header, it makes part of
// view the packet's NetworkHeader and the rest its Data. Otherwise all of view
// becomes Data.
func truncatedPacket(view buffer.View, trunc, netHdrLen int) *stack.PacketBuffer {
v := view[:len(view)-trunc]
if len(v) < netHdrLen {
return &stack.PacketBuffer{Data: v.ToVectorisedView()}
}
return &stack.PacketBuffer{
NetworkHeader: v[:netHdrLen],
Data: v[netHdrLen:].ToVectorisedView(),
}
}
+3
View File
@@ -59,6 +59,9 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt *stack
func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer) {
stats := r.Stats()
received := stats.ICMP.V4PacketsReceived
// TODO(gvisor.dev/issue/170): ICMP packets don't have their
// TransportHeader fields set. See icmp/protocol.go:protocol.Parse for a
// full explanation.
v, ok := pkt.Data.PullUp(header.ICMPv4MinimumSize)
if !ok {
received.Invalid.Increment()
+40 -34
View File
@@ -21,6 +21,7 @@
package ipv4
import (
"fmt"
"sync/atomic"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -268,14 +269,14 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw
ep, err := e.stack.FindNetworkEndpoint(header.IPv4ProtocolNumber, netHeader.DestinationAddress())
if err == nil {
route := r.ReverseRoute(netHeader.SourceAddress(), netHeader.DestinationAddress())
handleLoopback(&route, pkt, ep)
ep.HandlePacket(&route, pkt)
return nil
}
}
if r.Loop&stack.PacketLoop != 0 {
loopedR := r.MakeLoopedRoute()
handleLoopback(&loopedR, pkt, e)
e.HandlePacket(&loopedR, pkt)
loopedR.Release()
}
if r.Loop&stack.PacketOut == 0 {
@@ -291,17 +292,6 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw
return nil
}
func handleLoopback(route *stack.Route, pkt *stack.PacketBuffer, ep stack.NetworkEndpoint) {
// The inbound path expects the network header to still be in
// the PacketBuffer's Data field.
views := make([]buffer.View, 1, 1+len(pkt.Data.Views()))
views[0] = pkt.Header.View()
views = append(views, pkt.Data.Views()...)
ep.HandlePacket(route, &stack.PacketBuffer{
Data: buffer.NewVectorisedView(len(views[0])+pkt.Data.Size(), views),
})
}
// WritePackets implements stack.NetworkEndpoint.WritePackets.
func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, params stack.NetworkHeaderParams) (int, *tcpip.Error) {
if r.Loop&stack.PacketLoop != 0 {
@@ -339,12 +329,11 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
}
if _, ok := natPkts[pkt]; ok {
netHeader := header.IPv4(pkt.NetworkHeader)
ep, err := e.stack.FindNetworkEndpoint(header.IPv4ProtocolNumber, netHeader.DestinationAddress())
if err == nil {
if ep, err := e.stack.FindNetworkEndpoint(header.IPv4ProtocolNumber, netHeader.DestinationAddress()); err == nil {
src := netHeader.SourceAddress()
dst := netHeader.DestinationAddress()
route := r.ReverseRoute(src, dst)
handleLoopback(&route, pkt, ep)
ep.HandlePacket(&route, pkt)
n++
continue
}
@@ -418,22 +407,11 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu
// HandlePacket is called by the link layer when new ipv4 packets arrive for
// this endpoint.
func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
headerView, ok := pkt.Data.PullUp(header.IPv4MinimumSize)
if !ok {
h := header.IPv4(pkt.NetworkHeader)
if !h.IsValid(pkt.Data.Size() + len(pkt.NetworkHeader) + len(pkt.TransportHeader)) {
r.Stats().IP.MalformedPacketsReceived.Increment()
return
}
h := header.IPv4(headerView)
if !h.IsValid(pkt.Data.Size()) {
r.Stats().IP.MalformedPacketsReceived.Increment()
return
}
pkt.NetworkHeader = headerView[:h.HeaderLength()]
hlen := int(h.HeaderLength())
tlen := int(h.TotalLength())
pkt.Data.TrimFront(hlen)
pkt.Data.CapLength(tlen - hlen)
// iptables filtering. All packets that reach here are intended for
// this machine and will not be forwarded.
@@ -443,9 +421,8 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
return
}
more := (h.Flags() & header.IPv4FlagMoreFragments) != 0
if more || h.FragmentOffset() != 0 {
if pkt.Data.Size() == 0 {
if h.More() || h.FragmentOffset() != 0 {
if pkt.Data.Size()+len(pkt.TransportHeader) == 0 {
// Drop the packet as it's marked as a fragment but has
// no payload.
r.Stats().IP.MalformedPacketsReceived.Increment()
@@ -464,7 +441,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
}
var ready bool
var err error
pkt.Data, ready, err = e.fragmentation.Process(hash.IPv4FragmentHash(h), h.FragmentOffset(), last, more, pkt.Data)
pkt.Data, ready, err = e.fragmentation.Process(hash.IPv4FragmentHash(h), h.FragmentOffset(), last, h.More(), pkt.Data)
if err != nil {
r.Stats().IP.MalformedPacketsReceived.Increment()
r.Stats().IP.MalformedFragmentsReceived.Increment()
@@ -476,7 +453,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
}
p := h.TransportProtocol()
if p == header.ICMPv4ProtocolNumber {
headerView.CapLength(hlen)
pkt.NetworkHeader.CapLength(int(h.HeaderLength()))
e.handleICMP(r, pkt)
return
}
@@ -556,6 +533,35 @@ func (*protocol) Close() {}
// Wait implements stack.TransportProtocol.Wait.
func (*protocol) Wait() {}
// Parse implements stack.TransportProtocol.Parse.
func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNumber, hasTransportHdr bool, ok bool) {
hdr, ok := pkt.Data.PullUp(header.IPv4MinimumSize)
if !ok {
return 0, false, false
}
ipHdr := header.IPv4(hdr)
// If there are options, pull those into hdr as well.
if headerLen := int(ipHdr.HeaderLength()); headerLen > header.IPv4MinimumSize && headerLen <= pkt.Data.Size() {
hdr, ok = pkt.Data.PullUp(headerLen)
if !ok {
panic(fmt.Sprintf("There are only %d bytes in pkt.Data, but there should be at least %d", pkt.Data.Size(), headerLen))
}
ipHdr = header.IPv4(hdr)
}
// If this is a fragment, don't bother parsing the transport header.
parseTransportHeader := true
if ipHdr.More() || ipHdr.FragmentOffset() != 0 {
parseTransportHeader = false
}
pkt.NetworkHeader = hdr
pkt.Data.TrimFront(len(hdr))
pkt.Data.CapLength(int(ipHdr.TotalLength()) - len(hdr))
return ipHdr.TransportProtocol(), parseTransportHeader, true
}
// calculateMTU calculates the network-layer payload MTU based on the link-layer
// payload mtu.
func calculateMTU(mtu uint32) uint32 {
+12
View File
@@ -652,6 +652,18 @@ func TestReceiveFragments(t *testing.T) {
},
expectedPayloads: [][]byte{udpPayload1, udpPayload2},
},
{
name: "Fragment without followup",
fragments: []fragmentData{
{
id: 1,
flags: header.IPv4FlagMoreFragments,
fragmentOffset: 0,
payload: ipv4Payload1[:64],
},
},
expectedPayloads: nil,
},
}
for _, test := range tests {
+5 -2
View File
@@ -70,17 +70,20 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt *stack
e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, hdr.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt)
}
func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt *stack.PacketBuffer, hasFragmentHeader bool) {
func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragmentHeader bool) {
stats := r.Stats().ICMP
sent := stats.V6PacketsSent
received := stats.V6PacketsReceived
// TODO(gvisor.dev/issue/170): ICMP packets don't have their
// TransportHeader fields set. See icmp/protocol.go:protocol.Parse for a
// full explanation.
v, ok := pkt.Data.PullUp(header.ICMPv6HeaderSize)
if !ok {
received.Invalid.Increment()
return
}
h := header.ICMPv6(v)
iph := header.IPv6(netHeader)
iph := header.IPv6(pkt.NetworkHeader)
// Validate ICMPv6 checksum before processing the packet.
//
+18 -25
View File
@@ -179,36 +179,32 @@ func TestICMPCounts(t *testing.T) {
},
}
handleIPv6Payload := func(hdr buffer.Prependable) {
payloadLength := hdr.UsedLength()
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
handleIPv6Payload := func(icmp header.ICMPv6) {
ip := header.IPv6(buffer.NewView(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
PayloadLength: uint16(payloadLength),
PayloadLength: uint16(len(icmp)),
NextHeader: uint8(header.ICMPv6ProtocolNumber),
HopLimit: header.NDPHopLimit,
SrcAddr: r.LocalAddress,
DstAddr: r.RemoteAddress,
})
ep.HandlePacket(&r, &stack.PacketBuffer{
Data: hdr.View().ToVectorisedView(),
NetworkHeader: buffer.View(ip),
Data: buffer.View(icmp).ToVectorisedView(),
})
}
for _, typ := range types {
extraDataLen := len(typ.extraData)
hdr := buffer.NewPrependable(header.IPv6MinimumSize + typ.size + extraDataLen)
extraData := buffer.View(hdr.Prepend(extraDataLen))
copy(extraData, typ.extraData)
pkt := header.ICMPv6(hdr.Prepend(typ.size))
pkt.SetType(typ.typ)
pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, extraData.ToVectorisedView()))
handleIPv6Payload(hdr)
icmp := header.ICMPv6(buffer.NewView(typ.size + len(typ.extraData)))
copy(icmp[typ.size:], typ.extraData)
icmp.SetType(typ.typ)
icmp.SetChecksum(header.ICMPv6Checksum(icmp[:typ.size], r.LocalAddress, r.RemoteAddress, buffer.View(typ.extraData).ToVectorisedView()))
handleIPv6Payload(icmp)
}
// Construct an empty ICMP packet so that
// Stats().ICMP.ICMPv6ReceivedPacketStats.Invalid is incremented.
handleIPv6Payload(buffer.NewPrependable(header.IPv6MinimumSize))
handleIPv6Payload(header.ICMPv6(buffer.NewView(header.IPv6MinimumSize)))
icmpv6Stats := s.Stats().ICMP.V6PacketsReceived
visitStats(reflect.ValueOf(&icmpv6Stats).Elem(), func(name string, s *tcpip.StatCounter) {
@@ -546,25 +542,22 @@ func TestICMPChecksumValidationSimple(t *testing.T) {
}
handleIPv6Payload := func(checksum bool) {
extraDataLen := len(typ.extraData)
hdr := buffer.NewPrependable(header.IPv6MinimumSize + typ.size + extraDataLen)
extraData := buffer.View(hdr.Prepend(extraDataLen))
copy(extraData, typ.extraData)
pkt := header.ICMPv6(hdr.Prepend(typ.size))
pkt.SetType(typ.typ)
icmp := header.ICMPv6(buffer.NewView(typ.size + len(typ.extraData)))
copy(icmp[typ.size:], typ.extraData)
icmp.SetType(typ.typ)
if checksum {
pkt.SetChecksum(header.ICMPv6Checksum(pkt, lladdr1, lladdr0, extraData.ToVectorisedView()))
icmp.SetChecksum(header.ICMPv6Checksum(icmp, lladdr1, lladdr0, buffer.View{}.ToVectorisedView()))
}
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
ip := header.IPv6(buffer.NewView(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
PayloadLength: uint16(typ.size + extraDataLen),
PayloadLength: uint16(len(icmp)),
NextHeader: uint8(header.ICMPv6ProtocolNumber),
HopLimit: header.NDPHopLimit,
SrcAddr: lladdr1,
DstAddr: lladdr0,
})
e.InjectInbound(ProtocolNumber, &stack.PacketBuffer{
Data: hdr.View().ToVectorisedView(),
Data: buffer.NewVectorisedView(len(ip)+len(icmp), []buffer.View{buffer.View(ip), buffer.View(icmp)}),
})
}
+98 -20
View File
@@ -171,22 +171,20 @@ func (*endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBuff
// HandlePacket is called by the link layer when new ipv6 packets arrive for
// this endpoint.
func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
headerView, ok := pkt.Data.PullUp(header.IPv6MinimumSize)
if !ok {
r.Stats().IP.MalformedPacketsReceived.Increment()
return
}
h := header.IPv6(headerView)
if !h.IsValid(pkt.Data.Size()) {
h := header.IPv6(pkt.NetworkHeader)
if !h.IsValid(pkt.Data.Size() + len(pkt.NetworkHeader) + len(pkt.TransportHeader)) {
r.Stats().IP.MalformedPacketsReceived.Increment()
return
}
pkt.NetworkHeader = headerView[:header.IPv6MinimumSize]
pkt.Data.TrimFront(header.IPv6MinimumSize)
pkt.Data.CapLength(int(h.PayloadLength()))
it := header.MakeIPv6PayloadIterator(header.IPv6ExtensionHeaderIdentifier(h.NextHeader()), pkt.Data)
// vv consists of:
// - Any IPv6 header bytes after the first 40 (i.e. extensions).
// - The transport header, if present.
// - Any other payload data.
vv := pkt.NetworkHeader[header.IPv6MinimumSize:].ToVectorisedView()
vv.AppendView(pkt.TransportHeader)
vv.Append(pkt.Data)
it := header.MakeIPv6PayloadIterator(header.IPv6ExtensionHeaderIdentifier(h.NextHeader()), vv)
hasFragmentHeader := false
for firstHeader := true; ; firstHeader = false {
@@ -262,9 +260,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
case header.IPv6FragmentExtHdr:
hasFragmentHeader = true
fragmentOffset := extHdr.FragmentOffset()
more := extHdr.More()
if !more && fragmentOffset == 0 {
if extHdr.IsAtomic() {
// This fragment extension header indicates that this packet is an
// atomic fragment. An atomic fragment is a fragment that contains
// all the data required to reassemble a full packet. As per RFC 6946,
@@ -277,9 +273,9 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
// Don't consume the iterator if we have the first fragment because we
// will use it to validate that the first fragment holds the upper layer
// header.
rawPayload := it.AsRawHeader(fragmentOffset != 0 /* consume */)
rawPayload := it.AsRawHeader(extHdr.FragmentOffset() != 0 /* consume */)
if fragmentOffset == 0 {
if extHdr.FragmentOffset() == 0 {
// Check that the iterator ends with a raw payload as the first fragment
// should include all headers up to and including any upper layer
// headers, as per RFC 8200 section 4.5; only upper layer data
@@ -332,7 +328,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
}
// The packet is a fragment, let's try to reassemble it.
start := fragmentOffset * header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit
start := extHdr.FragmentOffset() * header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit
last := start + uint16(fragmentPayloadLen) - 1
// Drop the packet if the fragmentOffset is incorrect. i.e the
@@ -345,7 +341,9 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
}
var ready bool
pkt.Data, ready, err = e.fragmentation.Process(hash.IPv6FragmentHash(h, extHdr.ID()), start, last, more, rawPayload.Buf)
// Note that pkt doesn't have its transport header set after reassembly,
// and won't until DeliverNetworkPacket sets it.
pkt.Data, ready, err = e.fragmentation.Process(hash.IPv6FragmentHash(h, extHdr.ID()), start, last, extHdr.More(), rawPayload.Buf)
if err != nil {
r.Stats().IP.MalformedPacketsReceived.Increment()
r.Stats().IP.MalformedFragmentsReceived.Increment()
@@ -394,10 +392,17 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
case header.IPv6RawPayloadHeader:
// If the last header in the payload isn't a known IPv6 extension header,
// handle it as if it is transport layer data.
// For unfragmented packets, extHdr still contains the transport header.
// Get rid of it.
//
// For reassembled fragments, pkt.TransportHeader is unset, so this is a
// no-op and pkt.Data begins with the transport header.
extHdr.Buf.TrimFront(len(pkt.TransportHeader))
pkt.Data = extHdr.Buf
if p := tcpip.TransportProtocolNumber(extHdr.Identifier); p == header.ICMPv6ProtocolNumber {
e.handleICMP(r, headerView, pkt, hasFragmentHeader)
e.handleICMP(r, pkt, hasFragmentHeader)
} else {
r.Stats().IP.PacketsDelivered.Increment()
// TODO(b/152019344): Send an ICMPv6 Parameter Problem, Code 1 error
@@ -505,6 +510,79 @@ func (*protocol) Close() {}
// Wait implements stack.TransportProtocol.Wait.
func (*protocol) Wait() {}
// Parse implements stack.TransportProtocol.Parse.
func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNumber, hasTransportHdr bool, ok bool) {
hdr, ok := pkt.Data.PullUp(header.IPv6MinimumSize)
if !ok {
return 0, false, false
}
ipHdr := header.IPv6(hdr)
// dataClone consists of:
// - Any IPv6 header bytes after the first 40 (i.e. extensions).
// - The transport header, if present.
// - Any other payload data.
views := [8]buffer.View{}
dataClone := pkt.Data.Clone(views[:])
dataClone.TrimFront(header.IPv6MinimumSize)
it := header.MakeIPv6PayloadIterator(header.IPv6ExtensionHeaderIdentifier(ipHdr.NextHeader()), dataClone)
// Iterate over the IPv6 extensions to find their length.
//
// Parsing occurs again in HandlePacket because we don't track the
// extensions in PacketBuffer. Unfortunately, that means HandlePacket
// has to do the parsing work again.
var nextHdr tcpip.TransportProtocolNumber
foundNext := true
extensionsSize := 0
traverseExtensions:
for extHdr, done, err := it.Next(); ; extHdr, done, err = it.Next() {
if err != nil {
break
}
// If we exhaust the extension list, the entire packet is the IPv6 header
// and (possibly) extensions.
if done {
extensionsSize = dataClone.Size()
foundNext = false
break
}
switch extHdr := extHdr.(type) {
case header.IPv6FragmentExtHdr:
// If this is an atomic fragment, we don't have to treat it specially.
if !extHdr.More() && extHdr.FragmentOffset() == 0 {
continue
}
// This is a non-atomic fragment and has to be re-assembled before we can
// examine the payload for a transport header.
foundNext = false
case header.IPv6RawPayloadHeader:
// We've found the payload after any extensions.
extensionsSize = dataClone.Size() - extHdr.Buf.Size()
nextHdr = tcpip.TransportProtocolNumber(extHdr.Identifier)
break traverseExtensions
default:
// Any other extension is a no-op, keep looping until we find the payload.
}
}
// Put the IPv6 header with extensions in pkt.NetworkHeader.
hdr, ok = pkt.Data.PullUp(header.IPv6MinimumSize + extensionsSize)
if !ok {
panic(fmt.Sprintf("pkt.Data should have at least %d bytes, but only has %d.", header.IPv6MinimumSize+extensionsSize, pkt.Data.Size()))
}
ipHdr = header.IPv6(hdr)
pkt.NetworkHeader = hdr
pkt.Data.TrimFront(len(hdr))
pkt.Data.CapLength(int(ipHdr.PayloadLength()))
return nextHdr, foundNext, true
}
// calculateMTU calculates the network-layer payload MTU based on the link-layer
// payload mtu.
func calculateMTU(mtu uint32) uint32 {
+17 -16
View File
@@ -551,25 +551,29 @@ func TestNDPValidation(t *testing.T) {
return s, ep, r
}
handleIPv6Payload := func(hdr buffer.Prependable, hopLimit uint8, atomicFragment bool, ep stack.NetworkEndpoint, r *stack.Route) {
handleIPv6Payload := func(payload buffer.View, hopLimit uint8, atomicFragment bool, ep stack.NetworkEndpoint, r *stack.Route) {
nextHdr := uint8(header.ICMPv6ProtocolNumber)
var extensions buffer.View
if atomicFragment {
bytes := hdr.Prepend(header.IPv6FragmentExtHdrLength)
bytes[0] = nextHdr
extensions = buffer.NewView(header.IPv6FragmentExtHdrLength)
extensions[0] = nextHdr
nextHdr = uint8(header.IPv6FragmentExtHdrIdentifier)
}
payloadLength := hdr.UsedLength()
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
ip := header.IPv6(buffer.NewView(header.IPv6MinimumSize + len(extensions)))
ip.Encode(&header.IPv6Fields{
PayloadLength: uint16(payloadLength),
PayloadLength: uint16(len(payload) + len(extensions)),
NextHeader: nextHdr,
HopLimit: hopLimit,
SrcAddr: r.LocalAddress,
DstAddr: r.RemoteAddress,
})
if n := copy(ip[header.IPv6MinimumSize:], extensions); n != len(extensions) {
t.Fatalf("expected to write %d bytes of extensions, but wrote %d", len(extensions), n)
}
ep.HandlePacket(r, &stack.PacketBuffer{
Data: hdr.View().ToVectorisedView(),
NetworkHeader: buffer.View(ip),
Data: payload.ToVectorisedView(),
})
}
@@ -676,14 +680,11 @@ func TestNDPValidation(t *testing.T) {
invalid := stats.Invalid
typStat := typ.statCounter(stats)
extraDataLen := len(typ.extraData)
hdr := buffer.NewPrependable(header.IPv6MinimumSize + typ.size + extraDataLen + header.IPv6FragmentExtHdrLength)
extraData := buffer.View(hdr.Prepend(extraDataLen))
copy(extraData, typ.extraData)
pkt := header.ICMPv6(hdr.Prepend(typ.size))
pkt.SetType(typ.typ)
pkt.SetCode(test.code)
pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, extraData.ToVectorisedView()))
icmp := header.ICMPv6(buffer.NewView(typ.size + len(typ.extraData)))
copy(icmp[typ.size:], typ.extraData)
icmp.SetType(typ.typ)
icmp.SetCode(test.code)
icmp.SetChecksum(header.ICMPv6Checksum(icmp[:typ.size], r.LocalAddress, r.RemoteAddress, buffer.View(typ.extraData).ToVectorisedView()))
// Rx count of the NDP message should initially be 0.
if got := typStat.Value(); got != 0 {
@@ -699,7 +700,7 @@ func TestNDPValidation(t *testing.T) {
t.FailNow()
}
handleIPv6Payload(hdr, test.hopLimit, test.atomicFragment, ep, &r)
handleIPv6Payload(buffer.View(icmp), test.hopLimit, test.atomicFragment, ep, &r)
// Rx count of the NDP packet should have increased.
if got := typStat.Value(); got != 1 {
-46
View File
@@ -20,7 +20,6 @@ import (
"time"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/hash/jenkins"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcpconntrack"
@@ -147,44 +146,6 @@ type ConnTrackTable struct {
Seed uint32
}
// parseHeaders sets headers in the packet.
func parseHeaders(pkt *PacketBuffer) {
newPkt := pkt.Clone()
// Set network header.
hdr, ok := newPkt.Data.PullUp(header.IPv4MinimumSize)
if !ok {
return
}
netHeader := header.IPv4(hdr)
newPkt.NetworkHeader = hdr
length := int(netHeader.HeaderLength())
// TODO(gvisor.dev/issue/170): Need to support for other
// protocols as well.
// Set transport header.
switch protocol := netHeader.TransportProtocol(); protocol {
case header.UDPProtocolNumber:
if newPkt.TransportHeader == nil {
h, ok := newPkt.Data.PullUp(length + header.UDPMinimumSize)
if !ok {
return
}
newPkt.TransportHeader = buffer.View(header.UDP(h[length:]))
}
case header.TCPProtocolNumber:
if newPkt.TransportHeader == nil {
h, ok := newPkt.Data.PullUp(length + header.TCPMinimumSize)
if !ok {
return
}
newPkt.TransportHeader = buffer.View(header.TCP(h[length:]))
}
}
pkt.NetworkHeader = newPkt.NetworkHeader
pkt.TransportHeader = newPkt.TransportHeader
}
// packetToTuple converts packet to a tuple in original direction.
func packetToTuple(pkt *PacketBuffer, hook Hook) (connTrackTuple, *tcpip.Error) {
var tuple connTrackTuple
@@ -257,13 +218,6 @@ func (ct *ConnTrackTable) getTupleHash(tuple connTrackTuple) uint32 {
// TODO(gvisor.dev/issue/170): Only TCP packets are supported. Need to support other
// transport protocols.
func (ct *ConnTrackTable) connTrackForPacket(pkt *PacketBuffer, hook Hook, createConn bool) (*connTrack, ctDirection) {
if hook == Prerouting {
// Headers will not be set in Prerouting.
// TODO(gvisor.dev/issue/170): Change this after parsing headers
// code is added.
parseHeaders(pkt)
}
var dir ctDirection
tuple, err := packetToTuple(pkt, hook)
if err != nil {
+43 -35
View File
@@ -33,6 +33,10 @@ const (
// except where another value is explicitly used. It is chosen to match
// the MTU of loopback interfaces on linux systems.
fwdTestNetDefaultMTU = 65536
dstAddrOffset = 0
srcAddrOffset = 1
protocolNumberOffset = 2
)
// fwdTestNetworkEndpoint is a network-layer protocol endpoint.
@@ -69,15 +73,8 @@ func (f *fwdTestNetworkEndpoint) ID() *NetworkEndpointID {
}
func (f *fwdTestNetworkEndpoint) HandlePacket(r *Route, pkt *PacketBuffer) {
// Consume the network header.
b, ok := pkt.Data.PullUp(fwdTestNetHeaderLen)
if !ok {
return
}
pkt.Data.TrimFront(fwdTestNetHeaderLen)
// Dispatch the packet to the transport protocol.
f.dispatcher.DeliverTransportPacket(r, tcpip.TransportProtocolNumber(b[2]), pkt)
f.dispatcher.DeliverTransportPacket(r, tcpip.TransportProtocolNumber(pkt.NetworkHeader[protocolNumberOffset]), pkt)
}
func (f *fwdTestNetworkEndpoint) MaxHeaderLength() uint16 {
@@ -100,9 +97,9 @@ func (f *fwdTestNetworkEndpoint) WritePacket(r *Route, gso *GSO, params NetworkH
// Add the protocol's header to the packet and send it to the link
// endpoint.
b := pkt.Header.Prepend(fwdTestNetHeaderLen)
b[0] = r.RemoteAddress[0]
b[1] = f.id.LocalAddress[0]
b[2] = byte(params.Protocol)
b[dstAddrOffset] = r.RemoteAddress[0]
b[srcAddrOffset] = f.id.LocalAddress[0]
b[protocolNumberOffset] = byte(params.Protocol)
return f.ep.WritePacket(r, gso, fwdTestNetNumber, pkt)
}
@@ -140,7 +137,17 @@ func (f *fwdTestNetworkProtocol) DefaultPrefixLen() int {
}
func (*fwdTestNetworkProtocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) {
return tcpip.Address(v[1:2]), tcpip.Address(v[0:1])
return tcpip.Address(v[srcAddrOffset : srcAddrOffset+1]), tcpip.Address(v[dstAddrOffset : dstAddrOffset+1])
}
func (*fwdTestNetworkProtocol) Parse(pkt *PacketBuffer) (tcpip.TransportProtocolNumber, bool, bool) {
netHeader, ok := pkt.Data.PullUp(fwdTestNetHeaderLen)
if !ok {
return 0, false, false
}
pkt.NetworkHeader = netHeader
pkt.Data.TrimFront(fwdTestNetHeaderLen)
return tcpip.TransportProtocolNumber(pkt.NetworkHeader[protocolNumberOffset]), true, true
}
func (f *fwdTestNetworkProtocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache LinkAddressCache, dispatcher TransportDispatcher, ep LinkEndpoint, _ *Stack) (NetworkEndpoint, *tcpip.Error) {
@@ -361,7 +368,7 @@ func TestForwardingWithStaticResolver(t *testing.T) {
// Inject an inbound packet to address 3 on NIC 1, and see if it is
// forwarded to NIC 2.
buf := buffer.NewView(30)
buf[0] = 3
buf[dstAddrOffset] = 3
ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{
Data: buf.ToVectorisedView(),
})
@@ -398,7 +405,7 @@ func TestForwardingWithFakeResolver(t *testing.T) {
// Inject an inbound packet to address 3 on NIC 1, and see if it is
// forwarded to NIC 2.
buf := buffer.NewView(30)
buf[0] = 3
buf[dstAddrOffset] = 3
ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{
Data: buf.ToVectorisedView(),
})
@@ -429,7 +436,7 @@ func TestForwardingWithNoResolver(t *testing.T) {
// inject an inbound packet to address 3 on NIC 1, and see if it is
// forwarded to NIC 2.
buf := buffer.NewView(30)
buf[0] = 3
buf[dstAddrOffset] = 3
ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{
Data: buf.ToVectorisedView(),
})
@@ -459,7 +466,7 @@ func TestForwardingWithFakeResolverPartialTimeout(t *testing.T) {
// Inject an inbound packet to address 4 on NIC 1. This packet should
// not be forwarded.
buf := buffer.NewView(30)
buf[0] = 4
buf[dstAddrOffset] = 4
ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{
Data: buf.ToVectorisedView(),
})
@@ -467,7 +474,7 @@ func TestForwardingWithFakeResolverPartialTimeout(t *testing.T) {
// Inject an inbound packet to address 3 on NIC 1, and see if it is
// forwarded to NIC 2.
buf = buffer.NewView(30)
buf[0] = 3
buf[dstAddrOffset] = 3
ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{
Data: buf.ToVectorisedView(),
})
@@ -480,9 +487,8 @@ func TestForwardingWithFakeResolverPartialTimeout(t *testing.T) {
t.Fatal("packet not forwarded")
}
b := p.Pkt.Data.ToView()
if b[0] != 3 {
t.Fatalf("got b[0] = %d, want = 3", b[0])
if p.Pkt.NetworkHeader[dstAddrOffset] != 3 {
t.Fatalf("got p.Pkt.NetworkHeader[dstAddrOffset] = %d, want = 3", p.Pkt.NetworkHeader[dstAddrOffset])
}
// Test that the address resolution happened correctly.
@@ -509,7 +515,7 @@ func TestForwardingWithFakeResolverTwoPackets(t *testing.T) {
// Inject two inbound packets to address 3 on NIC 1.
for i := 0; i < 2; i++ {
buf := buffer.NewView(30)
buf[0] = 3
buf[dstAddrOffset] = 3
ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{
Data: buf.ToVectorisedView(),
})
@@ -524,9 +530,8 @@ func TestForwardingWithFakeResolverTwoPackets(t *testing.T) {
t.Fatal("packet not forwarded")
}
b := p.Pkt.Data.ToView()
if b[0] != 3 {
t.Fatalf("got b[0] = %d, want = 3", b[0])
if p.Pkt.NetworkHeader[dstAddrOffset] != 3 {
t.Fatalf("got p.Pkt.NetworkHeader[dstAddrOffset] = %d, want = 3", p.Pkt.NetworkHeader[dstAddrOffset])
}
// Test that the address resolution happened correctly.
@@ -554,7 +559,7 @@ func TestForwardingWithFakeResolverManyPackets(t *testing.T) {
for i := 0; i < maxPendingPacketsPerResolution+5; i++ {
// Inject inbound 'maxPendingPacketsPerResolution + 5' packets on NIC 1.
buf := buffer.NewView(30)
buf[0] = 3
buf[dstAddrOffset] = 3
// Set the packet sequence number.
binary.BigEndian.PutUint16(buf[fwdTestNetHeaderLen:], uint16(i))
ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{
@@ -571,14 +576,18 @@ func TestForwardingWithFakeResolverManyPackets(t *testing.T) {
t.Fatal("packet not forwarded")
}
b := p.Pkt.Data.ToView()
if b[0] != 3 {
t.Fatalf("got b[0] = %d, want = 3", b[0])
if b := p.Pkt.Header.View(); b[dstAddrOffset] != 3 {
t.Fatalf("got b[dstAddrOffset] = %d, want = 3", b[dstAddrOffset])
}
// The first 5 packets should not be forwarded so the the
// sequemnce number should start with 5.
seqNumBuf, ok := p.Pkt.Data.PullUp(2) // The sequence number is a uint16 (2 bytes).
if !ok {
t.Fatalf("p.Pkt.Data is too short to hold a sequence number: %d", p.Pkt.Data.Size())
}
// The first 5 packets should not be forwarded so the sequence number should
// start with 5.
want := uint16(i + 5)
if n := binary.BigEndian.Uint16(b[fwdTestNetHeaderLen:]); n != want {
if n := binary.BigEndian.Uint16(seqNumBuf); n != want {
t.Fatalf("got the packet #%d, want = #%d", n, want)
}
@@ -609,7 +618,7 @@ func TestForwardingWithFakeResolverManyResolutions(t *testing.T) {
// Each packet has a different destination address (3 to
// maxPendingResolutions + 7).
buf := buffer.NewView(30)
buf[0] = byte(3 + i)
buf[dstAddrOffset] = byte(3 + i)
ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{
Data: buf.ToVectorisedView(),
})
@@ -626,9 +635,8 @@ func TestForwardingWithFakeResolverManyResolutions(t *testing.T) {
// The first 5 packets (address 3 to 7) should not be forwarded
// because their address resolutions are interrupted.
b := p.Pkt.Data.ToView()
if b[0] < 8 {
t.Fatalf("got b[0] = %d, want b[0] >= 8", b[0])
if p.Pkt.NetworkHeader[dstAddrOffset] < 8 {
t.Fatalf("got p.Pkt.NetworkHeader[dstAddrOffset] = %d, want p.Pkt.NetworkHeader[dstAddrOffset] >= 8", p.Pkt.NetworkHeader[dstAddrOffset])
}
// Test that the address resolution happened correctly.
-5
View File
@@ -98,11 +98,6 @@ func (rt RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrackTable, hook Hook
return RuleAccept, 0
}
// Set network header.
if hook == Prerouting {
parseHeaders(pkt)
}
// Drop the packet if network and transport header are not set.
if pkt.NetworkHeader == nil || pkt.TransportHeader == nil {
return RuleDrop, 0
+44 -7
View File
@@ -1212,12 +1212,21 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp
n.stack.stats.IP.PacketsReceived.Increment()
}
netHeader, ok := pkt.Data.PullUp(netProto.MinimumPacketSize())
// Parse headers.
transProtoNum, hasTransportHdr, ok := netProto.Parse(pkt)
if !ok {
// The packet is too small to contain a network header.
n.stack.stats.MalformedRcvdPackets.Increment()
return
}
src, dst := netProto.ParseAddresses(netHeader)
if hasTransportHdr {
// Parse the transport header if present.
if state, ok := n.stack.transportProtocols[transProtoNum]; ok {
state.proto.Parse(pkt)
}
}
src, dst := netProto.ParseAddresses(pkt.NetworkHeader)
if n.stack.handleLocal && !n.isLoopback() && n.getRef(protocol, src) != nil {
// The source address is one of our own, so we never should have gotten a
@@ -1301,8 +1310,18 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp
func (n *NIC) forwardPacket(r *Route, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) {
// TODO(b/143425874) Decrease the TTL field in forwarded packets.
if linkHeaderLen := int(n.linkEP.MaxHeaderLength()); linkHeaderLen != 0 {
pkt.Header = buffer.NewPrependable(linkHeaderLen)
// TODO(b/151227689): Avoid copying the packet when forwarding. We can do this
// by having lower layers explicity write each header instead of just
// pkt.Header.
// pkt may have set its NetworkHeader and TransportHeader. If we're
// forwarding, we'll have to copy them into pkt.Header.
pkt.Header = buffer.NewPrependable(int(n.linkEP.MaxHeaderLength()) + len(pkt.NetworkHeader) + len(pkt.TransportHeader))
if n := copy(pkt.Header.Prepend(len(pkt.TransportHeader)), pkt.TransportHeader); n != len(pkt.TransportHeader) {
panic(fmt.Sprintf("copied %d bytes, expected %d", n, len(pkt.TransportHeader)))
}
if n := copy(pkt.Header.Prepend(len(pkt.NetworkHeader)), pkt.NetworkHeader); n != len(pkt.NetworkHeader) {
panic(fmt.Sprintf("copied %d bytes, expected %d", n, len(pkt.NetworkHeader)))
}
// WritePacket takes ownership of pkt, calculate numBytes first.
@@ -1333,13 +1352,31 @@ func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolN
// validly formed.
n.stack.demux.deliverRawPacket(r, protocol, pkt)
transHeader, ok := pkt.Data.PullUp(transProto.MinimumPacketSize())
if !ok {
// TransportHeader is nil only when pkt is an ICMP packet or was reassembled
// from fragments.
if pkt.TransportHeader == nil {
// TODO(gvisor.dev/issue/170): ICMP packets don't have their
// TransportHeader fields set. See icmp/protocol.go:protocol.Parse for a
// full explanation.
if protocol == header.ICMPv4ProtocolNumber || protocol == header.ICMPv6ProtocolNumber {
transHeader, ok := pkt.Data.PullUp(transProto.MinimumPacketSize())
if !ok {
n.stack.stats.MalformedRcvdPackets.Increment()
return
}
pkt.TransportHeader = transHeader
} else {
// This is either a bad packet or was re-assembled from fragments.
transProto.Parse(pkt)
}
}
if len(pkt.TransportHeader) < transProto.MinimumPacketSize() {
n.stack.stats.MalformedRcvdPackets.Increment()
return
}
srcPort, dstPort, err := transProto.ParsePorts(transHeader)
srcPort, dstPort, err := transProto.ParsePorts(pkt.TransportHeader)
if err != nil {
n.stack.stats.MalformedRcvdPackets.Increment()
return
+13
View File
@@ -168,6 +168,11 @@ type TransportProtocol interface {
// Wait waits for any worker goroutines owned by the protocol to stop.
Wait()
// Parse sets pkt.TransportHeader and trims pkt.Data appropriately. It does
// neither and returns false if pkt.Data is too small, i.e. pkt.Data.Size() <
// MinimumPacketSize()
Parse(pkt *PacketBuffer) (ok bool)
}
// TransportDispatcher contains the methods used by the network stack to deliver
@@ -313,6 +318,14 @@ type NetworkProtocol interface {
// Wait waits for any worker goroutines owned by the protocol to stop.
Wait()
// Parse sets pkt.NetworkHeader and trims pkt.Data appropriately. It
// returns:
// - The encapsulated protocol, if present.
// - Whether there is an encapsulated transport protocol payload (e.g. ARP
// does not encapsulate anything).
// - Whether pkt.Data was large enough to parse and set pkt.NetworkHeader.
Parse(pkt *PacketBuffer) (proto tcpip.TransportProtocolNumber, hasTransportHdr bool, ok bool)
}
// NetworkDispatcher contains the methods used by the network stack to deliver
+39 -31
View File
@@ -52,6 +52,10 @@ const (
// where another value is explicitly used. It is chosen to match the MTU
// of loopback interfaces on linux systems.
defaultMTU = 65536
dstAddrOffset = 0
srcAddrOffset = 1
protocolNumberOffset = 2
)
// fakeNetworkEndpoint is a network-layer protocol endpoint. It counts sent and
@@ -94,26 +98,24 @@ func (f *fakeNetworkEndpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuff
// Increment the received packet count in the protocol descriptor.
f.proto.packetCount[int(f.id.LocalAddress[0])%len(f.proto.packetCount)]++
// Consume the network header.
b, ok := pkt.Data.PullUp(fakeNetHeaderLen)
if !ok {
return
}
pkt.Data.TrimFront(fakeNetHeaderLen)
// Handle control packets.
if b[2] == uint8(fakeControlProtocol) {
if pkt.NetworkHeader[protocolNumberOffset] == uint8(fakeControlProtocol) {
nb, ok := pkt.Data.PullUp(fakeNetHeaderLen)
if !ok {
return
}
pkt.Data.TrimFront(fakeNetHeaderLen)
f.dispatcher.DeliverTransportControlPacket(tcpip.Address(nb[1:2]), tcpip.Address(nb[0:1]), fakeNetNumber, tcpip.TransportProtocolNumber(nb[2]), stack.ControlPortUnreachable, 0, pkt)
f.dispatcher.DeliverTransportControlPacket(
tcpip.Address(nb[srcAddrOffset:srcAddrOffset+1]),
tcpip.Address(nb[dstAddrOffset:dstAddrOffset+1]),
fakeNetNumber,
tcpip.TransportProtocolNumber(nb[protocolNumberOffset]),
stack.ControlPortUnreachable, 0, pkt)
return
}
// Dispatch the packet to the transport protocol.
f.dispatcher.DeliverTransportPacket(r, tcpip.TransportProtocolNumber(b[2]), pkt)
f.dispatcher.DeliverTransportPacket(r, tcpip.TransportProtocolNumber(pkt.NetworkHeader[protocolNumberOffset]), pkt)
}
func (f *fakeNetworkEndpoint) MaxHeaderLength() uint16 {
@@ -138,18 +140,13 @@ func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params
// Add the protocol's header to the packet and send it to the link
// endpoint.
b := pkt.Header.Prepend(fakeNetHeaderLen)
b[0] = r.RemoteAddress[0]
b[1] = f.id.LocalAddress[0]
b[2] = byte(params.Protocol)
pkt.NetworkHeader = pkt.Header.Prepend(fakeNetHeaderLen)
pkt.NetworkHeader[dstAddrOffset] = r.RemoteAddress[0]
pkt.NetworkHeader[srcAddrOffset] = f.id.LocalAddress[0]
pkt.NetworkHeader[protocolNumberOffset] = byte(params.Protocol)
if r.Loop&stack.PacketLoop != 0 {
views := make([]buffer.View, 1, 1+len(pkt.Data.Views()))
views[0] = pkt.Header.View()
views = append(views, pkt.Data.Views()...)
f.HandlePacket(r, &stack.PacketBuffer{
Data: buffer.NewVectorisedView(len(views[0])+pkt.Data.Size(), views),
})
f.HandlePacket(r, pkt)
}
if r.Loop&stack.PacketOut == 0 {
return nil
@@ -205,7 +202,7 @@ func (f *fakeNetworkProtocol) PacketCount(intfAddr byte) int {
}
func (*fakeNetworkProtocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) {
return tcpip.Address(v[1:2]), tcpip.Address(v[0:1])
return tcpip.Address(v[srcAddrOffset : srcAddrOffset+1]), tcpip.Address(v[dstAddrOffset : dstAddrOffset+1])
}
func (f *fakeNetworkProtocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, ep stack.LinkEndpoint, _ *stack.Stack) (stack.NetworkEndpoint, *tcpip.Error) {
@@ -247,6 +244,17 @@ func (*fakeNetworkProtocol) Close() {}
// Wait implements TransportProtocol.Wait.
func (*fakeNetworkProtocol) Wait() {}
// Parse implements TransportProtocol.Parse.
func (*fakeNetworkProtocol) Parse(pkt *stack.PacketBuffer) (tcpip.TransportProtocolNumber, bool, bool) {
hdr, ok := pkt.Data.PullUp(fakeNetHeaderLen)
if !ok {
return 0, false, false
}
pkt.NetworkHeader = hdr
pkt.Data.TrimFront(fakeNetHeaderLen)
return tcpip.TransportProtocolNumber(hdr[protocolNumberOffset]), true, true
}
func fakeNetFactory() stack.NetworkProtocol {
return &fakeNetworkProtocol{}
}
@@ -292,7 +300,7 @@ func TestNetworkReceive(t *testing.T) {
buf := buffer.NewView(30)
// Make sure packet with wrong address is not delivered.
buf[0] = 3
buf[dstAddrOffset] = 3
ep.InjectInbound(fakeNetNumber, &stack.PacketBuffer{
Data: buf.ToVectorisedView(),
})
@@ -304,7 +312,7 @@ func TestNetworkReceive(t *testing.T) {
}
// Make sure packet is delivered to first endpoint.
buf[0] = 1
buf[dstAddrOffset] = 1
ep.InjectInbound(fakeNetNumber, &stack.PacketBuffer{
Data: buf.ToVectorisedView(),
})
@@ -316,7 +324,7 @@ func TestNetworkReceive(t *testing.T) {
}
// Make sure packet is delivered to second endpoint.
buf[0] = 2
buf[dstAddrOffset] = 2
ep.InjectInbound(fakeNetNumber, &stack.PacketBuffer{
Data: buf.ToVectorisedView(),
})
@@ -982,7 +990,7 @@ func TestAddressRemoval(t *testing.T) {
buf := buffer.NewView(30)
// Send and receive packets, and verify they are received.
buf[0] = localAddrByte
buf[dstAddrOffset] = localAddrByte
testRecv(t, fakeNet, localAddrByte, ep, buf)
testSendTo(t, s, remoteAddr, ep, nil)
@@ -1032,7 +1040,7 @@ func TestAddressRemovalWithRouteHeld(t *testing.T) {
}
// Send and receive packets, and verify they are received.
buf[0] = localAddrByte
buf[dstAddrOffset] = localAddrByte
testRecv(t, fakeNet, localAddrByte, ep, buf)
testSend(t, r, ep, nil)
testSendTo(t, s, remoteAddr, ep, nil)
@@ -1114,7 +1122,7 @@ func TestEndpointExpiration(t *testing.T) {
fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol)
buf := buffer.NewView(30)
buf[0] = localAddrByte
buf[dstAddrOffset] = localAddrByte
if promiscuous {
if err := s.SetPromiscuousMode(nicID, true); err != nil {
@@ -1277,7 +1285,7 @@ func TestPromiscuousMode(t *testing.T) {
// Write a packet, and check that it doesn't get delivered as we don't
// have a matching endpoint.
const localAddrByte byte = 0x01
buf[0] = localAddrByte
buf[dstAddrOffset] = localAddrByte
testFailingRecv(t, fakeNet, localAddrByte, ep, buf)
// Set promiscuous mode, then check that packet is delivered.
@@ -1658,7 +1666,7 @@ func TestAddressRangeAcceptsMatchingPacket(t *testing.T) {
buf := buffer.NewView(30)
const localAddrByte byte = 0x01
buf[0] = localAddrByte
buf[dstAddrOffset] = localAddrByte
subnet, err := tcpip.NewSubnet(tcpip.Address("\x00"), tcpip.AddressMask("\xF0"))
if err != nil {
t.Fatal("NewSubnet failed:", err)
@@ -1766,7 +1774,7 @@ func TestAddressRangeRejectsNonmatchingPacket(t *testing.T) {
buf := buffer.NewView(30)
const localAddrByte byte = 0x01
buf[0] = localAddrByte
buf[dstAddrOffset] = localAddrByte
subnet, err := tcpip.NewSubnet(tcpip.Address("\x10"), tcpip.AddressMask("\xF0"))
if err != nil {
t.Fatal("NewSubnet failed:", err)
@@ -2344,7 +2352,7 @@ func TestNICForwarding(t *testing.T) {
// Send a packet to dstAddr.
buf := buffer.NewView(30)
buf[0] = dstAddr[0]
buf[dstAddrOffset] = dstAddr[0]
ep1.InjectInbound(fakeNetNumber, &stack.PacketBuffer{
Data: buf.ToVectorisedView(),
})

Some files were not shown because too many files have changed in this diff Show More