Allow setting MLD version

Updates #8346

PiperOrigin-RevId: 507851551
This commit is contained in:
Ghanan Gowripalan
2023-02-07 12:16:56 -08:00
committed by gVisor bot
parent efea407471
commit c3ff31ef8c
5 changed files with 165 additions and 60 deletions
@@ -244,7 +244,7 @@ type protocolMode int
const (
protocolModeV2 protocolMode = iota
protocolModeV1Forced
protocolModeV1
protocolModeV1Compatibility
)
@@ -292,30 +292,35 @@ type GenericMulticastProtocolState struct {
stateChangedReportV2TimerSet bool
}
// SetForcedV1ModeLocked sets the V1 forced configuration.
// SetV1ModeLocked sets the V1 configuration.
//
// Returns the previous configuration.
//
// Precondition: g.protocolMU must be locked.
func (g *GenericMulticastProtocolState) SetForcedV1ModeLocked(v bool) {
func (g *GenericMulticastProtocolState) SetV1ModeLocked(v bool) bool {
if v {
switch g.mode {
case protocolModeV2:
g.cancelV2ReportTimers()
case protocolModeV1Compatibility:
g.modeTimer.Stop()
case protocolModeV1Forced:
// Already in V1 forced mode; nothing to do.
case protocolModeV1:
// Already in V1 mode; nothing to do.
return true
default:
panic(fmt.Sprintf("unrecognized mode = %d", g.mode))
}
g.mode = protocolModeV1Forced
return
g.mode = protocolModeV1
return false
}
switch g.mode {
case protocolModeV2, protocolModeV1Compatibility:
// Not in V1 forced mode; nothing to do.
case protocolModeV1Forced:
// Not in V1 mode; nothing to do.
return false
case protocolModeV1:
g.mode = protocolModeV2
return true
default:
panic(fmt.Sprintf("unrecognized mode = %d", g.mode))
}
@@ -387,7 +392,7 @@ func (g *GenericMulticastProtocolState) MakeAllNonMemberLocked() {
groupAddress,
)
}
case protocolModeV1Compatibility, protocolModeV1Forced:
case protocolModeV1Compatibility, protocolModeV1:
handler = g.transitionToNonMemberLocked
default:
panic(fmt.Sprintf("unrecognized mode = %d", g.mode))
@@ -434,7 +439,7 @@ func (g *GenericMulticastProtocolState) InitializeGroupsLocked() {
switch g.mode {
case protocolModeV2:
v2ReportBuilder = g.opts.Protocol.NewReportV2Builder()
case protocolModeV1Compatibility, protocolModeV1Forced:
case protocolModeV1Compatibility, protocolModeV1:
default:
panic(fmt.Sprintf("unrecognized mode = %d", g.mode))
}
@@ -481,7 +486,7 @@ func (g *GenericMulticastProtocolState) SendQueuedReportsLocked() {
switch g.mode {
case protocolModeV2:
g.sendV2ReportAndMaybeScheduleChangedTimer(groupAddress, &info, MulticastGroupProtocolV2ReportRecordChangeToExcludeMode)
case protocolModeV1Compatibility, protocolModeV1Forced:
case protocolModeV1Compatibility, protocolModeV1:
g.maybeSendReportLocked(groupAddress, &info)
default:
panic(fmt.Sprintf("unrecognized mode = %d", g.mode))
@@ -528,7 +533,7 @@ func (g *GenericMulticastProtocolState) JoinGroupLocked(groupAddress tcpip.Addre
// Nothing meaningful we can do with the error here - we only try to
// send a delayed report once.
_, _ = reportBuilder.Send()
case protocolModeV1Compatibility, protocolModeV1Forced:
case protocolModeV1Compatibility, protocolModeV1:
g.maybeSendReportLocked(groupAddress, &info)
default:
panic(fmt.Sprintf("unrecognized mode = %d", g.mode))
@@ -674,7 +679,7 @@ func (g *GenericMulticastProtocolState) LeaveGroupLocked(groupAddress tcpip.Addr
} else {
delete(g.memberships, groupAddress)
}
case protocolModeV1Compatibility, protocolModeV1Forced:
case protocolModeV1Compatibility, protocolModeV1:
g.transitionToNonMemberLocked(groupAddress, &info)
delete(g.memberships, groupAddress)
default:
@@ -693,7 +698,7 @@ func (g *GenericMulticastProtocolState) HandleQueryV2Locked(groupAddress tcpip.A
}
switch g.mode {
case protocolModeV1Compatibility, protocolModeV1Forced:
case protocolModeV1Compatibility, protocolModeV1:
g.handleQueryInnerLocked(groupAddress, g.opts.Protocol.V2QueryMaxRespCodeToV1Delay(maxResponseCode))
return
case protocolModeV2:
@@ -914,7 +919,7 @@ func (g *GenericMulticastProtocolState) HandleQueryLocked(groupAddress tcpip.Add
}
g.mode = protocolModeV1Compatibility
g.cancelV2ReportTimers()
case protocolModeV1Forced:
case protocolModeV1:
default:
panic(fmt.Sprintf("unrecognized mode = %d", g.mode))
}
@@ -996,7 +1001,7 @@ func (g *GenericMulticastProtocolState) initializeNewMemberLocked(groupAddress t
callersV2ReportBuilder.AddRecord(MulticastGroupProtocolV2ReportRecordChangeToExcludeMode, groupAddress)
info.transmissionLeft--
}
case protocolModeV1Compatibility, protocolModeV1Forced:
case protocolModeV1Compatibility, protocolModeV1:
info.transmissionLeft = unsolicitedTransmissionCount
g.maybeSendReportLocked(groupAddress, info)
default:
@@ -60,7 +60,7 @@ func (m *mockMulticastGroupProtocol) init(opts ip.GenericMulticastProtocolOption
m.mu.genericMulticastGroup.Init(&m.mu.RWMutex, opts)
if v1Compatibility {
m.mu.genericMulticastGroup.SetForcedV1ModeLocked(true)
m.mu.genericMulticastGroup.SetV1ModeLocked(true)
}
}
@@ -82,10 +82,10 @@ func (m *mockMulticastGroupProtocol) setQueuePackets(v bool) {
m.mu.makeQueuePackets = v
}
func (m *mockMulticastGroupProtocol) setForcedV1Mode(v bool) {
func (m *mockMulticastGroupProtocol) setV1Mode(v bool) bool {
m.mu.Lock()
defer m.mu.Unlock()
m.mu.genericMulticastGroup.SetForcedV1ModeLocked(v)
return m.mu.genericMulticastGroup.SetV1ModeLocked(v)
}
func (m *mockMulticastGroupProtocol) joinGroup(addr tcpip.Address) {
@@ -1506,7 +1506,7 @@ func TestQueuedPackets(t *testing.T) {
}
}
func TestV1Compatibility(t *testing.T) {
func TestSetV1Mode(t *testing.T) {
clock := faketime.NewManualClock()
mgp := mockMulticastGroupProtocol{t: t}
mgp.init(ip.GenericMulticastProtocolOptions{
@@ -1525,13 +1525,17 @@ func TestV1Compatibility(t *testing.T) {
t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
}
mgp.setForcedV1Mode(true)
if mgp.setV1Mode(true) {
t.Error("got mgp.setV1Mode(true) = true, want = false")
}
mgp.joinGroup(addr2)
if diff := mgp.check(checkFields{sendReportGroupAddresses: []tcpip.Address{addr2}}); diff != "" {
t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
}
mgp.setForcedV1Mode(false)
if !mgp.setV1Mode(false) {
t.Error("got mgp.setV1Mode(false) = false, want = true")
}
mgp.joinGroup(addr3)
if diff := mgp.check(checkFields{sentV2Reports: []mockReportV2{{records: []mockReportV2Record{
{
+8
View File
@@ -180,6 +180,7 @@ var _ stack.GroupAddressableEndpoint = (*endpoint)(nil)
var _ stack.AddressableEndpoint = (*endpoint)(nil)
var _ stack.NetworkEndpoint = (*endpoint)(nil)
var _ stack.NDPEndpoint = (*endpoint)(nil)
var _ MLDEndpoint = (*endpoint)(nil)
var _ NDPEndpoint = (*endpoint)(nil)
type endpoint struct {
@@ -356,6 +357,13 @@ func (e *endpoint) InvalidateDefaultRouter(rtr tcpip.Address) {
e.mu.ndp.invalidateOffLinkRoute(offLinkRoute{dest: header.IPv6EmptySubnet, router: rtr})
}
// SetMLDVersion implements MLDEndpoint.
func (e *endpoint) SetMLDVersion(v MLDVersion) MLDVersion {
e.mu.Lock()
defer e.mu.Unlock()
return e.mu.mld.setVersion(v)
}
// SetNDPConfigurations implements NDPEndpoint.
func (e *endpoint) SetNDPConfigurations(c NDPConfigurations) {
c.validate()
+40
View File
@@ -33,6 +33,26 @@ const (
UnsolicitedReportIntervalMax = 10 * time.Second
)
// MLDVersion is the forced version of MLD.
type MLDVersion int
const (
_ MLDVersion = iota
// MLDVersion1 indicates MLDv1.
MLDVersion1
// MLDVersion2 indicates MLDv2. Note that MLD may still fallback to V1
// compatibility mode as required by MLDv2.
MLDVersion2
)
// MLDEndpoint is a network endpoint that supports MLD.
type MLDEndpoint interface {
// Sets the MLD version.
//
// Returns the previous MLD version.
SetMLDVersion(MLDVersion) MLDVersion
}
// MLDOptions holds options for MLD.
type MLDOptions struct {
// Enabled indicates whether MLD will be performed.
@@ -300,6 +320,26 @@ func (mld *mldState) sendQueuedReports() {
mld.genericMulticastProtocol.SendQueuedReportsLocked()
}
// setVersion sets the MLD version.
//
// Precondition: mld.ep.mu must be locked.
func (mld *mldState) setVersion(v MLDVersion) MLDVersion {
var prev bool
switch v {
case MLDVersion2:
prev = mld.genericMulticastProtocol.SetV1ModeLocked(false)
case MLDVersion1:
prev = mld.genericMulticastProtocol.SetV1ModeLocked(true)
default:
panic(fmt.Sprintf("unrecognized version = %d", v))
}
if prev {
return MLDVersion1
}
return MLDVersion2
}
// writePacket assembles and sends an MLD packet.
//
// Precondition: mld.ep.mu must be read locked.
+85 -37
View File
@@ -44,6 +44,24 @@ var (
globalAddrSNMC = header.SolicitedNodeAddr(globalAddr)
)
func checkVersion(t *testing.T, s *stack.Stack, nicID tcpip.NICID, v1 bool) {
if !v1 {
return
}
ep, err := s.GetNetworkEndpoint(nicID, header.IPv6ProtocolNumber)
if err != nil {
t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err)
}
mldEP, ok := ep.(ipv6.MLDEndpoint)
if !ok {
t.Fatalf("got (%T).(%T) = (_, false), want = (_ true)", ep, mldEP)
}
mldEP.SetMLDVersion(ipv6.MLDVersion1)
}
func validateMLDPacket(t *testing.T, v *bufferv2.View, localAddress, remoteAddress tcpip.Address, mldType header.ICMPv6Type, groupAddress tcpip.Address) {
t.Helper()
@@ -141,17 +159,7 @@ func TestIPv6JoinLeaveSolicitedNodeAddressPerformsMLD(t *testing.T) {
t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
}
if test.v1Compatibility {
createAndInjectMLDPacket(
e,
header.ICMPv6MulticastListenerQuery,
header.MLDHopLimit,
linkLocalAddr,
header.IPv6Any,
true, /* withRouterAlertOption */
header.IPv6RouterAlertMLD,
)
}
checkVersion(t, s, nicID, test.v1Compatibility)
// The stack will join an address's solicited node multicast address when
// an address is added. An MLD report message should be sent for the
@@ -306,20 +314,7 @@ func TestSendQueuedMLDReports(t *testing.T) {
}
}
checkVersion := func() {
if subTest.v1Compatibility {
createAndInjectMLDPacket(
e,
header.ICMPv6MulticastListenerQuery,
header.MLDHopLimit,
linkLocalAddr,
unusedMulticastAddr,
true, /* withRouterAlertOption */
header.IPv6RouterAlertMLD,
)
}
}
checkVersion()
checkVersion(t, s, nicID, subTest.v1Compatibility)
var reportCounter uint64
var doneCounter uint64
@@ -335,7 +330,6 @@ func TestSendQueuedMLDReports(t *testing.T) {
subTest.checkStats(t, s, reportCounter, doneCounter, reportV2Counter)
subTest.validate(t, e, header.IPv6Any, []tcpip.Address{globalMulticastAddr}, false /* leave */)
clock.Advance(time.Hour)
checkVersion()
if p := e.Read(); !p.IsNil() {
t.Errorf("got unexpected packet = %#v", p)
p.DecRef()
@@ -720,17 +714,7 @@ func TestMLDSkipProtocol(t *testing.T) {
defer e.Close()
defer c.cleanup()
if subTest.v1Compatibility {
createAndInjectMLDPacket(
e,
header.ICMPv6MulticastListenerQuery,
header.MLDHopLimit,
linkLocalAddr,
header.IPv6Any,
true, /* withRouterAlertOption */
header.IPv6RouterAlertMLD,
)
}
checkVersion(t, s, nicID, subTest.v1Compatibility)
protocolAddr := tcpip.ProtocolAddress{
Protocol: ipv6.ProtocolNumber,
@@ -775,6 +759,70 @@ func TestMLDSkipProtocol(t *testing.T) {
}
}
func TestSetMLDVersion(t *testing.T) {
const nicID = 1
c := newMLDTestContext()
s := c.s
e := channel.New(1, header.IPv6MinimumMTU, "")
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
}
defer e.Close()
defer c.cleanup()
ep, err := s.GetNetworkEndpoint(nicID, header.IPv6ProtocolNumber)
if err != nil {
t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err)
}
mldEP, ok := ep.(ipv6.MLDEndpoint)
if !ok {
t.Fatalf("got (%T).(%T) = (_, false), want = (_ true)", ep, mldEP)
}
protocolAddr := tcpip.ProtocolAddress{
Protocol: ipv6.ProtocolNumber,
AddressWithPrefix: linkLocalAddr.WithPrefix(),
}
if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
if p := e.Read(); p.IsNil() {
t.Fatal("expected a report message to be sent")
} else {
validateMLDv2ReportPacket(t, stack.PayloadSince(p.NetworkHeader()), linkLocalAddr, linkLocalAddrSNMC, header.MLDv2ReportRecordChangeToExcludeMode)
p.DecRef()
}
if got := mldEP.SetMLDVersion(ipv6.MLDVersion1); got != ipv6.MLDVersion2 {
t.Errorf("got mldEP.SetMLDVersion(%d) = %d, want = %d", ipv6.MLDVersion1, got, ipv6.MLDVersion2)
}
if err := s.JoinGroup(ipv6.ProtocolNumber, nicID, globalMulticastAddr); err != nil {
t.Fatalf("s.JoinGroup(%d, %d, %s): %s", ipv6.ProtocolNumber, nicID, globalMulticastAddr, err)
}
if p := e.Read(); p.IsNil() {
t.Fatal("expected a report message to be sent")
} else {
validateMLDPacket(t, stack.PayloadSince(p.NetworkHeader()), linkLocalAddr, globalMulticastAddr, header.ICMPv6MulticastListenerReport, globalMulticastAddr)
p.DecRef()
}
if got := mldEP.SetMLDVersion(ipv6.MLDVersion2); got != ipv6.MLDVersion1 {
t.Errorf("got mldEP.SetMLDVersion(%d) = %d, want = %d", ipv6.MLDVersion2, got, ipv6.MLDVersion1)
}
if err := s.LeaveGroup(ipv6.ProtocolNumber, nicID, globalMulticastAddr); err != nil {
t.Fatalf("s.LeaveGroup(%d, %d, %s): %s", ipv6.ProtocolNumber, nicID, globalMulticastAddr, err)
}
if p := e.Read(); p.IsNil() {
t.Fatal("expected a report message to be sent")
} else {
validateMLDv2ReportPacket(t, stack.PayloadSince(p.NetworkHeader()), linkLocalAddr, globalMulticastAddr, header.MLDv2ReportRecordChangeToIncludeMode)
p.DecRef()
}
}
func TestMain(m *testing.M) {
refs.SetLeakMode(refs.LeaksPanic)
code := m.Run()