summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/com/android/networkstack/util/DnsUtils.java99
-rw-r--r--src/com/android/server/connectivity/NetworkMonitor.java27
-rw-r--r--tests/lib/src/com/android/testutils/FileUtils.kt11
-rw-r--r--tests/unit/src/android/net/testutils/TrackRecordTest.kt396
-rw-r--r--tests/unit/src/com/android/server/connectivity/NetworkMonitorTest.java50
5 files changed, 517 insertions, 66 deletions
diff --git a/src/com/android/networkstack/util/DnsUtils.java b/src/com/android/networkstack/util/DnsUtils.java
index 4767d55..e68976a 100644
--- a/src/com/android/networkstack/util/DnsUtils.java
+++ b/src/com/android/networkstack/util/DnsUtils.java
@@ -21,21 +21,25 @@ import static android.net.DnsResolver.TYPE_A;
import static android.net.DnsResolver.TYPE_AAAA;
import android.annotation.NonNull;
+import android.annotation.Nullable;
import android.net.DnsResolver;
import android.net.Network;
import android.net.TrafficStats;
+import android.net.util.Stopwatch;
import android.util.Log;
import com.android.internal.util.TrafficStatsConstants;
+import com.android.server.connectivity.NetworkMonitor.DnsLogFunc;
import java.net.InetAddress;
import java.net.UnknownHostException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
-import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
-import java.util.concurrent.atomic.AtomicReference;
+import java.util.concurrent.TimeoutException;
/**
* Collection of utilities for dns query.
@@ -44,6 +48,7 @@ public class DnsUtils {
// Decide what queries to make depending on what IP addresses are on the system.
public static final int TYPE_ADDRCONFIG = -1;
private static final String TAG = DnsUtils.class.getSimpleName();
+ private static final boolean DBG = Log.isLoggable(TAG, Log.DEBUG);
/**
* Return both A and AAAA query results regardless the ip address type of the giving network.
@@ -51,27 +56,33 @@ public class DnsUtils {
*/
@NonNull
public static InetAddress[] getAllByName(@NonNull final DnsResolver dnsResolver,
- @NonNull final Network network, @NonNull String host, int timeout)
- throws UnknownHostException {
+ @NonNull final Network network, @NonNull String host, int timeout,
+ @NonNull final DnsLogFunc logger) throws UnknownHostException {
final List<InetAddress> result = new ArrayList<InetAddress>();
+ final StringBuilder errorMsg = new StringBuilder(host);
try {
result.addAll(Arrays.asList(
getAllByName(dnsResolver, network, host, TYPE_AAAA, FLAG_NO_CACHE_LOOKUP,
- timeout)));
+ timeout, logger)));
} catch (UnknownHostException e) {
// Might happen if the host is v4-only, still need to query TYPE_A
+ errorMsg.append(String.format(" (%s)%s", dnsTypeToStr(TYPE_AAAA), e.getMessage()));
}
try {
result.addAll(Arrays.asList(
getAllByName(dnsResolver, network, host, TYPE_A, FLAG_NO_CACHE_LOOKUP,
- timeout)));
+ timeout, logger)));
} catch (UnknownHostException e) {
// Might happen if the host is v6-only, still need to return AAAA answers
+ errorMsg.append(String.format(" (%s)%s", dnsTypeToStr(TYPE_A), e.getMessage()));
}
+
if (result.size() == 0) {
- throw new UnknownHostException(host);
+ logger.log("FAIL: " + errorMsg.toString());
+ throw new UnknownHostException(errorMsg.toString());
}
+ logger.log("OK: " + host + " " + result.toString());
return result.toArray(new InetAddress[0]);
}
@@ -82,26 +93,34 @@ public class DnsUtils {
@NonNull
public static InetAddress[] getAllByName(@NonNull final DnsResolver dnsResolver,
@NonNull final Network network, @NonNull final String host, int type, int flag,
- int timeoutMs) throws UnknownHostException {
- final CountDownLatch latch = new CountDownLatch(1);
- final AtomicReference<List<InetAddress>> resultRef = new AtomicReference<>();
+ int timeoutMs, @Nullable final DnsLogFunc logger) throws UnknownHostException {
+ final CompletableFuture<List<InetAddress>> resultRef = new CompletableFuture<>();
+ final Stopwatch watch = new Stopwatch().start();
+
final DnsResolver.Callback<List<InetAddress>> callback =
- new DnsResolver.Callback<List<InetAddress>>() {
+ new DnsResolver.Callback<List<InetAddress>>() {
@Override
public void onAnswer(List<InetAddress> answer, int rcode) {
- if (rcode == 0) {
- resultRef.set(answer);
+ if (rcode == 0 && answer != null && answer.size() != 0) {
+ resultRef.complete(answer);
+ } else {
+ resultRef.completeExceptionally(new UnknownHostException());
}
- latch.countDown();
}
@Override
public void onError(@NonNull DnsResolver.DnsException e) {
- Log.d(TAG, "DNS error resolving " + host + ": " + e.getMessage());
- latch.countDown();
+ if (DBG) {
+ Log.d(TAG, "DNS error resolving " + host, e);
+ }
+ resultRef.completeExceptionally(e);
}
};
+ // TODO: Investigate whether this is still useful.
+ // The packets that actually do the DNS queries are sent by netd, but netd doesn't
+ // look at the tag at all. Given that this is a library, the tag should be passed in by the
+ // caller.
final int oldTag = TrafficStats.getAndSetThreadStatsTag(
TrafficStatsConstants.TAG_SYSTEM_PROBE);
@@ -115,16 +134,52 @@ public class DnsUtils {
TrafficStats.setThreadStatsTag(oldTag);
+ List<InetAddress> result = null;
+ Exception exception = null;
try {
- latch.await(timeoutMs, TimeUnit.MILLISECONDS);
- } catch (InterruptedException e) {
+ result = resultRef.get(timeoutMs, TimeUnit.MILLISECONDS);
+ } catch (ExecutionException e) {
+ exception = e;
+ } catch (TimeoutException | InterruptedException e) {
+ exception = new UnknownHostException("Timeout");
+ } finally {
+ logDnsResult(result, watch.stop() /* latency */, logger, type,
+ exception != null ? exception.getMessage() : "" /* errorMsg */);
}
- final List<InetAddress> result = resultRef.get();
- if (result == null || result.size() == 0) {
- throw new UnknownHostException(host);
- }
+ if (null != exception) throw (UnknownHostException) exception;
return result.toArray(new InetAddress[0]);
}
+
+ private static void logDnsResult(@Nullable final List<InetAddress> results, final long latency,
+ @Nullable final DnsLogFunc logger, int type, @NonNull final String errorMsg) {
+ if (logger == null) {
+ return;
+ }
+
+ if (results != null && results.size() != 0) {
+ final StringBuilder builder = new StringBuilder();
+ for (InetAddress address : results) {
+ builder.append(',').append(address.getHostAddress());
+ }
+ logger.log(String.format("%dms OK %s", latency, builder.substring(1)));
+ } else {
+ logger.log(String.format("%dms FAIL in type %s %s", latency, dnsTypeToStr(type),
+ errorMsg));
+ }
+ }
+
+ private static String dnsTypeToStr(int type) {
+ switch (type) {
+ case TYPE_A:
+ return "A";
+ case TYPE_AAAA:
+ return "AAAA";
+ case TYPE_ADDRCONFIG:
+ return "ADDRCONFIG";
+ default:
+ }
+ return "UNDEFINED";
+ }
}
diff --git a/src/com/android/server/connectivity/NetworkMonitor.java b/src/com/android/server/connectivity/NetworkMonitor.java
index 0b2c051..d4b484d 100644
--- a/src/com/android/server/connectivity/NetworkMonitor.java
+++ b/src/com/android/server/connectivity/NetworkMonitor.java
@@ -1044,12 +1044,11 @@ public class NetworkMonitor extends StateMachine {
try {
// Do a blocking DNS resolution using the network-assigned nameservers.
final InetAddress[] ips = DnsUtils.getAllByName(mDependencies.getDnsResolver(),
- mCleartextDnsNetwork, mPrivateDnsProviderHostname, getDnsProbeTimeout());
+ mCleartextDnsNetwork, mPrivateDnsProviderHostname, getDnsProbeTimeout(),
+ str -> validationLog("Strict mode hostname resolution " + str));
mPrivateDnsConfig = new PrivateDnsConfig(mPrivateDnsProviderHostname, ips);
- validationLog("Strict mode hostname resolved: " + mPrivateDnsConfig);
} catch (UnknownHostException uhe) {
mPrivateDnsConfig = null;
- validationLog("Strict mode hostname resolution failed: " + uhe.getMessage());
}
mEvaluationState.noteProbeResult(NETWORK_VALIDATION_PROBE_PRIVDNS,
(mPrivateDnsConfig != null) /* succeeded */);
@@ -1153,7 +1152,6 @@ public class NetworkMonitor extends StateMachine {
} else if (probeResult.isPartialConnectivity()) {
mEvaluationState.reportEvaluationResult(NETWORK_VALIDATION_RESULT_PARTIAL,
null /* redirectUrl */);
- // Check if disable https probing needed.
maybeDisableHttpsProbing(mAcceptPartialConnectivity);
if (mAcceptPartialConnectivity) {
transitionTo(mEvaluatingPrivateDnsState);
@@ -1557,7 +1555,8 @@ public class NetworkMonitor extends StateMachine {
protected InetAddress[] sendDnsProbeWithTimeout(String host, int timeoutMs)
throws UnknownHostException {
return DnsUtils.getAllByName(mDependencies.getDnsResolver(), mCleartextDnsNetwork, host,
- TYPE_ADDRCONFIG, FLAG_EMPTY, timeoutMs);
+ TYPE_ADDRCONFIG, FLAG_EMPTY, timeoutMs,
+ str -> validationLog(ValidationProbeEvent.PROBE_DNS, host, str));
}
/** Do a DNS resolution of the given server. */
@@ -1572,19 +1571,11 @@ public class NetworkMonitor extends StateMachine {
String connectInfo;
try {
InetAddress[] addresses = sendDnsProbeWithTimeout(host, getDnsProbeTimeout());
- StringBuffer buffer = new StringBuffer();
- for (InetAddress address : addresses) {
- buffer.append(',').append(address.getHostAddress());
- }
result = ValidationProbeEvent.DNS_SUCCESS;
- connectInfo = "OK " + buffer.substring(1);
} catch (UnknownHostException e) {
result = ValidationProbeEvent.DNS_FAILURE;
- connectInfo = "FAIL";
}
final long latency = watch.stop();
- validationLog(ValidationProbeEvent.PROBE_DNS, host,
- String.format("%dms %s", latency, connectInfo));
logValidationProbe(latency, ValidationProbeEvent.PROBE_DNS, result);
}
@@ -2175,4 +2166,14 @@ public class NetworkMonitor extends StateMachine {
}
mEvaluationState.noteProbeResult(probeResult, succeeded);
}
+
+ /**
+ * Interface for logging dns results.
+ */
+ public interface DnsLogFunc {
+ /**
+ * Log function.
+ */
+ void log(String s);
+ }
}
diff --git a/tests/lib/src/com/android/testutils/FileUtils.kt b/tests/lib/src/com/android/testutils/FileUtils.kt
new file mode 100644
index 0000000..edd8d83
--- /dev/null
+++ b/tests/lib/src/com/android/testutils/FileUtils.kt
@@ -0,0 +1,11 @@
+package com.android.testutils
+
+// This function is private because the 2 is hardcoded here, and is not correct if not called
+// directly from __LINE__ or __FILE__.
+private fun callerStackTrace(): StackTraceElement = try {
+ throw RuntimeException()
+} catch (e: RuntimeException) {
+ e.stackTrace[2] // 0 is here, 1 is get() in __FILE__ or __LINE__
+}
+val __FILE__: String get() = callerStackTrace().fileName
+val __LINE__: Int get() = callerStackTrace().lineNumber
diff --git a/tests/unit/src/android/net/testutils/TrackRecordTest.kt b/tests/unit/src/android/net/testutils/TrackRecordTest.kt
index 5f3e81a..77c2cd3 100644
--- a/tests/unit/src/android/net/testutils/TrackRecordTest.kt
+++ b/tests/unit/src/android/net/testutils/TrackRecordTest.kt
@@ -16,10 +16,17 @@
package android.net.testutils
+import android.os.SystemClock
import com.android.testutils.ArrayTrackRecord
+import com.android.testutils.TrackRecord
+import com.android.testutils.__FILE__
+import com.android.testutils.__LINE__
import org.junit.Test
import org.junit.runner.RunWith
import org.junit.runners.JUnit4
+import java.util.concurrent.CyclicBarrier
+import java.util.concurrent.TimeUnit
+import kotlin.system.measureTimeMillis
import kotlin.test.assertEquals
import kotlin.test.assertFailsWith
import kotlin.test.assertFalse
@@ -30,6 +37,14 @@ import kotlin.test.fail
val TEST_VALUES = listOf(4, 13, 52, 94, 41, 68, 11, 13, 51, 0, 91, 94, 33, 98, 14)
const val ABSENT_VALUE = 2
+// Caution in changing these : some tests rely on the fact that TEST_TIMEOUT > 2 * SHORT_TIMEOUT
+// and LONG_TIMEOUT > 2 * TEST_TIMEOUT
+const val SHORT_TIMEOUT = 40L // ms
+const val TEST_TIMEOUT = 200L // ms
+const val LONG_TIMEOUT = 5000L // ms
+
+// The unit of time for interpreted tests
+const val INTERPRET_TIME_UNIT = SHORT_TIMEOUT
@RunWith(JUnit4::class)
class TrackRecordTest {
@@ -135,19 +150,390 @@ class TrackRecordTest {
TEST_VALUES.subList(3, TEST_VALUES.size - 3))
}
+ fun testPollReturnsImmediately(record: TrackRecord<Int>) {
+ record.add(4)
+ val elapsed = measureTimeMillis { assertEquals(4, record.poll(LONG_TIMEOUT, 0)) }
+ // Should not have waited at all, in fact.
+ assertTrue(elapsed < LONG_TIMEOUT)
+ record.add(7)
+ record.add(9)
+ // Can poll multiple times for the same position, in whatever order
+ assertEquals(9, record.poll(0, 2))
+ assertEquals(7, record.poll(Long.MAX_VALUE, 1))
+ assertEquals(9, record.poll(0, 2))
+ assertEquals(4, record.poll(0, 0))
+ assertEquals(9, record.poll(0, 2) { it > 5 })
+ assertEquals(7, record.poll(0, 0) { it > 5 })
+ }
+
+ @Test
+ fun testPollReturnsImmediately() {
+ testPollReturnsImmediately(ArrayTrackRecord())
+ testPollReturnsImmediately(ArrayTrackRecord<Int>().newReadHead())
+ }
+
+ @Test
+ fun testPollTimesOut() {
+ val record = ArrayTrackRecord<Int>()
+ var delay = measureTimeMillis { assertNull(record.poll(SHORT_TIMEOUT, 0)) }
+ assertTrue(delay >= SHORT_TIMEOUT, "Delay $delay < $SHORT_TIMEOUT")
+ delay = measureTimeMillis { assertNull(record.poll(SHORT_TIMEOUT, 0) { it < 10 }) }
+ assertTrue(delay > SHORT_TIMEOUT)
+ }
+
+ @Test
+ fun testPollWakesUp() {
+ val record = ArrayTrackRecord<Int>()
+ val barrier = CyclicBarrier(2)
+ Thread {
+ barrier.await(LONG_TIMEOUT, TimeUnit.MILLISECONDS) // barrier 1
+ barrier.await() // barrier 2
+ Thread.sleep(SHORT_TIMEOUT * 2)
+ record.add(31)
+ }.start()
+ barrier.await() // barrier 1
+ // Should find the element in more than SHORT_TIMEOUT but less than TEST_TIMEOUT
+ var delay = measureTimeMillis {
+ barrier.await() // barrier 2
+ assertEquals(31, record.poll(TEST_TIMEOUT, 0))
+ }
+ assertTrue(delay in SHORT_TIMEOUT..TEST_TIMEOUT)
+ // Polling for an element already added in anothe thread (pos 0) : should return immediately
+ delay = measureTimeMillis { assertEquals(31, record.poll(TEST_TIMEOUT, 0)) }
+ assertTrue(delay < TEST_TIMEOUT, "Delay $delay > $TEST_TIMEOUT")
+ // Waiting for an element that never comes
+ delay = measureTimeMillis { assertNull(record.poll(SHORT_TIMEOUT, 1)) }
+ assertTrue(delay >= SHORT_TIMEOUT, "Delay $delay < $SHORT_TIMEOUT")
+ // Polling for an element that doesn't match what is already there
+ delay = measureTimeMillis { assertNull(record.poll(SHORT_TIMEOUT, 0) { it < 10 }) }
+ assertTrue(delay > SHORT_TIMEOUT)
+ }
+
+ // Just make sure the interpreter actually throws an exception when the spec
+ // does not conform to the behavior. The interpreter is just a tool to test a
+ // tool used for a tool for test, let's not have hundreds of tests for it ;
+ // if it's broken one of the tests using it will break.
+ @Test
+ fun testInterpreter() {
+ val interpretLine = __LINE__ + 2
+ try {
+ interpretTestSpec(useReadHeads = true, spec = """
+ add(4) | poll(1, 0) = 5
+ """)
+ fail("This spec should have thrown")
+ } catch (e: InterpretException) {
+ assertTrue(e.cause is AssertionError)
+ assertEquals(interpretLine + 1, e.stackTrace[0].lineNumber)
+ assertTrue(e.stackTrace[0].fileName.contains(__FILE__))
+ assertTrue(e.stackTrace[0].methodName.contains("testInterpreter"))
+ assertTrue(e.stackTrace[0].methodName.contains("thread1"))
+ }
+ }
+
+ @Test
+ fun testMultipleAdds() {
+ interpretTestSpec(useReadHeads = false, spec = """
+ add(2) | | |
+ | add(4) | |
+ | | add(6) |
+ | | | add(8)
+ poll(0, 0) = 2 time 0..1 | poll(0, 0) = 2 | poll(0, 0) = 2 | poll(0, 0) = 2
+ poll(0, 1) = 4 time 0..1 | poll(0, 1) = 4 | poll(0, 1) = 4 | poll(0, 1) = 4
+ poll(0, 2) = 6 time 0..1 | poll(0, 2) = 6 | poll(0, 2) = 6 | poll(0, 2) = 6
+ poll(0, 3) = 8 time 0..1 | poll(0, 3) = 8 | poll(0, 3) = 8 | poll(0, 3) = 8
+ """)
+ }
+
+ @Test
+ fun testConcurrentAdds() {
+ interpretTestSpec(useReadHeads = false, spec = """
+ add(2) | add(4) | add(6) | add(8)
+ add(1) | add(3) | add(5) | add(7)
+ poll(0, 1) is even | poll(0, 0) is even | poll(0, 3) is even | poll(0, 2) is even
+ poll(0, 5) is odd | poll(0, 4) is odd | poll(0, 7) is odd | poll(0, 6) is odd
+ """)
+ }
+
+ @Test
+ fun testMultiplePoll() {
+ interpretTestSpec(useReadHeads = false, spec = """
+ add(4) | poll(1, 0) = 4
+ | poll(0, 1) = null time 0..1
+ | poll(1, 1) = null time 1..2
+ sleep; add(7) | poll(2, 1) = 7 time 1..2
+ sleep; add(18) | poll(2, 2) = 18 time 1..2
+ """)
+ }
+
+ @Test
+ fun testMultiplePollWithPredicate() {
+ interpretTestSpec(useReadHeads = false, spec = """
+ | poll(1, 0) = null | poll(1, 0) = null
+ add(6) | poll(1, 0) = 6 |
+ add(11) | poll(1, 0) { > 20 } = null | poll(1, 0) { = 11 } = 11
+ | poll(1, 0) { > 8 } = 11 |
+ """)
+ }
+
+ @Test
+ fun testMultipleReadHeads() {
+ interpretTestSpec(useReadHeads = true, spec = """
+ | poll() = null | poll() = null | poll() = null
+ add(5) | | poll() = 5 |
+ | poll() = 5 | |
+ add(8) | poll() = 8 | poll() = 8 |
+ | | | poll() = 5
+ | | | poll() = 8
+ | | | poll() = null
+ | | poll() = null |
+ """)
+ }
+
+ @Test
+ fun testReadHeadPollWithPredicate() {
+ interpretTestSpec(useReadHeads = true, spec = """
+ add(5) | poll() { < 0 } = null
+ | poll() { > 5 } = null
+ add(10) |
+ | poll() { = 5 } = null // The "5" was skipped in the previous line
+ add(15) | poll() { > 8 } = 15 // The "10" was skipped in the previous line
+ | poll(1, 0) { > 8 } = 10 // 10 is the first element after pos 0 matching > 8
+ """)
+ }
+
+ @Test
+ fun testPollImmediatelyAdvancesReadhead() {
+ interpretTestSpec(useReadHeads = true, spec = """
+ add(1) | add(2) | add(3) | add(4)
+ mark = 0 | poll(0) { > 3 } = 4 | |
+ poll(0) { > 10 } = null | | |
+ mark = 4 | | |
+ poll() = null | | |
+ """)
+ }
+
+ @Test
+ fun testParallelReadHeads() {
+ interpretTestSpec(useReadHeads = true, spec = """
+ mark = 0 | mark = 0 | mark = 0 | mark = 0
+ add(2) | | |
+ | add(4) | |
+ | | add(6) |
+ | | | add(8)
+ poll() = 2 | poll() = 2 | poll() = 2 | poll() = 2
+ poll() = 4 | poll() = 4 | poll() = 4 | poll() = 4
+ poll() = 6 | poll() = 6 | poll() = 6 | mark = 2
+ poll() = 8 | poll() = 8 | mark = 3 | poll() = 6
+ mark = 4 | mark = 4 | poll() = 8 | poll() = 8
+ """)
+ }
+
+ @Test
+ fun testPeek() {
+ interpretTestSpec(useReadHeads = true, spec = """
+ add(2) | | |
+ | add(4) | |
+ | | add(6) |
+ | | | add(8)
+ peek() = 2 | poll() = 2 | poll() = 2 | peek() = 2
+ peek() = 2 | peek() = 4 | poll() = 4 | peek() = 2
+ peek() = 2 | peek() = 4 | peek() = 6 | poll() = 2
+ peek() = 2 | mark = 1 | mark = 2 | poll() = 4
+ mark = 0 | peek() = 4 | peek() = 6 | peek() = 6
+ poll() = 2 | poll() = 4 | poll() = 6 | poll() = 6
+ poll() = 4 | mark = 2 | poll() = 8 | peek() = 8
+ peek() = 6 | peek() = 6 | peek() = null | mark = 3
+ """)
+ }
/**
- * // TODO : add the following tests.
+ * // TODO : don't submit without this.
* Test poll()
- * - Put stuff, check that it's returned immediately
- * - Check that it waits and times out
- * - Check that it waits and finds the stuff added through the timeout
- * - Put stuff, check that it's returned immediately when it matches the predicate
* - Check that it immediately finds added stuff that matches
* Test ReadHead#poll()
* - All of the above, and:
* - Put stuff, check that it timeouts when it doesn't match the predicate, and the read head
* has advanced
* - Check that it immediately advances the read head
+ * - Check multiple read heads in different threads
* Test ReadHead#peek()
*/
}
+
+/**
+ * A small interpreter for testing parallel code. The interpreter will read a list of lines
+ * consisting of "|"-separated statements. Each column runs in a different concurrent thread
+ * and all threads wait for each other in between lines. Each statement is split on ";" then
+ * matched with regular expressions in the instructionTable constant, which contains the
+ * code associated with each statement.
+ *
+ * The time unit is defined in milliseconds by the INTERPRET_TIME_UNIT constant. Whitespace is
+ * ignored. Quick ref of supported expressions :
+ * sleep(x) : sleeps for x time units and returns Unit ; sleep alone means sleep(1)
+ * add(x) : calls and returns TrackRecord#add.
+ * poll(time, pos) [{ predicate }] : calls and returns TrackRecord#poll(x time units, pos).
+ * Optionally, a predicate may be specified.
+ * poll() [{ predicate }] : calls and returns ReadHead#poll(1 time unit). Optionally, a predicate
+ * may be specified.
+ * EXPR = VALUE : asserts that EXPR equals VALUE. EXPR is interpreted. VALUE can either be the
+ * string "null" or an int. Returns Unit.
+ * EXPR time x..y : measures the time taken by EXPR and asserts it took at least x and at most
+ * y time units.
+ * predicate must be one of "= x", "< x" or "> x".
+ */
+class SyntaxException(msg: String, cause: Throwable? = null) : RuntimeException(msg, cause)
+class InterpretException(
+ threadIndex: Int,
+ lineNum: Int,
+ className: String,
+ methodName: String,
+ fileName: String,
+ cause: Throwable
+) : RuntimeException(cause) {
+ init {
+ stackTrace = arrayOf(StackTraceElement(
+ className,
+ "$methodName:thread$threadIndex",
+ fileName,
+ lineNum)) + super.getStackTrace()
+ }
+}
+
+// Some small helpers to avoid to say the large ".groupValues[index].trim()" every time
+private fun MatchResult.strArg(index: Int) = this.groupValues[index].trim()
+private fun MatchResult.intArg(index: Int) = strArg(index).toInt()
+private fun MatchResult.timeArg(index: Int) = INTERPRET_TIME_UNIT * intArg(index)
+
+// Parses a { = x } or { < x } or { > x } string and returns the corresponding predicate
+// Returns an always-true predicate for empty and null arguments
+private fun makePredicate(spec: String?): (Int) -> Boolean {
+ if (spec.isNullOrEmpty()) return { true }
+ val match = Regex("""\{\s*([<>=])\s*(\d+)\s*\}""").matchEntire(spec)
+ if (null == match) throw SyntaxException("Predicate \"${spec}\"")
+ val arg = match.intArg(2)
+ return when (match.strArg(1)) {
+ ">" -> { i -> i > arg }
+ "<" -> { i -> i < arg }
+ "=" -> { i -> i == arg }
+ else -> throw RuntimeException("How did \"${spec}\" match this regexp ?")
+ }
+}
+
+const val DEBUG_INTERPRETER = true
+
+// The table contains pairs associating a regexp with the code to run. The statement is matched
+// against each matcher in sequence and when a match is found the associated code is run, passing
+// it the TrackRecord under test and the result of the regexp match.
+typealias InterpretMatcher = Pair<Regex, (TrackRecord<Int>, MatchResult) -> Any?>
+
+val interpretTable = listOf<InterpretMatcher>(
+ // Interpret an empty line as doing nothing.
+ Regex("") to { _, _ -> null },
+ Regex("(.*)//.*") to { t, r -> interpret(r.strArg(1), t) },
+ // Interpret "XXX time x..y" : run XXX and check it took at least x and not more than y
+ Regex("""(.*)\s*time\s*(\d+)\.\.(\d+)""") to { t, r ->
+ assertTrue(measureTimeMillis { interpret(r.strArg(1), t) } in r.timeArg(2)..r.timeArg(3))
+ },
+ // Interpret "XXX = YYY" : run XXX and assert its return value is equal to YYY. "null" supported
+ Regex("""(.*)\s*=\s*(null|\d+)""") to { t, r ->
+ interpret(r.strArg(1), t).also {
+ if ("null" == r.strArg(2)) assertNull(it) else assertEquals(r.intArg(2), it)
+ }
+ },
+ // Interpret "XXX is odd" : run XXX and assert its return value is odd ("even" works too)
+ Regex("(.*)\\s+is\\s+(even|odd)") to { t, r ->
+ interpret(r.strArg(1), t).also {
+ assertEquals((it as Int) % 2, if ("even" == r.strArg(2)) 0 else 1)
+ }
+ },
+ // Interpret sleep. Optional argument for the count, in INTERPRET_TIME_UNIT units.
+ Regex("""sleep(\((\d+)\))?""") to { t, r ->
+ SystemClock.sleep(if (r.strArg(2).isEmpty()) INTERPRET_TIME_UNIT else r.timeArg(2))
+ },
+ // Interpret "add(XXX)" as TrackRecord#add(int)
+ Regex("""add\((\d+)\)""") to { t, r ->
+ t.add(r.intArg(1))
+ },
+ // Interpret "poll(x, y)" as TrackRecord#poll(timeout = x * INTERPRET_TIME_UNIT, pos = y)
+ // Accepts an optional {} argument for the predicate (see makePredicate for syntax)
+ Regex("""poll\((\d+),\s*(\d+)\)\s*(\{.*\})?""") to { t, r ->
+ t.poll(r.timeArg(1), r.intArg(2), makePredicate(r.strArg(3)))
+ },
+ // ReadHead#poll. If this throws in the cast, the code is malformed and has passed "poll()"
+ // in a test that takes a TrackRecord that is not a ReadHead. It's technically possible to get
+ // the test code to not compile instead of throw, but it's vastly more complex and this will
+ // fail 100% at runtime any test that would not have compiled.
+ Regex("""poll\((\d+)?\)\s*(\{.*\})?""") to { t, r ->
+ (if (r.strArg(1).isEmpty()) INTERPRET_TIME_UNIT else r.timeArg(1)).let { time ->
+ (t as ArrayTrackRecord<Int>.ReadHead).poll(time, makePredicate(r.strArg(2)))
+ }
+ },
+ // ReadHead#mark. The same remarks apply as with ReadHead#poll.
+ Regex("mark") to { t, _ -> (t as ArrayTrackRecord<Int>.ReadHead).mark },
+ // ReadHead#peek. The same remarks apply as with ReadHead#poll.
+ Regex("peek\\(\\)") to { t, _ -> (t as ArrayTrackRecord<Int>.ReadHead).peek() }
+)
+
+// Split the line into multiple statements separated by ";" and execute them. Return whatever
+// the last statement returned.
+private fun <T : TrackRecord<Int>> interpretMultiple(instruction: String, r: T): Any? {
+ return instruction.split(";").map { interpret(it.trim(), r) }.last()
+}
+// Match the statement to a regex and interpret it.
+private fun <T : TrackRecord<Int>> interpret(instr: String, r: T): Any? {
+ val (matcher, code) =
+ interpretTable.find { instr matches it.first } ?: throw SyntaxException(instr)
+ val match = matcher.matchEntire(instr) ?: throw SyntaxException(instr)
+ return code(r, match)
+}
+
+// Create the ArrayTrackRecord<Int> under test, then spins as many threads as needed by the test
+// spec and interpret each program concurrently, having all threads waiting on a CyclicBarrier
+// after each line. If |useReadHeads| is true, it will create a ReadHead over the ArrayTrackRecord
+// in each thread and call the interpreted methods on that ; if it's false, it will call the
+// interpreted methods on the ArrayTrackRecord directly. Be careful that some instructions may
+// only be supported on ReadHead, and will throw if called when using useReadHeads = false.
+private fun interpretTestSpec(useReadHeads: Boolean, spec: String) {
+ // For nice stack traces
+ val callSite = getCallingMethod()
+ val lines = spec.trim().trim('\n').split("\n").map { it.split("|") }
+ // |threads| contains arrays of strings that make up the statements of a thread : in other
+ // words, it's an array that contains a list of statements for each column in the spec.
+ val threadCount = lines[0].size
+ assertTrue(lines.all { it.size == threadCount })
+ val threadInstructions = (0 until threadCount).map { i -> lines.map { it[i].trim() } }
+ val barrier = CyclicBarrier(threadCount)
+ val rec = ArrayTrackRecord<Int>()
+ var crash: InterpretException? = null
+ threadInstructions.mapIndexed { threadIndex, instructions ->
+ Thread {
+ val rh = if (useReadHeads) rec.newReadHead() else rec
+ barrier.await()
+ var lineNum = 0
+ instructions.forEach {
+ if (null != crash) return@Thread
+ lineNum += 1
+ try {
+ interpretMultiple(it, rh)
+ } catch (e: Throwable) {
+ // If fail() or some exception was called, the thread will come here ; if the
+ // exception isn't caught the process will crash, which is not nice for testing.
+ // Instead, catch the exception, cancel other threads, and report nicely.
+ // Catch throwable because fail() is AssertionError, which inherits from Error.
+ crash = InterpretException(threadIndex, callSite.lineNumber + lineNum,
+ callSite.className, callSite.methodName, callSite.fileName, e)
+ }
+ barrier.await()
+ }
+ }.also { it.start() }
+ }.forEach { it.join() }
+ // If the test failed, crash with line number
+ crash?.let { throw it }
+}
+
+private fun getCallingMethod(): StackTraceElement {
+ try {
+ throw RuntimeException()
+ } catch (e: RuntimeException) {
+ return e.stackTrace[3] // 0 is this method here, 1 is interpretTestSpec, 2 the lambda
+ }
+}
diff --git a/tests/unit/src/com/android/server/connectivity/NetworkMonitorTest.java b/tests/unit/src/com/android/server/connectivity/NetworkMonitorTest.java
index cf5cb4b..8f0974d 100644
--- a/tests/unit/src/com/android/server/connectivity/NetworkMonitorTest.java
+++ b/tests/unit/src/com/android/server/connectivity/NetworkMonitorTest.java
@@ -103,7 +103,8 @@ import org.mockito.Captor;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
import org.mockito.Spy;
-import org.mockito.verification.VerificationWithTimeout;
+import org.mockito.invocation.InvocationOnMock;
+import org.mockito.stubbing.Answer;
import java.io.IOException;
import java.net.HttpURLConnection;
@@ -252,36 +253,33 @@ public class NetworkMonitorTest {
// Queries on mCleartextDnsNetwork using DnsResolver#query.
doAnswer(invocation -> {
- String hostname = (String) invocation.getArgument(1);
- Executor executor = (Executor) invocation.getArgument(3);
- DnsResolver.Callback<List<InetAddress>> callback = invocation.getArgument(5);
-
- List<InetAddress> answer = getAnswer(invocation.getMock(), hostname);
- if (answer != null && answer.size() > 0) {
- new Handler(Looper.getMainLooper()).post(() -> {
- executor.execute(() -> callback.onAnswer(answer, 0));
- });
- }
- // If no answers, do nothing. sendDnsProbeWithTimeout will time out and throw UHE.
- return null;
+ return mockQuery(invocation, 1 /* posHostname */, 3 /* posExecutor */,
+ 5 /* posCallback */);
}).when(mDnsResolver).query(any(), any(), anyInt(), any(), any(), any());
- // Queries on mCleartextDnsNetwork using using DnsResolver#query with QueryType.
+ // Queries on mCleartextDnsNetwork using DnsResolver#query with QueryType.
doAnswer(invocation -> {
- String hostname = (String) invocation.getArgument(1);
- Executor executor = (Executor) invocation.getArgument(4);
- DnsResolver.Callback<List<InetAddress>> callback = invocation.getArgument(6);
-
- List<InetAddress> answer = getAnswer(invocation.getMock(), hostname);
- if (answer != null && answer.size() > 0) {
- new Handler(Looper.getMainLooper()).post(() -> {
- executor.execute(() -> callback.onAnswer(answer, 0));
- });
- }
- // If no answers, do nothing. sendDnsProbeWithTimeout will time out and throw UHE.
- return null;
+ return mockQuery(invocation, 1 /* posHostname */, 4 /* posExecutor */,
+ 6 /* posCallback */);
}).when(mDnsResolver).query(any(), any(), anyInt(), anyInt(), any(), any(), any());
}
+
+ // Mocking queries on DnsResolver#query.
+ private Answer mockQuery(InvocationOnMock invocation, int posHostname, int posExecutor,
+ int posCallback) {
+ String hostname = (String) invocation.getArgument(posHostname);
+ Executor executor = (Executor) invocation.getArgument(posExecutor);
+ DnsResolver.Callback<List<InetAddress>> callback = invocation.getArgument(posCallback);
+
+ List<InetAddress> answer = getAnswer(invocation.getMock(), hostname);
+ if (answer != null && answer.size() > 0) {
+ new Handler(Looper.getMainLooper()).post(() -> {
+ executor.execute(() -> callback.onAnswer(answer, 0));
+ });
+ }
+ // If no answers, do nothing. sendDnsProbeWithTimeout will time out and throw UHE.
+ return null;
+ }
}
private FakeDns mFakeDns;