summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--services/core/java/com/android/server/vcn/Vcn.java14
-rw-r--r--tests/vcn/java/com/android/server/vcn/VcnTest.java69
2 files changed, 64 insertions, 19 deletions
diff --git a/services/core/java/com/android/server/vcn/Vcn.java b/services/core/java/com/android/server/vcn/Vcn.java
index 9d39c67d27fb..c55913e2e547 100644
--- a/services/core/java/com/android/server/vcn/Vcn.java
+++ b/services/core/java/com/android/server/vcn/Vcn.java
@@ -128,7 +128,6 @@ public class Vcn extends Handler {
* from VcnManagementService, and therefore cannot rely on guarantees of running on the VCN
* Looper.
*/
- // TODO(b/179429339): update when exiting safemode (when a new VcnConfig is provided)
private final AtomicBoolean mIsActive = new AtomicBoolean(true);
public Vcn(
@@ -203,7 +202,8 @@ public class Vcn extends Handler {
@Override
public void handleMessage(@NonNull Message msg) {
- if (!isActive()) {
+ // Ignore if this Vcn is not active and we're not receiving new configs
+ if (!isActive() && msg.what != MSG_EVENT_CONFIG_UPDATED) {
return;
}
@@ -237,7 +237,13 @@ public class Vcn extends Handler {
mConfig = config;
- // TODO: Reevaluate active VcnGatewayConnection(s)
+ // TODO(b/181815405): Reevaluate active VcnGatewayConnection(s)
+
+ if (!mIsActive.getAndSet(true)) {
+ // If this VCN was not previously active, it is exiting Safe Mode. Re-register the
+ // request listener to get NetworkRequests again (and all cached requests).
+ mVcnContext.getVcnNetworkProvider().registerListener(mRequestListener);
+ }
}
private void handleTeardown() {
@@ -253,6 +259,8 @@ public class Vcn extends Handler {
private void handleEnterSafeMode() {
handleTeardown();
+ mVcnGatewayConnections.clear();
+
mVcnCallback.onEnteredSafeMode();
}
diff --git a/tests/vcn/java/com/android/server/vcn/VcnTest.java b/tests/vcn/java/com/android/server/vcn/VcnTest.java
index 3dd710afed7b..4fa63d4ff640 100644
--- a/tests/vcn/java/com/android/server/vcn/VcnTest.java
+++ b/tests/vcn/java/com/android/server/vcn/VcnTest.java
@@ -22,7 +22,9 @@ import static android.net.NetworkCapabilities.NET_CAPABILITY_MMS;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertTrue;
import static org.mockito.Matchers.any;
+import static org.mockito.Matchers.argThat;
import static org.mockito.Matchers.eq;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.doReturn;
@@ -48,6 +50,7 @@ import org.junit.Before;
import org.junit.Test;
import org.mockito.ArgumentCaptor;
+import java.util.Arrays;
import java.util.Set;
import java.util.UUID;
@@ -58,7 +61,7 @@ public class VcnTest {
private static final int PROVIDER_ID = 5;
private static final int[][] TEST_CAPS =
new int[][] {
- new int[] {NET_CAPABILITY_INTERNET, NET_CAPABILITY_MMS},
+ new int[] {NET_CAPABILITY_MMS, NET_CAPABILITY_INTERNET},
new int[] {NET_CAPABILITY_DUN}
};
@@ -155,14 +158,6 @@ public class VcnTest {
}
}
- @Test
- public void testGatewayEnteringSafeModeNotifiesVcn() {
- final NetworkRequestListener requestListener = verifyAndGetRequestListener();
- for (final int capability : VcnGatewayConnectionConfigTest.EXPOSED_CAPS) {
- startVcnGatewayWithCapabilities(requestListener, capability);
- }
- }
-
private void triggerVcnRequestListeners(NetworkRequestListener requestListener) {
for (final int[] caps : TEST_CAPS) {
startVcnGatewayWithCapabilities(requestListener, caps);
@@ -188,8 +183,20 @@ public class VcnTest {
return gatewayConnections;
}
+ private void verifySafeMode(
+ NetworkRequestListener requestListener,
+ Set<VcnGatewayConnection> expectedGatewaysTornDown) {
+ assertFalse(mVcn.isActive());
+ assertTrue(mVcn.getVcnGatewayConnections().isEmpty());
+ for (final VcnGatewayConnection gatewayConnection : expectedGatewaysTornDown) {
+ verify(gatewayConnection).teardownAsynchronously();
+ }
+ verify(mVcnNetworkProvider).unregisterListener(requestListener);
+ verify(mVcnCallback).onEnteredSafeMode();
+ }
+
@Test
- public void testGatewayEnteringSafemodeNotifiesVcn() {
+ public void testGatewayEnteringSafeModeNotifiesVcn() {
final NetworkRequestListener requestListener = verifyAndGetRequestListener();
final Set<VcnGatewayConnection> gatewayConnections =
startGatewaysAndGetGatewayConnections(requestListener);
@@ -200,12 +207,7 @@ public class VcnTest {
statusCallback.onEnteredSafeMode();
mTestLooper.dispatchAll();
- assertFalse(mVcn.isActive());
- for (final VcnGatewayConnection gatewayConnection : gatewayConnections) {
- verify(gatewayConnection).teardownAsynchronously();
- }
- verify(mVcnNetworkProvider).unregisterListener(requestListener);
- verify(mVcnCallback).onEnteredSafeMode();
+ verifySafeMode(requestListener, gatewayConnections);
}
@Test
@@ -234,4 +236,39 @@ public class VcnTest {
any(),
mGatewayStatusCallbackCaptor.capture());
}
+
+ @Test
+ public void testUpdateConfigExitsSafeMode() {
+ final NetworkRequestListener requestListener = verifyAndGetRequestListener();
+ final Set<VcnGatewayConnection> gatewayConnections =
+ new ArraySet<>(startGatewaysAndGetGatewayConnections(requestListener));
+
+ final VcnGatewayStatusCallback statusCallback = mGatewayStatusCallbackCaptor.getValue();
+ statusCallback.onEnteredSafeMode();
+ mTestLooper.dispatchAll();
+ verifySafeMode(requestListener, gatewayConnections);
+
+ doAnswer(invocation -> {
+ final NetworkRequestListener listener = invocation.getArgument(0);
+ triggerVcnRequestListeners(listener);
+ return null;
+ }).when(mVcnNetworkProvider).registerListener(eq(requestListener));
+
+ mVcn.updateConfig(mConfig);
+ mTestLooper.dispatchAll();
+
+ // Registered on start, then re-registered with new configs
+ verify(mVcnNetworkProvider, times(2)).registerListener(eq(requestListener));
+ assertTrue(mVcn.isActive());
+ for (final int[] caps : TEST_CAPS) {
+ // Expect each gateway connection created on initial startup, and again with new configs
+ verify(mDeps, times(2))
+ .newVcnGatewayConnection(
+ eq(mVcnContext),
+ eq(TEST_SUB_GROUP),
+ eq(mSubscriptionSnapshot),
+ argThat(config -> Arrays.equals(caps, config.getExposedCapabilities())),
+ any());
+ }
+ }
}