diff --git a/agent.go b/agent.go index 93e5a28..f8a53bf 100644 --- a/agent.go +++ b/agent.go @@ -753,11 +753,10 @@ func (a *Agent) validateSelectedPair() bool { return false } - if (a.connectionTimeout != 0) && - (time.Since(selectedPair.remote.LastReceived()) > a.connectionTimeout) { - a.setSelectedPair(nil) + if (a.connectionTimeout != 0) && (time.Since(selectedPair.remote.LastReceived()) > a.connectionTimeout) { a.updateConnectionState(ConnectionStateDisconnected) - return false + } else { + a.updateConnectionState(ConnectionStateConnected) } return true diff --git a/connectivity_vnet_test.go b/connectivity_vnet_test.go index 12632bc..b03c3d8 100644 --- a/connectivity_vnet_test.go +++ b/connectivity_vnet_test.go @@ -6,6 +6,7 @@ import ( "context" "net" "sync" + "sync/atomic" "testing" "time" @@ -471,3 +472,89 @@ func TestConnectivityVNet(t *testing.T) { } }) } + +// TestDisconnectedToConnected asserts that an agent can go to disconnected, and then return to connected successfully +func TestDisconnectedToConnected(t *testing.T) { + loggerFactory := logging.NewDefaultLoggerFactory() + + // Create a network with two interfaces + wan, err := vnet.NewRouter(&vnet.RouterConfig{ + CIDR: "0.0.0.0/0", + LoggerFactory: loggerFactory, + }) + assert.NoError(t, err) + + var dropAllData uint64 + wan.AddChunkFilter(func(vnet.Chunk) bool { + return atomic.LoadUint64(&dropAllData) != 1 + }) + + net0 := vnet.NewNet(&vnet.NetConfig{ + StaticIPs: []string{"192.168.0.1"}, + }) + assert.NoError(t, wan.AddNet(net0)) + + net1 := vnet.NewNet(&vnet.NetConfig{ + StaticIPs: []string{"192.168.0.2"}, + }) + assert.NoError(t, wan.AddNet(net1)) + + assert.NoError(t, wan.Start()) + + oneSecond := time.Second + threeSeconds := time.Second * 3 + + // Create two agents and connect them + controllingAgent, err := NewAgent(&AgentConfig{ + NetworkTypes: supportedNetworkTypes, + MulticastDNSMode: MulticastDNSModeDisabled, + Net: net0, + ConnectionTimeout: &threeSeconds, + KeepaliveInterval: &oneSecond, + }) + assert.NoError(t, err) + + controlledAgent, err := NewAgent(&AgentConfig{ + NetworkTypes: supportedNetworkTypes, + MulticastDNSMode: MulticastDNSModeDisabled, + Net: net1, + ConnectionTimeout: &threeSeconds, + KeepaliveInterval: &oneSecond, + }) + assert.NoError(t, err) + + controllingStateChanges := make(chan ConnectionState, 100) + assert.NoError(t, controllingAgent.OnConnectionStateChange(func(c ConnectionState) { + controllingStateChanges <- c + })) + + controlledStateChanges := make(chan ConnectionState, 100) + assert.NoError(t, controlledAgent.OnConnectionStateChange(func(c ConnectionState) { + controlledStateChanges <- c + })) + + connectWithVNet(controllingAgent, controlledAgent) + blockUntilStateSeen := func(expectedState ConnectionState, stateQueue chan ConnectionState) { + for s := range stateQueue { + if s == expectedState { + return + } + } + } + + // Assert we have gone to connected + blockUntilStateSeen(ConnectionStateConnected, controllingStateChanges) + blockUntilStateSeen(ConnectionStateConnected, controlledStateChanges) + + // Drop all packets, and block until we have gone to disconnected + atomic.StoreUint64(&dropAllData, 1) + blockUntilStateSeen(ConnectionStateDisconnected, controllingStateChanges) + blockUntilStateSeen(ConnectionStateDisconnected, controlledStateChanges) + + // Allow all packets through again, block until we have gone to connected + atomic.StoreUint64(&dropAllData, 0) + blockUntilStateSeen(ConnectionStateConnected, controllingStateChanges) + blockUntilStateSeen(ConnectionStateConnected, controlledStateChanges) + + assert.NoError(t, wan.Stop()) +}