diff --git a/pkg/tcpip/network/internal/ip/generic_multicast_protocol.go b/pkg/tcpip/network/internal/ip/generic_multicast_protocol.go index 798a75f48..fb2e642ef 100644 --- a/pkg/tcpip/network/internal/ip/generic_multicast_protocol.go +++ b/pkg/tcpip/network/internal/ip/generic_multicast_protocol.go @@ -244,7 +244,7 @@ type protocolMode int const ( protocolModeV2 protocolMode = iota - protocolModeV1Forced + protocolModeV1 protocolModeV1Compatibility ) @@ -292,30 +292,35 @@ type GenericMulticastProtocolState struct { stateChangedReportV2TimerSet bool } -// SetForcedV1ModeLocked sets the V1 forced configuration. +// SetV1ModeLocked sets the V1 configuration. +// +// Returns the previous configuration. // // Precondition: g.protocolMU must be locked. -func (g *GenericMulticastProtocolState) SetForcedV1ModeLocked(v bool) { +func (g *GenericMulticastProtocolState) SetV1ModeLocked(v bool) bool { if v { switch g.mode { case protocolModeV2: g.cancelV2ReportTimers() case protocolModeV1Compatibility: g.modeTimer.Stop() - case protocolModeV1Forced: - // Already in V1 forced mode; nothing to do. + case protocolModeV1: + // Already in V1 mode; nothing to do. + return true default: panic(fmt.Sprintf("unrecognized mode = %d", g.mode)) } - g.mode = protocolModeV1Forced - return + g.mode = protocolModeV1 + return false } switch g.mode { case protocolModeV2, protocolModeV1Compatibility: - // Not in V1 forced mode; nothing to do. - case protocolModeV1Forced: + // Not in V1 mode; nothing to do. + return false + case protocolModeV1: g.mode = protocolModeV2 + return true default: panic(fmt.Sprintf("unrecognized mode = %d", g.mode)) } @@ -387,7 +392,7 @@ func (g *GenericMulticastProtocolState) MakeAllNonMemberLocked() { groupAddress, ) } - case protocolModeV1Compatibility, protocolModeV1Forced: + case protocolModeV1Compatibility, protocolModeV1: handler = g.transitionToNonMemberLocked default: panic(fmt.Sprintf("unrecognized mode = %d", g.mode)) @@ -434,7 +439,7 @@ func (g *GenericMulticastProtocolState) InitializeGroupsLocked() { switch g.mode { case protocolModeV2: v2ReportBuilder = g.opts.Protocol.NewReportV2Builder() - case protocolModeV1Compatibility, protocolModeV1Forced: + case protocolModeV1Compatibility, protocolModeV1: default: panic(fmt.Sprintf("unrecognized mode = %d", g.mode)) } @@ -481,7 +486,7 @@ func (g *GenericMulticastProtocolState) SendQueuedReportsLocked() { switch g.mode { case protocolModeV2: g.sendV2ReportAndMaybeScheduleChangedTimer(groupAddress, &info, MulticastGroupProtocolV2ReportRecordChangeToExcludeMode) - case protocolModeV1Compatibility, protocolModeV1Forced: + case protocolModeV1Compatibility, protocolModeV1: g.maybeSendReportLocked(groupAddress, &info) default: panic(fmt.Sprintf("unrecognized mode = %d", g.mode)) @@ -528,7 +533,7 @@ func (g *GenericMulticastProtocolState) JoinGroupLocked(groupAddress tcpip.Addre // Nothing meaningful we can do with the error here - we only try to // send a delayed report once. _, _ = reportBuilder.Send() - case protocolModeV1Compatibility, protocolModeV1Forced: + case protocolModeV1Compatibility, protocolModeV1: g.maybeSendReportLocked(groupAddress, &info) default: panic(fmt.Sprintf("unrecognized mode = %d", g.mode)) @@ -674,7 +679,7 @@ func (g *GenericMulticastProtocolState) LeaveGroupLocked(groupAddress tcpip.Addr } else { delete(g.memberships, groupAddress) } - case protocolModeV1Compatibility, protocolModeV1Forced: + case protocolModeV1Compatibility, protocolModeV1: g.transitionToNonMemberLocked(groupAddress, &info) delete(g.memberships, groupAddress) default: @@ -693,7 +698,7 @@ func (g *GenericMulticastProtocolState) HandleQueryV2Locked(groupAddress tcpip.A } switch g.mode { - case protocolModeV1Compatibility, protocolModeV1Forced: + case protocolModeV1Compatibility, protocolModeV1: g.handleQueryInnerLocked(groupAddress, g.opts.Protocol.V2QueryMaxRespCodeToV1Delay(maxResponseCode)) return case protocolModeV2: @@ -914,7 +919,7 @@ func (g *GenericMulticastProtocolState) HandleQueryLocked(groupAddress tcpip.Add } g.mode = protocolModeV1Compatibility g.cancelV2ReportTimers() - case protocolModeV1Forced: + case protocolModeV1: default: panic(fmt.Sprintf("unrecognized mode = %d", g.mode)) } @@ -996,7 +1001,7 @@ func (g *GenericMulticastProtocolState) initializeNewMemberLocked(groupAddress t callersV2ReportBuilder.AddRecord(MulticastGroupProtocolV2ReportRecordChangeToExcludeMode, groupAddress) info.transmissionLeft-- } - case protocolModeV1Compatibility, protocolModeV1Forced: + case protocolModeV1Compatibility, protocolModeV1: info.transmissionLeft = unsolicitedTransmissionCount g.maybeSendReportLocked(groupAddress, info) default: diff --git a/pkg/tcpip/network/internal/ip/generic_multicast_protocol_test.go b/pkg/tcpip/network/internal/ip/generic_multicast_protocol_test.go index ca45db31f..e145ab7ac 100644 --- a/pkg/tcpip/network/internal/ip/generic_multicast_protocol_test.go +++ b/pkg/tcpip/network/internal/ip/generic_multicast_protocol_test.go @@ -60,7 +60,7 @@ func (m *mockMulticastGroupProtocol) init(opts ip.GenericMulticastProtocolOption m.mu.genericMulticastGroup.Init(&m.mu.RWMutex, opts) if v1Compatibility { - m.mu.genericMulticastGroup.SetForcedV1ModeLocked(true) + m.mu.genericMulticastGroup.SetV1ModeLocked(true) } } @@ -82,10 +82,10 @@ func (m *mockMulticastGroupProtocol) setQueuePackets(v bool) { m.mu.makeQueuePackets = v } -func (m *mockMulticastGroupProtocol) setForcedV1Mode(v bool) { +func (m *mockMulticastGroupProtocol) setV1Mode(v bool) bool { m.mu.Lock() defer m.mu.Unlock() - m.mu.genericMulticastGroup.SetForcedV1ModeLocked(v) + return m.mu.genericMulticastGroup.SetV1ModeLocked(v) } func (m *mockMulticastGroupProtocol) joinGroup(addr tcpip.Address) { @@ -1506,7 +1506,7 @@ func TestQueuedPackets(t *testing.T) { } } -func TestV1Compatibility(t *testing.T) { +func TestSetV1Mode(t *testing.T) { clock := faketime.NewManualClock() mgp := mockMulticastGroupProtocol{t: t} mgp.init(ip.GenericMulticastProtocolOptions{ @@ -1525,13 +1525,17 @@ func TestV1Compatibility(t *testing.T) { t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } - mgp.setForcedV1Mode(true) + if mgp.setV1Mode(true) { + t.Error("got mgp.setV1Mode(true) = true, want = false") + } mgp.joinGroup(addr2) if diff := mgp.check(checkFields{sendReportGroupAddresses: []tcpip.Address{addr2}}); diff != "" { t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } - mgp.setForcedV1Mode(false) + if !mgp.setV1Mode(false) { + t.Error("got mgp.setV1Mode(false) = false, want = true") + } mgp.joinGroup(addr3) if diff := mgp.check(checkFields{sentV2Reports: []mockReportV2{{records: []mockReportV2Record{ { diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index 31f48b37f..c3b1a88b5 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -180,6 +180,7 @@ var _ stack.GroupAddressableEndpoint = (*endpoint)(nil) var _ stack.AddressableEndpoint = (*endpoint)(nil) var _ stack.NetworkEndpoint = (*endpoint)(nil) var _ stack.NDPEndpoint = (*endpoint)(nil) +var _ MLDEndpoint = (*endpoint)(nil) var _ NDPEndpoint = (*endpoint)(nil) type endpoint struct { @@ -356,6 +357,13 @@ func (e *endpoint) InvalidateDefaultRouter(rtr tcpip.Address) { e.mu.ndp.invalidateOffLinkRoute(offLinkRoute{dest: header.IPv6EmptySubnet, router: rtr}) } +// SetMLDVersion implements MLDEndpoint. +func (e *endpoint) SetMLDVersion(v MLDVersion) MLDVersion { + e.mu.Lock() + defer e.mu.Unlock() + return e.mu.mld.setVersion(v) +} + // SetNDPConfigurations implements NDPEndpoint. func (e *endpoint) SetNDPConfigurations(c NDPConfigurations) { c.validate() diff --git a/pkg/tcpip/network/ipv6/mld.go b/pkg/tcpip/network/ipv6/mld.go index 21bf98bfe..4444bb424 100644 --- a/pkg/tcpip/network/ipv6/mld.go +++ b/pkg/tcpip/network/ipv6/mld.go @@ -33,6 +33,26 @@ const ( UnsolicitedReportIntervalMax = 10 * time.Second ) +// MLDVersion is the forced version of MLD. +type MLDVersion int + +const ( + _ MLDVersion = iota + // MLDVersion1 indicates MLDv1. + MLDVersion1 + // MLDVersion2 indicates MLDv2. Note that MLD may still fallback to V1 + // compatibility mode as required by MLDv2. + MLDVersion2 +) + +// MLDEndpoint is a network endpoint that supports MLD. +type MLDEndpoint interface { + // Sets the MLD version. + // + // Returns the previous MLD version. + SetMLDVersion(MLDVersion) MLDVersion +} + // MLDOptions holds options for MLD. type MLDOptions struct { // Enabled indicates whether MLD will be performed. @@ -300,6 +320,26 @@ func (mld *mldState) sendQueuedReports() { mld.genericMulticastProtocol.SendQueuedReportsLocked() } +// setVersion sets the MLD version. +// +// Precondition: mld.ep.mu must be locked. +func (mld *mldState) setVersion(v MLDVersion) MLDVersion { + var prev bool + switch v { + case MLDVersion2: + prev = mld.genericMulticastProtocol.SetV1ModeLocked(false) + case MLDVersion1: + prev = mld.genericMulticastProtocol.SetV1ModeLocked(true) + default: + panic(fmt.Sprintf("unrecognized version = %d", v)) + } + + if prev { + return MLDVersion1 + } + return MLDVersion2 +} + // writePacket assembles and sends an MLD packet. // // Precondition: mld.ep.mu must be read locked. diff --git a/pkg/tcpip/network/ipv6/mld_test.go b/pkg/tcpip/network/ipv6/mld_test.go index 4882c6396..0aeab2f21 100644 --- a/pkg/tcpip/network/ipv6/mld_test.go +++ b/pkg/tcpip/network/ipv6/mld_test.go @@ -44,6 +44,24 @@ var ( globalAddrSNMC = header.SolicitedNodeAddr(globalAddr) ) +func checkVersion(t *testing.T, s *stack.Stack, nicID tcpip.NICID, v1 bool) { + if !v1 { + return + } + + ep, err := s.GetNetworkEndpoint(nicID, header.IPv6ProtocolNumber) + if err != nil { + t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err) + } + + mldEP, ok := ep.(ipv6.MLDEndpoint) + if !ok { + t.Fatalf("got (%T).(%T) = (_, false), want = (_ true)", ep, mldEP) + } + + mldEP.SetMLDVersion(ipv6.MLDVersion1) +} + func validateMLDPacket(t *testing.T, v *bufferv2.View, localAddress, remoteAddress tcpip.Address, mldType header.ICMPv6Type, groupAddress tcpip.Address) { t.Helper() @@ -141,17 +159,7 @@ func TestIPv6JoinLeaveSolicitedNodeAddressPerformsMLD(t *testing.T) { t.Fatalf("CreateNIC(%d, _): %s", nicID, err) } - if test.v1Compatibility { - createAndInjectMLDPacket( - e, - header.ICMPv6MulticastListenerQuery, - header.MLDHopLimit, - linkLocalAddr, - header.IPv6Any, - true, /* withRouterAlertOption */ - header.IPv6RouterAlertMLD, - ) - } + checkVersion(t, s, nicID, test.v1Compatibility) // The stack will join an address's solicited node multicast address when // an address is added. An MLD report message should be sent for the @@ -306,20 +314,7 @@ func TestSendQueuedMLDReports(t *testing.T) { } } - checkVersion := func() { - if subTest.v1Compatibility { - createAndInjectMLDPacket( - e, - header.ICMPv6MulticastListenerQuery, - header.MLDHopLimit, - linkLocalAddr, - unusedMulticastAddr, - true, /* withRouterAlertOption */ - header.IPv6RouterAlertMLD, - ) - } - } - checkVersion() + checkVersion(t, s, nicID, subTest.v1Compatibility) var reportCounter uint64 var doneCounter uint64 @@ -335,7 +330,6 @@ func TestSendQueuedMLDReports(t *testing.T) { subTest.checkStats(t, s, reportCounter, doneCounter, reportV2Counter) subTest.validate(t, e, header.IPv6Any, []tcpip.Address{globalMulticastAddr}, false /* leave */) clock.Advance(time.Hour) - checkVersion() if p := e.Read(); !p.IsNil() { t.Errorf("got unexpected packet = %#v", p) p.DecRef() @@ -720,17 +714,7 @@ func TestMLDSkipProtocol(t *testing.T) { defer e.Close() defer c.cleanup() - if subTest.v1Compatibility { - createAndInjectMLDPacket( - e, - header.ICMPv6MulticastListenerQuery, - header.MLDHopLimit, - linkLocalAddr, - header.IPv6Any, - true, /* withRouterAlertOption */ - header.IPv6RouterAlertMLD, - ) - } + checkVersion(t, s, nicID, subTest.v1Compatibility) protocolAddr := tcpip.ProtocolAddress{ Protocol: ipv6.ProtocolNumber, @@ -775,6 +759,70 @@ func TestMLDSkipProtocol(t *testing.T) { } } +func TestSetMLDVersion(t *testing.T) { + const nicID = 1 + + c := newMLDTestContext() + s := c.s + + e := channel.New(1, header.IPv6MinimumMTU, "") + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID, err) + } + + defer e.Close() + defer c.cleanup() + + ep, err := s.GetNetworkEndpoint(nicID, header.IPv6ProtocolNumber) + if err != nil { + t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err) + } + mldEP, ok := ep.(ipv6.MLDEndpoint) + if !ok { + t.Fatalf("got (%T).(%T) = (_, false), want = (_ true)", ep, mldEP) + } + + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv6.ProtocolNumber, + AddressWithPrefix: linkLocalAddr.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) + } + if p := e.Read(); p.IsNil() { + t.Fatal("expected a report message to be sent") + } else { + validateMLDv2ReportPacket(t, stack.PayloadSince(p.NetworkHeader()), linkLocalAddr, linkLocalAddrSNMC, header.MLDv2ReportRecordChangeToExcludeMode) + p.DecRef() + } + + if got := mldEP.SetMLDVersion(ipv6.MLDVersion1); got != ipv6.MLDVersion2 { + t.Errorf("got mldEP.SetMLDVersion(%d) = %d, want = %d", ipv6.MLDVersion1, got, ipv6.MLDVersion2) + } + if err := s.JoinGroup(ipv6.ProtocolNumber, nicID, globalMulticastAddr); err != nil { + t.Fatalf("s.JoinGroup(%d, %d, %s): %s", ipv6.ProtocolNumber, nicID, globalMulticastAddr, err) + } + if p := e.Read(); p.IsNil() { + t.Fatal("expected a report message to be sent") + } else { + validateMLDPacket(t, stack.PayloadSince(p.NetworkHeader()), linkLocalAddr, globalMulticastAddr, header.ICMPv6MulticastListenerReport, globalMulticastAddr) + p.DecRef() + } + + if got := mldEP.SetMLDVersion(ipv6.MLDVersion2); got != ipv6.MLDVersion1 { + t.Errorf("got mldEP.SetMLDVersion(%d) = %d, want = %d", ipv6.MLDVersion2, got, ipv6.MLDVersion1) + } + if err := s.LeaveGroup(ipv6.ProtocolNumber, nicID, globalMulticastAddr); err != nil { + t.Fatalf("s.LeaveGroup(%d, %d, %s): %s", ipv6.ProtocolNumber, nicID, globalMulticastAddr, err) + } + if p := e.Read(); p.IsNil() { + t.Fatal("expected a report message to be sent") + } else { + validateMLDv2ReportPacket(t, stack.PayloadSince(p.NetworkHeader()), linkLocalAddr, globalMulticastAddr, header.MLDv2ReportRecordChangeToIncludeMode) + p.DecRef() + } +} + func TestMain(m *testing.M) { refs.SetLeakMode(refs.LeaksPanic) code := m.Run()