diff options
author | Lorenzo Colitti <lorenzo@google.com> | 2019-05-09 11:57:25 +0000 |
---|---|---|
committer | Gerrit Code Review <noreply-gerritcodereview@google.com> | 2019-05-09 11:57:25 +0000 |
commit | 839bb2cbf378ae42790a99104571f94c9cbfe59c (patch) | |
tree | 8f0992c38d340155f9a9a2e64609607f11c326d8 | |
parent | 5562cf3d9c24c7979018f918ac6598122297b3af (diff) | |
parent | cfc1abc714db769d3844623d927ba96160f50c89 (diff) |
Merge "Add tests for strict mode private DNS validation."
-rw-r--r-- | tests/src/com/android/server/connectivity/NetworkMonitorTest.java | 201 |
1 files changed, 165 insertions, 36 deletions
diff --git a/tests/src/com/android/server/connectivity/NetworkMonitorTest.java b/tests/src/com/android/server/connectivity/NetworkMonitorTest.java index 0dc1cbf..a24ed5a 100644 --- a/tests/src/com/android/server/connectivity/NetworkMonitorTest.java +++ b/tests/src/com/android/server/connectivity/NetworkMonitorTest.java @@ -42,10 +42,10 @@ import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.any; import static org.mockito.Mockito.anyInt; import static org.mockito.Mockito.doAnswer; -import static org.mockito.Mockito.doNothing; import static org.mockito.Mockito.doReturn; import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.never; +import static org.mockito.Mockito.reset; import static org.mockito.Mockito.timeout; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; @@ -66,6 +66,7 @@ import android.net.NetworkCapabilities; import android.net.NetworkInfo; import android.net.captiveportal.CaptivePortalProbeResult; import android.net.metrics.IpConnectivityLog; +import android.net.shared.PrivateDnsConfig; import android.net.util.SharedLog; import android.net.wifi.WifiInfo; import android.net.wifi.WifiManager; @@ -73,6 +74,7 @@ import android.os.Bundle; import android.os.ConditionVariable; import android.os.Handler; import android.os.Looper; +import android.os.Process; import android.os.RemoteException; import android.os.SystemClock; import android.provider.Settings; @@ -132,6 +134,7 @@ public class NetworkMonitorTest { private @Mock NetworkMonitor.Dependencies mDependencies; private @Mock INetworkMonitorCallbacks mCallbacks; private @Spy Network mNetwork = new Network(TEST_NETID); + private @Mock Network mNonPrivateDnsBypassNetwork; private @Mock DataStallStatsUtils mDataStallStatsUtils; private @Mock WifiInfo mWifiInfo; private @Captor ArgumentCaptor<String> mNetworkTestedRedirectUrlCaptor; @@ -166,31 +169,93 @@ public class NetworkMonitorTest { private static final NetworkCapabilities NO_INTERNET_CAPABILITIES = new NetworkCapabilities() .addTransportType(NetworkCapabilities.TRANSPORT_CELLULAR); - private void setDnsAnswers(String[] answers) throws UnknownHostException { - if (answers == null) { - doThrow(new UnknownHostException()).when(mNetwork).getAllByName(any()); - doNothing().when(mDnsResolver).query(any(), any(), anyInt(), any(), any(), any()); - return; + /** + * Fakes DNS responses. + * + * Allows test methods to configure the IP addresses that will be resolved by + * Network#getAllByName and by DnsResolver#query. + */ + class FakeDns { + private final ArrayMap<String, List<InetAddress>> mAnswers = new ArrayMap<>(); + private boolean mNonBypassPrivateDnsWorking = true; + + /** Whether DNS queries on mNonBypassPrivateDnsWorking should succeed. */ + private void setNonBypassPrivateDnsWorking(boolean working) { + mNonBypassPrivateDnsWorking = working; } - List<InetAddress> answerList = new ArrayList<>(); - for (String answer : answers) { - answerList.add(InetAddresses.parseNumericAddress(answer)); + /** Clears all DNS entries. */ + private synchronized void clearAll() { + mAnswers.clear(); } - InetAddress[] answerArray = answerList.toArray(new InetAddress[0]); - doReturn(answerArray).when(mNetwork).getAllByName(any()); + /** Returns the answer for a given name on the given mock network. */ + private synchronized List<InetAddress> getAnswer(Object mock, String hostname) { + if (mock == mNonPrivateDnsBypassNetwork && !mNonBypassPrivateDnsWorking) { + return null; + } + if (mAnswers.containsKey(hostname)) { + return mAnswers.get(hostname); + } + return mAnswers.get("*"); + } - doAnswer((invocation) -> { - Executor executor = (Executor) invocation.getArgument(3); - DnsResolver.Callback<List<InetAddress>> callback = invocation.getArgument(5); - new Handler(Looper.getMainLooper()).post(() -> { - executor.execute(() -> callback.onAnswer(answerList, 0)); - }); - return null; - }).when(mDnsResolver).query(eq(mNetwork), any(), anyInt(), any(), any(), any()); + /** Sets the answer for a given name. */ + private synchronized void setAnswer(String hostname, String[] answer) + throws UnknownHostException { + if (answer == null) { + mAnswers.remove(hostname); + } else { + List<InetAddress> answerList = new ArrayList<>(); + for (String addr : answer) { + answerList.add(InetAddresses.parseNumericAddress(addr)); + } + mAnswers.put(hostname, answerList); + } + } + + /** Simulates a getAllByName call for the specified name on the specified mock network. */ + private InetAddress[] getAllByName(Object mock, String hostname) + throws UnknownHostException { + List<InetAddress> answer = getAnswer(mock, hostname); + if (answer == null || answer.size() == 0) { + throw new UnknownHostException(hostname); + } + return answer.toArray(new InetAddress[0]); + } + + /** Starts mocking DNS queries. */ + private void startMocking() throws UnknownHostException { + // Queries on mNetwork (i.e., bypassing private DNS) using getAllByName. + doAnswer(invocation -> { + return getAllByName(invocation.getMock(), invocation.getArgument(0)); + }).when(mNetwork).getAllByName(any()); + + // Queries on mNonBypassPrivateDnsNetwork using getAllByName. + doAnswer(invocation -> { + return getAllByName(invocation.getMock(), invocation.getArgument(0)); + }).when(mNonPrivateDnsBypassNetwork).getAllByName(any()); + + // Queries on mNetwork (i.e., bypassing private DNS) using DnsResolver#query. + doAnswer(invocation -> { + String hostname = (String) invocation.getArgument(1); + Executor executor = (Executor) invocation.getArgument(3); + DnsResolver.Callback<List<InetAddress>> callback = invocation.getArgument(5); + + List<InetAddress> answer = getAnswer(invocation.getMock(), hostname); + if (answer != null && answer.size() > 0) { + new Handler(Looper.getMainLooper()).post(() -> { + executor.execute(() -> callback.onAnswer(answer, 0)); + }); + } + // If no answers, do nothing. sendDnsProbeWithTimeout will time out and throw UHE. + return null; + }).when(mDnsResolver).query(any(), any(), anyInt(), any(), any(), any()); + } } + private FakeDns mFakeDns; + @Before public void setUp() throws IOException { MockitoAnnotations.initMocks(this); @@ -206,7 +271,7 @@ public class NetworkMonitorTest { when(mDependencies.getSetting(any(), eq(Settings.Global.CAPTIVE_PORTAL_HTTPS_URL), any())) .thenReturn(TEST_HTTPS_URL); - doReturn(mNetwork).when(mNetwork).getPrivateDnsBypassingCopy(); + doReturn(mNetwork).when(mNonPrivateDnsBypassNetwork).getPrivateDnsBypassingCopy(); when(mContext.getSystemService(Context.CONNECTIVITY_SERVICE)).thenReturn(mCm); when(mContext.getSystemService(Context.TELEPHONY_SERVICE)).thenReturn(mTelephony); @@ -222,6 +287,9 @@ public class NetworkMonitorTest { setFallbackSpecs(null); // Test with no fallback spec by default when(mRandom.nextInt()).thenReturn(0); + when(mResources.getInteger(eq(R.integer.config_captive_portal_dns_probe_timeout))) + .thenReturn(500); + doAnswer((invocation) -> { URL url = invocation.getArgument(0); switch(url.toString()) { @@ -241,7 +309,9 @@ public class NetworkMonitorTest { when(mHttpConnection.getRequestProperties()).thenReturn(new ArrayMap<>()); when(mHttpsConnection.getRequestProperties()).thenReturn(new ArrayMap<>()); - setDnsAnswers(new String[]{"2001:db8::1", "192.0.2.2"}); + mFakeDns = new FakeDns(); + mFakeDns.startMocking(); + mFakeDns.setAnswer("*", new String[]{"2001:db8::1", "192.0.2.2"}); when(mContext.registerReceiver(any(BroadcastReceiver.class), any())).then((invocation) -> { mRegisteredReceivers.add(invocation.getArgument(0)); @@ -264,6 +334,7 @@ public class NetworkMonitorTest { @After public void tearDown() { + mFakeDns.clearAll(); assertTrue(mCreatedNetworkMonitors.size() > 0); // Make a local copy of mCreatedNetworkMonitors because during the iteration below, // WrappedNetworkMonitor#onQuitting will delete elements from it on the handler threads. @@ -284,8 +355,8 @@ public class NetworkMonitorTest { private final ConditionVariable mQuitCv = new ConditionVariable(false); WrappedNetworkMonitor() { - super(mContext, mCallbacks, mNetwork, mLogger, mValidationLogger, mDependencies, - mDataStallStatsUtils); + super(mContext, mCallbacks, mNonPrivateDnsBypassNetwork, mLogger, mValidationLogger, + mDependencies, mDataStallStatsUtils); } @Override @@ -314,23 +385,22 @@ public class NetworkMonitorTest { } } - private WrappedNetworkMonitor makeMonitor() { + private WrappedNetworkMonitor makeMonitor(NetworkCapabilities nc) { final WrappedNetworkMonitor nm = new WrappedNetworkMonitor(); nm.start(); + setNetworkCapabilities(nm, nc); waitForIdle(nm.getHandler()); mCreatedNetworkMonitors.add(nm); return nm; } private WrappedNetworkMonitor makeMeteredNetworkMonitor() { - final WrappedNetworkMonitor nm = makeMonitor(); - setNetworkCapabilities(nm, METERED_CAPABILITIES); + final WrappedNetworkMonitor nm = makeMonitor(METERED_CAPABILITIES); return nm; } private WrappedNetworkMonitor makeNotMeteredNetworkMonitor() { - final WrappedNetworkMonitor nm = makeMonitor(); - setNetworkCapabilities(nm, NOT_METERED_CAPABILITIES); + final WrappedNetworkMonitor nm = makeMonitor(NOT_METERED_CAPABILITIES); return nm; } @@ -603,7 +673,7 @@ public class NetworkMonitorTest { setSslException(mHttpsConnection); setPortal302(mHttpConnection); - final NetworkMonitor nm = makeMonitor(); + final NetworkMonitor nm = makeMonitor(METERED_CAPABILITIES); nm.notifyNetworkConnected(TEST_LINK_PROPERTIES, METERED_CAPABILITIES); verify(mCallbacks, timeout(HANDLER_TIMEOUT_MS).times(1)) @@ -638,6 +708,63 @@ public class NetworkMonitorTest { } @Test + public void testPrivateDnsSuccess() throws Exception { + setStatus(mHttpsConnection, 204); + setStatus(mHttpConnection, 204); + mFakeDns.setAnswer("dns.google", new String[]{"2001:db8::53"}); + + WrappedNetworkMonitor wnm = makeNotMeteredNetworkMonitor(); + wnm.notifyPrivateDnsSettingsChanged(new PrivateDnsConfig("dns.google", new InetAddress[0])); + wnm.notifyNetworkConnected(TEST_LINK_PROPERTIES, NOT_METERED_CAPABILITIES); + verify(mCallbacks, timeout(HANDLER_TIMEOUT_MS).times(1)) + .notifyNetworkTested(eq(NETWORK_TEST_RESULT_VALID), eq(null)); + } + + @Test + public void testPrivateDnsResolutionRetryUpdate() throws Exception { + // Set a private DNS hostname that doesn't resolve and expect validation to fail. + mFakeDns.setAnswer("dns.google", new String[0]); + setStatus(mHttpsConnection, 204); + setStatus(mHttpConnection, 204); + + WrappedNetworkMonitor wnm = makeNotMeteredNetworkMonitor(); + wnm.notifyPrivateDnsSettingsChanged(new PrivateDnsConfig("dns.google", new InetAddress[0])); + wnm.notifyNetworkConnected(TEST_LINK_PROPERTIES, NOT_METERED_CAPABILITIES); + verify(mCallbacks, timeout(HANDLER_TIMEOUT_MS).times(1)) + .notifyNetworkTested(eq(NETWORK_TEST_RESULT_INVALID), eq(null)); + + // Fix DNS and retry, expect validation to succeed. + reset(mCallbacks); + mFakeDns.setAnswer("dns.google", new String[]{"2001:db8::1"}); + + wnm.forceReevaluation(Process.myUid()); + verify(mCallbacks, timeout(HANDLER_TIMEOUT_MS).times(1)) + .notifyNetworkTested(eq(NETWORK_TEST_RESULT_VALID), eq(null)); + + // Change configuration to an invalid DNS name, expect validation to fail. + reset(mCallbacks); + mFakeDns.setAnswer("dns.bad", new String[0]); + wnm.notifyPrivateDnsSettingsChanged(new PrivateDnsConfig("dns.bad", new InetAddress[0])); + verify(mCallbacks, timeout(HANDLER_TIMEOUT_MS).times(1)) + .notifyNetworkTested(eq(NETWORK_TEST_RESULT_INVALID), eq(null)); + + // Change configuration back to working again, but make private DNS not work. + // Expect validation to fail. + reset(mCallbacks); + mFakeDns.setNonBypassPrivateDnsWorking(false); + wnm.notifyPrivateDnsSettingsChanged(new PrivateDnsConfig("dns.google", new InetAddress[0])); + verify(mCallbacks, timeout(HANDLER_TIMEOUT_MS).times(1)) + .notifyNetworkTested(eq(NETWORK_TEST_RESULT_INVALID), eq(null)); + + // Make private DNS work again. Expect validation to succeed. + reset(mCallbacks); + mFakeDns.setNonBypassPrivateDnsWorking(true); + wnm.forceReevaluation(Process.myUid()); + verify(mCallbacks, timeout(HANDLER_TIMEOUT_MS).times(1)) + .notifyNetworkTested(eq(NETWORK_TEST_RESULT_VALID), eq(null)); + } + + @Test public void testDataStall_StallSuspectedAndSendMetrics() throws IOException { WrappedNetworkMonitor wrappedMonitor = makeNotMeteredNetworkMonitor(); wrappedMonitor.setLastProbeTime(SystemClock.elapsedRealtime() - 1000); @@ -728,25 +855,27 @@ public class NetworkMonitorTest { WrappedNetworkMonitor wnm = makeNotMeteredNetworkMonitor(); final int shortTimeoutMs = 200; + // Clear the wildcard DNS response created in setUp. + mFakeDns.setAnswer("*", null); + String[] expected = new String[]{"2001:db8::"}; - setDnsAnswers(expected); + mFakeDns.setAnswer("www.google.com", expected); InetAddress[] actual = wnm.sendDnsProbeWithTimeout("www.google.com", shortTimeoutMs); assertIpAddressArrayEquals(expected, actual); expected = new String[]{"2001:db8::", "192.0.2.1"}; - setDnsAnswers(expected); - actual = wnm.sendDnsProbeWithTimeout("www.google.com", shortTimeoutMs); + mFakeDns.setAnswer("www.googleapis.com", expected); + actual = wnm.sendDnsProbeWithTimeout("www.googleapis.com", shortTimeoutMs); assertIpAddressArrayEquals(expected, actual); - expected = new String[0]; - setDnsAnswers(expected); + mFakeDns.setAnswer("www.google.com", new String[0]); try { wnm.sendDnsProbeWithTimeout("www.google.com", shortTimeoutMs); fail("No DNS results, expected UnknownHostException"); } catch (UnknownHostException e) { } - setDnsAnswers(null); + mFakeDns.setAnswer("www.google.com", null); try { wnm.sendDnsProbeWithTimeout("www.google.com", shortTimeoutMs); fail("DNS query timed out, expected UnknownHostException"); @@ -841,7 +970,7 @@ public class NetworkMonitorTest { } private NetworkMonitor runNetworkTest(NetworkCapabilities nc, int testResult) { - final NetworkMonitor monitor = makeMonitor(); + final NetworkMonitor monitor = makeMonitor(nc); monitor.notifyNetworkConnected(TEST_LINK_PROPERTIES, nc); try { verify(mCallbacks, timeout(HANDLER_TIMEOUT_MS).times(1)) |