mirror of
https://github.com/netbirdio/gvisor.git
synced 2026-05-22 17:12:49 -07:00
committed by
gVisor bot
parent
efea407471
commit
c3ff31ef8c
@@ -244,7 +244,7 @@ type protocolMode int
|
||||
|
||||
const (
|
||||
protocolModeV2 protocolMode = iota
|
||||
protocolModeV1Forced
|
||||
protocolModeV1
|
||||
protocolModeV1Compatibility
|
||||
)
|
||||
|
||||
@@ -292,30 +292,35 @@ type GenericMulticastProtocolState struct {
|
||||
stateChangedReportV2TimerSet bool
|
||||
}
|
||||
|
||||
// SetForcedV1ModeLocked sets the V1 forced configuration.
|
||||
// SetV1ModeLocked sets the V1 configuration.
|
||||
//
|
||||
// Returns the previous configuration.
|
||||
//
|
||||
// Precondition: g.protocolMU must be locked.
|
||||
func (g *GenericMulticastProtocolState) SetForcedV1ModeLocked(v bool) {
|
||||
func (g *GenericMulticastProtocolState) SetV1ModeLocked(v bool) bool {
|
||||
if v {
|
||||
switch g.mode {
|
||||
case protocolModeV2:
|
||||
g.cancelV2ReportTimers()
|
||||
case protocolModeV1Compatibility:
|
||||
g.modeTimer.Stop()
|
||||
case protocolModeV1Forced:
|
||||
// Already in V1 forced mode; nothing to do.
|
||||
case protocolModeV1:
|
||||
// Already in V1 mode; nothing to do.
|
||||
return true
|
||||
default:
|
||||
panic(fmt.Sprintf("unrecognized mode = %d", g.mode))
|
||||
}
|
||||
g.mode = protocolModeV1Forced
|
||||
return
|
||||
g.mode = protocolModeV1
|
||||
return false
|
||||
}
|
||||
|
||||
switch g.mode {
|
||||
case protocolModeV2, protocolModeV1Compatibility:
|
||||
// Not in V1 forced mode; nothing to do.
|
||||
case protocolModeV1Forced:
|
||||
// Not in V1 mode; nothing to do.
|
||||
return false
|
||||
case protocolModeV1:
|
||||
g.mode = protocolModeV2
|
||||
return true
|
||||
default:
|
||||
panic(fmt.Sprintf("unrecognized mode = %d", g.mode))
|
||||
}
|
||||
@@ -387,7 +392,7 @@ func (g *GenericMulticastProtocolState) MakeAllNonMemberLocked() {
|
||||
groupAddress,
|
||||
)
|
||||
}
|
||||
case protocolModeV1Compatibility, protocolModeV1Forced:
|
||||
case protocolModeV1Compatibility, protocolModeV1:
|
||||
handler = g.transitionToNonMemberLocked
|
||||
default:
|
||||
panic(fmt.Sprintf("unrecognized mode = %d", g.mode))
|
||||
@@ -434,7 +439,7 @@ func (g *GenericMulticastProtocolState) InitializeGroupsLocked() {
|
||||
switch g.mode {
|
||||
case protocolModeV2:
|
||||
v2ReportBuilder = g.opts.Protocol.NewReportV2Builder()
|
||||
case protocolModeV1Compatibility, protocolModeV1Forced:
|
||||
case protocolModeV1Compatibility, protocolModeV1:
|
||||
default:
|
||||
panic(fmt.Sprintf("unrecognized mode = %d", g.mode))
|
||||
}
|
||||
@@ -481,7 +486,7 @@ func (g *GenericMulticastProtocolState) SendQueuedReportsLocked() {
|
||||
switch g.mode {
|
||||
case protocolModeV2:
|
||||
g.sendV2ReportAndMaybeScheduleChangedTimer(groupAddress, &info, MulticastGroupProtocolV2ReportRecordChangeToExcludeMode)
|
||||
case protocolModeV1Compatibility, protocolModeV1Forced:
|
||||
case protocolModeV1Compatibility, protocolModeV1:
|
||||
g.maybeSendReportLocked(groupAddress, &info)
|
||||
default:
|
||||
panic(fmt.Sprintf("unrecognized mode = %d", g.mode))
|
||||
@@ -528,7 +533,7 @@ func (g *GenericMulticastProtocolState) JoinGroupLocked(groupAddress tcpip.Addre
|
||||
// Nothing meaningful we can do with the error here - we only try to
|
||||
// send a delayed report once.
|
||||
_, _ = reportBuilder.Send()
|
||||
case protocolModeV1Compatibility, protocolModeV1Forced:
|
||||
case protocolModeV1Compatibility, protocolModeV1:
|
||||
g.maybeSendReportLocked(groupAddress, &info)
|
||||
default:
|
||||
panic(fmt.Sprintf("unrecognized mode = %d", g.mode))
|
||||
@@ -674,7 +679,7 @@ func (g *GenericMulticastProtocolState) LeaveGroupLocked(groupAddress tcpip.Addr
|
||||
} else {
|
||||
delete(g.memberships, groupAddress)
|
||||
}
|
||||
case protocolModeV1Compatibility, protocolModeV1Forced:
|
||||
case protocolModeV1Compatibility, protocolModeV1:
|
||||
g.transitionToNonMemberLocked(groupAddress, &info)
|
||||
delete(g.memberships, groupAddress)
|
||||
default:
|
||||
@@ -693,7 +698,7 @@ func (g *GenericMulticastProtocolState) HandleQueryV2Locked(groupAddress tcpip.A
|
||||
}
|
||||
|
||||
switch g.mode {
|
||||
case protocolModeV1Compatibility, protocolModeV1Forced:
|
||||
case protocolModeV1Compatibility, protocolModeV1:
|
||||
g.handleQueryInnerLocked(groupAddress, g.opts.Protocol.V2QueryMaxRespCodeToV1Delay(maxResponseCode))
|
||||
return
|
||||
case protocolModeV2:
|
||||
@@ -914,7 +919,7 @@ func (g *GenericMulticastProtocolState) HandleQueryLocked(groupAddress tcpip.Add
|
||||
}
|
||||
g.mode = protocolModeV1Compatibility
|
||||
g.cancelV2ReportTimers()
|
||||
case protocolModeV1Forced:
|
||||
case protocolModeV1:
|
||||
default:
|
||||
panic(fmt.Sprintf("unrecognized mode = %d", g.mode))
|
||||
}
|
||||
@@ -996,7 +1001,7 @@ func (g *GenericMulticastProtocolState) initializeNewMemberLocked(groupAddress t
|
||||
callersV2ReportBuilder.AddRecord(MulticastGroupProtocolV2ReportRecordChangeToExcludeMode, groupAddress)
|
||||
info.transmissionLeft--
|
||||
}
|
||||
case protocolModeV1Compatibility, protocolModeV1Forced:
|
||||
case protocolModeV1Compatibility, protocolModeV1:
|
||||
info.transmissionLeft = unsolicitedTransmissionCount
|
||||
g.maybeSendReportLocked(groupAddress, info)
|
||||
default:
|
||||
|
||||
@@ -60,7 +60,7 @@ func (m *mockMulticastGroupProtocol) init(opts ip.GenericMulticastProtocolOption
|
||||
m.mu.genericMulticastGroup.Init(&m.mu.RWMutex, opts)
|
||||
|
||||
if v1Compatibility {
|
||||
m.mu.genericMulticastGroup.SetForcedV1ModeLocked(true)
|
||||
m.mu.genericMulticastGroup.SetV1ModeLocked(true)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -82,10 +82,10 @@ func (m *mockMulticastGroupProtocol) setQueuePackets(v bool) {
|
||||
m.mu.makeQueuePackets = v
|
||||
}
|
||||
|
||||
func (m *mockMulticastGroupProtocol) setForcedV1Mode(v bool) {
|
||||
func (m *mockMulticastGroupProtocol) setV1Mode(v bool) bool {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.mu.genericMulticastGroup.SetForcedV1ModeLocked(v)
|
||||
return m.mu.genericMulticastGroup.SetV1ModeLocked(v)
|
||||
}
|
||||
|
||||
func (m *mockMulticastGroupProtocol) joinGroup(addr tcpip.Address) {
|
||||
@@ -1506,7 +1506,7 @@ func TestQueuedPackets(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestV1Compatibility(t *testing.T) {
|
||||
func TestSetV1Mode(t *testing.T) {
|
||||
clock := faketime.NewManualClock()
|
||||
mgp := mockMulticastGroupProtocol{t: t}
|
||||
mgp.init(ip.GenericMulticastProtocolOptions{
|
||||
@@ -1525,13 +1525,17 @@ func TestV1Compatibility(t *testing.T) {
|
||||
t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
mgp.setForcedV1Mode(true)
|
||||
if mgp.setV1Mode(true) {
|
||||
t.Error("got mgp.setV1Mode(true) = true, want = false")
|
||||
}
|
||||
mgp.joinGroup(addr2)
|
||||
if diff := mgp.check(checkFields{sendReportGroupAddresses: []tcpip.Address{addr2}}); diff != "" {
|
||||
t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
mgp.setForcedV1Mode(false)
|
||||
if !mgp.setV1Mode(false) {
|
||||
t.Error("got mgp.setV1Mode(false) = false, want = true")
|
||||
}
|
||||
mgp.joinGroup(addr3)
|
||||
if diff := mgp.check(checkFields{sentV2Reports: []mockReportV2{{records: []mockReportV2Record{
|
||||
{
|
||||
|
||||
@@ -180,6 +180,7 @@ var _ stack.GroupAddressableEndpoint = (*endpoint)(nil)
|
||||
var _ stack.AddressableEndpoint = (*endpoint)(nil)
|
||||
var _ stack.NetworkEndpoint = (*endpoint)(nil)
|
||||
var _ stack.NDPEndpoint = (*endpoint)(nil)
|
||||
var _ MLDEndpoint = (*endpoint)(nil)
|
||||
var _ NDPEndpoint = (*endpoint)(nil)
|
||||
|
||||
type endpoint struct {
|
||||
@@ -356,6 +357,13 @@ func (e *endpoint) InvalidateDefaultRouter(rtr tcpip.Address) {
|
||||
e.mu.ndp.invalidateOffLinkRoute(offLinkRoute{dest: header.IPv6EmptySubnet, router: rtr})
|
||||
}
|
||||
|
||||
// SetMLDVersion implements MLDEndpoint.
|
||||
func (e *endpoint) SetMLDVersion(v MLDVersion) MLDVersion {
|
||||
e.mu.Lock()
|
||||
defer e.mu.Unlock()
|
||||
return e.mu.mld.setVersion(v)
|
||||
}
|
||||
|
||||
// SetNDPConfigurations implements NDPEndpoint.
|
||||
func (e *endpoint) SetNDPConfigurations(c NDPConfigurations) {
|
||||
c.validate()
|
||||
|
||||
@@ -33,6 +33,26 @@ const (
|
||||
UnsolicitedReportIntervalMax = 10 * time.Second
|
||||
)
|
||||
|
||||
// MLDVersion is the forced version of MLD.
|
||||
type MLDVersion int
|
||||
|
||||
const (
|
||||
_ MLDVersion = iota
|
||||
// MLDVersion1 indicates MLDv1.
|
||||
MLDVersion1
|
||||
// MLDVersion2 indicates MLDv2. Note that MLD may still fallback to V1
|
||||
// compatibility mode as required by MLDv2.
|
||||
MLDVersion2
|
||||
)
|
||||
|
||||
// MLDEndpoint is a network endpoint that supports MLD.
|
||||
type MLDEndpoint interface {
|
||||
// Sets the MLD version.
|
||||
//
|
||||
// Returns the previous MLD version.
|
||||
SetMLDVersion(MLDVersion) MLDVersion
|
||||
}
|
||||
|
||||
// MLDOptions holds options for MLD.
|
||||
type MLDOptions struct {
|
||||
// Enabled indicates whether MLD will be performed.
|
||||
@@ -300,6 +320,26 @@ func (mld *mldState) sendQueuedReports() {
|
||||
mld.genericMulticastProtocol.SendQueuedReportsLocked()
|
||||
}
|
||||
|
||||
// setVersion sets the MLD version.
|
||||
//
|
||||
// Precondition: mld.ep.mu must be locked.
|
||||
func (mld *mldState) setVersion(v MLDVersion) MLDVersion {
|
||||
var prev bool
|
||||
switch v {
|
||||
case MLDVersion2:
|
||||
prev = mld.genericMulticastProtocol.SetV1ModeLocked(false)
|
||||
case MLDVersion1:
|
||||
prev = mld.genericMulticastProtocol.SetV1ModeLocked(true)
|
||||
default:
|
||||
panic(fmt.Sprintf("unrecognized version = %d", v))
|
||||
}
|
||||
|
||||
if prev {
|
||||
return MLDVersion1
|
||||
}
|
||||
return MLDVersion2
|
||||
}
|
||||
|
||||
// writePacket assembles and sends an MLD packet.
|
||||
//
|
||||
// Precondition: mld.ep.mu must be read locked.
|
||||
|
||||
@@ -44,6 +44,24 @@ var (
|
||||
globalAddrSNMC = header.SolicitedNodeAddr(globalAddr)
|
||||
)
|
||||
|
||||
func checkVersion(t *testing.T, s *stack.Stack, nicID tcpip.NICID, v1 bool) {
|
||||
if !v1 {
|
||||
return
|
||||
}
|
||||
|
||||
ep, err := s.GetNetworkEndpoint(nicID, header.IPv6ProtocolNumber)
|
||||
if err != nil {
|
||||
t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err)
|
||||
}
|
||||
|
||||
mldEP, ok := ep.(ipv6.MLDEndpoint)
|
||||
if !ok {
|
||||
t.Fatalf("got (%T).(%T) = (_, false), want = (_ true)", ep, mldEP)
|
||||
}
|
||||
|
||||
mldEP.SetMLDVersion(ipv6.MLDVersion1)
|
||||
}
|
||||
|
||||
func validateMLDPacket(t *testing.T, v *bufferv2.View, localAddress, remoteAddress tcpip.Address, mldType header.ICMPv6Type, groupAddress tcpip.Address) {
|
||||
t.Helper()
|
||||
|
||||
@@ -141,17 +159,7 @@ func TestIPv6JoinLeaveSolicitedNodeAddressPerformsMLD(t *testing.T) {
|
||||
t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
|
||||
}
|
||||
|
||||
if test.v1Compatibility {
|
||||
createAndInjectMLDPacket(
|
||||
e,
|
||||
header.ICMPv6MulticastListenerQuery,
|
||||
header.MLDHopLimit,
|
||||
linkLocalAddr,
|
||||
header.IPv6Any,
|
||||
true, /* withRouterAlertOption */
|
||||
header.IPv6RouterAlertMLD,
|
||||
)
|
||||
}
|
||||
checkVersion(t, s, nicID, test.v1Compatibility)
|
||||
|
||||
// The stack will join an address's solicited node multicast address when
|
||||
// an address is added. An MLD report message should be sent for the
|
||||
@@ -306,20 +314,7 @@ func TestSendQueuedMLDReports(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
checkVersion := func() {
|
||||
if subTest.v1Compatibility {
|
||||
createAndInjectMLDPacket(
|
||||
e,
|
||||
header.ICMPv6MulticastListenerQuery,
|
||||
header.MLDHopLimit,
|
||||
linkLocalAddr,
|
||||
unusedMulticastAddr,
|
||||
true, /* withRouterAlertOption */
|
||||
header.IPv6RouterAlertMLD,
|
||||
)
|
||||
}
|
||||
}
|
||||
checkVersion()
|
||||
checkVersion(t, s, nicID, subTest.v1Compatibility)
|
||||
|
||||
var reportCounter uint64
|
||||
var doneCounter uint64
|
||||
@@ -335,7 +330,6 @@ func TestSendQueuedMLDReports(t *testing.T) {
|
||||
subTest.checkStats(t, s, reportCounter, doneCounter, reportV2Counter)
|
||||
subTest.validate(t, e, header.IPv6Any, []tcpip.Address{globalMulticastAddr}, false /* leave */)
|
||||
clock.Advance(time.Hour)
|
||||
checkVersion()
|
||||
if p := e.Read(); !p.IsNil() {
|
||||
t.Errorf("got unexpected packet = %#v", p)
|
||||
p.DecRef()
|
||||
@@ -720,17 +714,7 @@ func TestMLDSkipProtocol(t *testing.T) {
|
||||
defer e.Close()
|
||||
defer c.cleanup()
|
||||
|
||||
if subTest.v1Compatibility {
|
||||
createAndInjectMLDPacket(
|
||||
e,
|
||||
header.ICMPv6MulticastListenerQuery,
|
||||
header.MLDHopLimit,
|
||||
linkLocalAddr,
|
||||
header.IPv6Any,
|
||||
true, /* withRouterAlertOption */
|
||||
header.IPv6RouterAlertMLD,
|
||||
)
|
||||
}
|
||||
checkVersion(t, s, nicID, subTest.v1Compatibility)
|
||||
|
||||
protocolAddr := tcpip.ProtocolAddress{
|
||||
Protocol: ipv6.ProtocolNumber,
|
||||
@@ -775,6 +759,70 @@ func TestMLDSkipProtocol(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetMLDVersion(t *testing.T) {
|
||||
const nicID = 1
|
||||
|
||||
c := newMLDTestContext()
|
||||
s := c.s
|
||||
|
||||
e := channel.New(1, header.IPv6MinimumMTU, "")
|
||||
if err := s.CreateNIC(nicID, e); err != nil {
|
||||
t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
|
||||
}
|
||||
|
||||
defer e.Close()
|
||||
defer c.cleanup()
|
||||
|
||||
ep, err := s.GetNetworkEndpoint(nicID, header.IPv6ProtocolNumber)
|
||||
if err != nil {
|
||||
t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err)
|
||||
}
|
||||
mldEP, ok := ep.(ipv6.MLDEndpoint)
|
||||
if !ok {
|
||||
t.Fatalf("got (%T).(%T) = (_, false), want = (_ true)", ep, mldEP)
|
||||
}
|
||||
|
||||
protocolAddr := tcpip.ProtocolAddress{
|
||||
Protocol: ipv6.ProtocolNumber,
|
||||
AddressWithPrefix: linkLocalAddr.WithPrefix(),
|
||||
}
|
||||
if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
|
||||
t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
|
||||
}
|
||||
if p := e.Read(); p.IsNil() {
|
||||
t.Fatal("expected a report message to be sent")
|
||||
} else {
|
||||
validateMLDv2ReportPacket(t, stack.PayloadSince(p.NetworkHeader()), linkLocalAddr, linkLocalAddrSNMC, header.MLDv2ReportRecordChangeToExcludeMode)
|
||||
p.DecRef()
|
||||
}
|
||||
|
||||
if got := mldEP.SetMLDVersion(ipv6.MLDVersion1); got != ipv6.MLDVersion2 {
|
||||
t.Errorf("got mldEP.SetMLDVersion(%d) = %d, want = %d", ipv6.MLDVersion1, got, ipv6.MLDVersion2)
|
||||
}
|
||||
if err := s.JoinGroup(ipv6.ProtocolNumber, nicID, globalMulticastAddr); err != nil {
|
||||
t.Fatalf("s.JoinGroup(%d, %d, %s): %s", ipv6.ProtocolNumber, nicID, globalMulticastAddr, err)
|
||||
}
|
||||
if p := e.Read(); p.IsNil() {
|
||||
t.Fatal("expected a report message to be sent")
|
||||
} else {
|
||||
validateMLDPacket(t, stack.PayloadSince(p.NetworkHeader()), linkLocalAddr, globalMulticastAddr, header.ICMPv6MulticastListenerReport, globalMulticastAddr)
|
||||
p.DecRef()
|
||||
}
|
||||
|
||||
if got := mldEP.SetMLDVersion(ipv6.MLDVersion2); got != ipv6.MLDVersion1 {
|
||||
t.Errorf("got mldEP.SetMLDVersion(%d) = %d, want = %d", ipv6.MLDVersion2, got, ipv6.MLDVersion1)
|
||||
}
|
||||
if err := s.LeaveGroup(ipv6.ProtocolNumber, nicID, globalMulticastAddr); err != nil {
|
||||
t.Fatalf("s.LeaveGroup(%d, %d, %s): %s", ipv6.ProtocolNumber, nicID, globalMulticastAddr, err)
|
||||
}
|
||||
if p := e.Read(); p.IsNil() {
|
||||
t.Fatal("expected a report message to be sent")
|
||||
} else {
|
||||
validateMLDv2ReportPacket(t, stack.PayloadSince(p.NetworkHeader()), linkLocalAddr, globalMulticastAddr, header.MLDv2ReportRecordChangeToIncludeMode)
|
||||
p.DecRef()
|
||||
}
|
||||
}
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
refs.SetLeakMode(refs.LeaksPanic)
|
||||
code := m.Run()
|
||||
|
||||
Reference in New Issue
Block a user