diff options
5 files changed, 517 insertions, 66 deletions
diff --git a/src/com/android/networkstack/util/DnsUtils.java b/src/com/android/networkstack/util/DnsUtils.java index 4767d55..e68976a 100644 --- a/src/com/android/networkstack/util/DnsUtils.java +++ b/src/com/android/networkstack/util/DnsUtils.java @@ -21,21 +21,25 @@ import static android.net.DnsResolver.TYPE_A; import static android.net.DnsResolver.TYPE_AAAA; import android.annotation.NonNull; +import android.annotation.Nullable; import android.net.DnsResolver; import android.net.Network; import android.net.TrafficStats; +import android.net.util.Stopwatch; import android.util.Log; import com.android.internal.util.TrafficStatsConstants; +import com.android.server.connectivity.NetworkMonitor.DnsLogFunc; import java.net.InetAddress; import java.net.UnknownHostException; import java.util.ArrayList; import java.util.Arrays; import java.util.List; -import java.util.concurrent.CountDownLatch; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicReference; +import java.util.concurrent.TimeoutException; /** * Collection of utilities for dns query. @@ -44,6 +48,7 @@ public class DnsUtils { // Decide what queries to make depending on what IP addresses are on the system. public static final int TYPE_ADDRCONFIG = -1; private static final String TAG = DnsUtils.class.getSimpleName(); + private static final boolean DBG = Log.isLoggable(TAG, Log.DEBUG); /** * Return both A and AAAA query results regardless the ip address type of the giving network. @@ -51,27 +56,33 @@ public class DnsUtils { */ @NonNull public static InetAddress[] getAllByName(@NonNull final DnsResolver dnsResolver, - @NonNull final Network network, @NonNull String host, int timeout) - throws UnknownHostException { + @NonNull final Network network, @NonNull String host, int timeout, + @NonNull final DnsLogFunc logger) throws UnknownHostException { final List<InetAddress> result = new ArrayList<InetAddress>(); + final StringBuilder errorMsg = new StringBuilder(host); try { result.addAll(Arrays.asList( getAllByName(dnsResolver, network, host, TYPE_AAAA, FLAG_NO_CACHE_LOOKUP, - timeout))); + timeout, logger))); } catch (UnknownHostException e) { // Might happen if the host is v4-only, still need to query TYPE_A + errorMsg.append(String.format(" (%s)%s", dnsTypeToStr(TYPE_AAAA), e.getMessage())); } try { result.addAll(Arrays.asList( getAllByName(dnsResolver, network, host, TYPE_A, FLAG_NO_CACHE_LOOKUP, - timeout))); + timeout, logger))); } catch (UnknownHostException e) { // Might happen if the host is v6-only, still need to return AAAA answers + errorMsg.append(String.format(" (%s)%s", dnsTypeToStr(TYPE_A), e.getMessage())); } + if (result.size() == 0) { - throw new UnknownHostException(host); + logger.log("FAIL: " + errorMsg.toString()); + throw new UnknownHostException(errorMsg.toString()); } + logger.log("OK: " + host + " " + result.toString()); return result.toArray(new InetAddress[0]); } @@ -82,26 +93,34 @@ public class DnsUtils { @NonNull public static InetAddress[] getAllByName(@NonNull final DnsResolver dnsResolver, @NonNull final Network network, @NonNull final String host, int type, int flag, - int timeoutMs) throws UnknownHostException { - final CountDownLatch latch = new CountDownLatch(1); - final AtomicReference<List<InetAddress>> resultRef = new AtomicReference<>(); + int timeoutMs, @Nullable final DnsLogFunc logger) throws UnknownHostException { + final CompletableFuture<List<InetAddress>> resultRef = new CompletableFuture<>(); + final Stopwatch watch = new Stopwatch().start(); + final DnsResolver.Callback<List<InetAddress>> callback = - new DnsResolver.Callback<List<InetAddress>>() { + new DnsResolver.Callback<List<InetAddress>>() { @Override public void onAnswer(List<InetAddress> answer, int rcode) { - if (rcode == 0) { - resultRef.set(answer); + if (rcode == 0 && answer != null && answer.size() != 0) { + resultRef.complete(answer); + } else { + resultRef.completeExceptionally(new UnknownHostException()); } - latch.countDown(); } @Override public void onError(@NonNull DnsResolver.DnsException e) { - Log.d(TAG, "DNS error resolving " + host + ": " + e.getMessage()); - latch.countDown(); + if (DBG) { + Log.d(TAG, "DNS error resolving " + host, e); + } + resultRef.completeExceptionally(e); } }; + // TODO: Investigate whether this is still useful. + // The packets that actually do the DNS queries are sent by netd, but netd doesn't + // look at the tag at all. Given that this is a library, the tag should be passed in by the + // caller. final int oldTag = TrafficStats.getAndSetThreadStatsTag( TrafficStatsConstants.TAG_SYSTEM_PROBE); @@ -115,16 +134,52 @@ public class DnsUtils { TrafficStats.setThreadStatsTag(oldTag); + List<InetAddress> result = null; + Exception exception = null; try { - latch.await(timeoutMs, TimeUnit.MILLISECONDS); - } catch (InterruptedException e) { + result = resultRef.get(timeoutMs, TimeUnit.MILLISECONDS); + } catch (ExecutionException e) { + exception = e; + } catch (TimeoutException | InterruptedException e) { + exception = new UnknownHostException("Timeout"); + } finally { + logDnsResult(result, watch.stop() /* latency */, logger, type, + exception != null ? exception.getMessage() : "" /* errorMsg */); } - final List<InetAddress> result = resultRef.get(); - if (result == null || result.size() == 0) { - throw new UnknownHostException(host); - } + if (null != exception) throw (UnknownHostException) exception; return result.toArray(new InetAddress[0]); } + + private static void logDnsResult(@Nullable final List<InetAddress> results, final long latency, + @Nullable final DnsLogFunc logger, int type, @NonNull final String errorMsg) { + if (logger == null) { + return; + } + + if (results != null && results.size() != 0) { + final StringBuilder builder = new StringBuilder(); + for (InetAddress address : results) { + builder.append(',').append(address.getHostAddress()); + } + logger.log(String.format("%dms OK %s", latency, builder.substring(1))); + } else { + logger.log(String.format("%dms FAIL in type %s %s", latency, dnsTypeToStr(type), + errorMsg)); + } + } + + private static String dnsTypeToStr(int type) { + switch (type) { + case TYPE_A: + return "A"; + case TYPE_AAAA: + return "AAAA"; + case TYPE_ADDRCONFIG: + return "ADDRCONFIG"; + default: + } + return "UNDEFINED"; + } } diff --git a/src/com/android/server/connectivity/NetworkMonitor.java b/src/com/android/server/connectivity/NetworkMonitor.java index 0b2c051..d4b484d 100644 --- a/src/com/android/server/connectivity/NetworkMonitor.java +++ b/src/com/android/server/connectivity/NetworkMonitor.java @@ -1044,12 +1044,11 @@ public class NetworkMonitor extends StateMachine { try { // Do a blocking DNS resolution using the network-assigned nameservers. final InetAddress[] ips = DnsUtils.getAllByName(mDependencies.getDnsResolver(), - mCleartextDnsNetwork, mPrivateDnsProviderHostname, getDnsProbeTimeout()); + mCleartextDnsNetwork, mPrivateDnsProviderHostname, getDnsProbeTimeout(), + str -> validationLog("Strict mode hostname resolution " + str)); mPrivateDnsConfig = new PrivateDnsConfig(mPrivateDnsProviderHostname, ips); - validationLog("Strict mode hostname resolved: " + mPrivateDnsConfig); } catch (UnknownHostException uhe) { mPrivateDnsConfig = null; - validationLog("Strict mode hostname resolution failed: " + uhe.getMessage()); } mEvaluationState.noteProbeResult(NETWORK_VALIDATION_PROBE_PRIVDNS, (mPrivateDnsConfig != null) /* succeeded */); @@ -1153,7 +1152,6 @@ public class NetworkMonitor extends StateMachine { } else if (probeResult.isPartialConnectivity()) { mEvaluationState.reportEvaluationResult(NETWORK_VALIDATION_RESULT_PARTIAL, null /* redirectUrl */); - // Check if disable https probing needed. maybeDisableHttpsProbing(mAcceptPartialConnectivity); if (mAcceptPartialConnectivity) { transitionTo(mEvaluatingPrivateDnsState); @@ -1557,7 +1555,8 @@ public class NetworkMonitor extends StateMachine { protected InetAddress[] sendDnsProbeWithTimeout(String host, int timeoutMs) throws UnknownHostException { return DnsUtils.getAllByName(mDependencies.getDnsResolver(), mCleartextDnsNetwork, host, - TYPE_ADDRCONFIG, FLAG_EMPTY, timeoutMs); + TYPE_ADDRCONFIG, FLAG_EMPTY, timeoutMs, + str -> validationLog(ValidationProbeEvent.PROBE_DNS, host, str)); } /** Do a DNS resolution of the given server. */ @@ -1572,19 +1571,11 @@ public class NetworkMonitor extends StateMachine { String connectInfo; try { InetAddress[] addresses = sendDnsProbeWithTimeout(host, getDnsProbeTimeout()); - StringBuffer buffer = new StringBuffer(); - for (InetAddress address : addresses) { - buffer.append(',').append(address.getHostAddress()); - } result = ValidationProbeEvent.DNS_SUCCESS; - connectInfo = "OK " + buffer.substring(1); } catch (UnknownHostException e) { result = ValidationProbeEvent.DNS_FAILURE; - connectInfo = "FAIL"; } final long latency = watch.stop(); - validationLog(ValidationProbeEvent.PROBE_DNS, host, - String.format("%dms %s", latency, connectInfo)); logValidationProbe(latency, ValidationProbeEvent.PROBE_DNS, result); } @@ -2175,4 +2166,14 @@ public class NetworkMonitor extends StateMachine { } mEvaluationState.noteProbeResult(probeResult, succeeded); } + + /** + * Interface for logging dns results. + */ + public interface DnsLogFunc { + /** + * Log function. + */ + void log(String s); + } } diff --git a/tests/lib/src/com/android/testutils/FileUtils.kt b/tests/lib/src/com/android/testutils/FileUtils.kt new file mode 100644 index 0000000..edd8d83 --- /dev/null +++ b/tests/lib/src/com/android/testutils/FileUtils.kt @@ -0,0 +1,11 @@ +package com.android.testutils + +// This function is private because the 2 is hardcoded here, and is not correct if not called +// directly from __LINE__ or __FILE__. +private fun callerStackTrace(): StackTraceElement = try { + throw RuntimeException() +} catch (e: RuntimeException) { + e.stackTrace[2] // 0 is here, 1 is get() in __FILE__ or __LINE__ +} +val __FILE__: String get() = callerStackTrace().fileName +val __LINE__: Int get() = callerStackTrace().lineNumber diff --git a/tests/unit/src/android/net/testutils/TrackRecordTest.kt b/tests/unit/src/android/net/testutils/TrackRecordTest.kt index 5f3e81a..77c2cd3 100644 --- a/tests/unit/src/android/net/testutils/TrackRecordTest.kt +++ b/tests/unit/src/android/net/testutils/TrackRecordTest.kt @@ -16,10 +16,17 @@ package android.net.testutils +import android.os.SystemClock import com.android.testutils.ArrayTrackRecord +import com.android.testutils.TrackRecord +import com.android.testutils.__FILE__ +import com.android.testutils.__LINE__ import org.junit.Test import org.junit.runner.RunWith import org.junit.runners.JUnit4 +import java.util.concurrent.CyclicBarrier +import java.util.concurrent.TimeUnit +import kotlin.system.measureTimeMillis import kotlin.test.assertEquals import kotlin.test.assertFailsWith import kotlin.test.assertFalse @@ -30,6 +37,14 @@ import kotlin.test.fail val TEST_VALUES = listOf(4, 13, 52, 94, 41, 68, 11, 13, 51, 0, 91, 94, 33, 98, 14) const val ABSENT_VALUE = 2 +// Caution in changing these : some tests rely on the fact that TEST_TIMEOUT > 2 * SHORT_TIMEOUT +// and LONG_TIMEOUT > 2 * TEST_TIMEOUT +const val SHORT_TIMEOUT = 40L // ms +const val TEST_TIMEOUT = 200L // ms +const val LONG_TIMEOUT = 5000L // ms + +// The unit of time for interpreted tests +const val INTERPRET_TIME_UNIT = SHORT_TIMEOUT @RunWith(JUnit4::class) class TrackRecordTest { @@ -135,19 +150,390 @@ class TrackRecordTest { TEST_VALUES.subList(3, TEST_VALUES.size - 3)) } + fun testPollReturnsImmediately(record: TrackRecord<Int>) { + record.add(4) + val elapsed = measureTimeMillis { assertEquals(4, record.poll(LONG_TIMEOUT, 0)) } + // Should not have waited at all, in fact. + assertTrue(elapsed < LONG_TIMEOUT) + record.add(7) + record.add(9) + // Can poll multiple times for the same position, in whatever order + assertEquals(9, record.poll(0, 2)) + assertEquals(7, record.poll(Long.MAX_VALUE, 1)) + assertEquals(9, record.poll(0, 2)) + assertEquals(4, record.poll(0, 0)) + assertEquals(9, record.poll(0, 2) { it > 5 }) + assertEquals(7, record.poll(0, 0) { it > 5 }) + } + + @Test + fun testPollReturnsImmediately() { + testPollReturnsImmediately(ArrayTrackRecord()) + testPollReturnsImmediately(ArrayTrackRecord<Int>().newReadHead()) + } + + @Test + fun testPollTimesOut() { + val record = ArrayTrackRecord<Int>() + var delay = measureTimeMillis { assertNull(record.poll(SHORT_TIMEOUT, 0)) } + assertTrue(delay >= SHORT_TIMEOUT, "Delay $delay < $SHORT_TIMEOUT") + delay = measureTimeMillis { assertNull(record.poll(SHORT_TIMEOUT, 0) { it < 10 }) } + assertTrue(delay > SHORT_TIMEOUT) + } + + @Test + fun testPollWakesUp() { + val record = ArrayTrackRecord<Int>() + val barrier = CyclicBarrier(2) + Thread { + barrier.await(LONG_TIMEOUT, TimeUnit.MILLISECONDS) // barrier 1 + barrier.await() // barrier 2 + Thread.sleep(SHORT_TIMEOUT * 2) + record.add(31) + }.start() + barrier.await() // barrier 1 + // Should find the element in more than SHORT_TIMEOUT but less than TEST_TIMEOUT + var delay = measureTimeMillis { + barrier.await() // barrier 2 + assertEquals(31, record.poll(TEST_TIMEOUT, 0)) + } + assertTrue(delay in SHORT_TIMEOUT..TEST_TIMEOUT) + // Polling for an element already added in anothe thread (pos 0) : should return immediately + delay = measureTimeMillis { assertEquals(31, record.poll(TEST_TIMEOUT, 0)) } + assertTrue(delay < TEST_TIMEOUT, "Delay $delay > $TEST_TIMEOUT") + // Waiting for an element that never comes + delay = measureTimeMillis { assertNull(record.poll(SHORT_TIMEOUT, 1)) } + assertTrue(delay >= SHORT_TIMEOUT, "Delay $delay < $SHORT_TIMEOUT") + // Polling for an element that doesn't match what is already there + delay = measureTimeMillis { assertNull(record.poll(SHORT_TIMEOUT, 0) { it < 10 }) } + assertTrue(delay > SHORT_TIMEOUT) + } + + // Just make sure the interpreter actually throws an exception when the spec + // does not conform to the behavior. The interpreter is just a tool to test a + // tool used for a tool for test, let's not have hundreds of tests for it ; + // if it's broken one of the tests using it will break. + @Test + fun testInterpreter() { + val interpretLine = __LINE__ + 2 + try { + interpretTestSpec(useReadHeads = true, spec = """ + add(4) | poll(1, 0) = 5 + """) + fail("This spec should have thrown") + } catch (e: InterpretException) { + assertTrue(e.cause is AssertionError) + assertEquals(interpretLine + 1, e.stackTrace[0].lineNumber) + assertTrue(e.stackTrace[0].fileName.contains(__FILE__)) + assertTrue(e.stackTrace[0].methodName.contains("testInterpreter")) + assertTrue(e.stackTrace[0].methodName.contains("thread1")) + } + } + + @Test + fun testMultipleAdds() { + interpretTestSpec(useReadHeads = false, spec = """ + add(2) | | | + | add(4) | | + | | add(6) | + | | | add(8) + poll(0, 0) = 2 time 0..1 | poll(0, 0) = 2 | poll(0, 0) = 2 | poll(0, 0) = 2 + poll(0, 1) = 4 time 0..1 | poll(0, 1) = 4 | poll(0, 1) = 4 | poll(0, 1) = 4 + poll(0, 2) = 6 time 0..1 | poll(0, 2) = 6 | poll(0, 2) = 6 | poll(0, 2) = 6 + poll(0, 3) = 8 time 0..1 | poll(0, 3) = 8 | poll(0, 3) = 8 | poll(0, 3) = 8 + """) + } + + @Test + fun testConcurrentAdds() { + interpretTestSpec(useReadHeads = false, spec = """ + add(2) | add(4) | add(6) | add(8) + add(1) | add(3) | add(5) | add(7) + poll(0, 1) is even | poll(0, 0) is even | poll(0, 3) is even | poll(0, 2) is even + poll(0, 5) is odd | poll(0, 4) is odd | poll(0, 7) is odd | poll(0, 6) is odd + """) + } + + @Test + fun testMultiplePoll() { + interpretTestSpec(useReadHeads = false, spec = """ + add(4) | poll(1, 0) = 4 + | poll(0, 1) = null time 0..1 + | poll(1, 1) = null time 1..2 + sleep; add(7) | poll(2, 1) = 7 time 1..2 + sleep; add(18) | poll(2, 2) = 18 time 1..2 + """) + } + + @Test + fun testMultiplePollWithPredicate() { + interpretTestSpec(useReadHeads = false, spec = """ + | poll(1, 0) = null | poll(1, 0) = null + add(6) | poll(1, 0) = 6 | + add(11) | poll(1, 0) { > 20 } = null | poll(1, 0) { = 11 } = 11 + | poll(1, 0) { > 8 } = 11 | + """) + } + + @Test + fun testMultipleReadHeads() { + interpretTestSpec(useReadHeads = true, spec = """ + | poll() = null | poll() = null | poll() = null + add(5) | | poll() = 5 | + | poll() = 5 | | + add(8) | poll() = 8 | poll() = 8 | + | | | poll() = 5 + | | | poll() = 8 + | | | poll() = null + | | poll() = null | + """) + } + + @Test + fun testReadHeadPollWithPredicate() { + interpretTestSpec(useReadHeads = true, spec = """ + add(5) | poll() { < 0 } = null + | poll() { > 5 } = null + add(10) | + | poll() { = 5 } = null // The "5" was skipped in the previous line + add(15) | poll() { > 8 } = 15 // The "10" was skipped in the previous line + | poll(1, 0) { > 8 } = 10 // 10 is the first element after pos 0 matching > 8 + """) + } + + @Test + fun testPollImmediatelyAdvancesReadhead() { + interpretTestSpec(useReadHeads = true, spec = """ + add(1) | add(2) | add(3) | add(4) + mark = 0 | poll(0) { > 3 } = 4 | | + poll(0) { > 10 } = null | | | + mark = 4 | | | + poll() = null | | | + """) + } + + @Test + fun testParallelReadHeads() { + interpretTestSpec(useReadHeads = true, spec = """ + mark = 0 | mark = 0 | mark = 0 | mark = 0 + add(2) | | | + | add(4) | | + | | add(6) | + | | | add(8) + poll() = 2 | poll() = 2 | poll() = 2 | poll() = 2 + poll() = 4 | poll() = 4 | poll() = 4 | poll() = 4 + poll() = 6 | poll() = 6 | poll() = 6 | mark = 2 + poll() = 8 | poll() = 8 | mark = 3 | poll() = 6 + mark = 4 | mark = 4 | poll() = 8 | poll() = 8 + """) + } + + @Test + fun testPeek() { + interpretTestSpec(useReadHeads = true, spec = """ + add(2) | | | + | add(4) | | + | | add(6) | + | | | add(8) + peek() = 2 | poll() = 2 | poll() = 2 | peek() = 2 + peek() = 2 | peek() = 4 | poll() = 4 | peek() = 2 + peek() = 2 | peek() = 4 | peek() = 6 | poll() = 2 + peek() = 2 | mark = 1 | mark = 2 | poll() = 4 + mark = 0 | peek() = 4 | peek() = 6 | peek() = 6 + poll() = 2 | poll() = 4 | poll() = 6 | poll() = 6 + poll() = 4 | mark = 2 | poll() = 8 | peek() = 8 + peek() = 6 | peek() = 6 | peek() = null | mark = 3 + """) + } /** - * // TODO : add the following tests. + * // TODO : don't submit without this. * Test poll() - * - Put stuff, check that it's returned immediately - * - Check that it waits and times out - * - Check that it waits and finds the stuff added through the timeout - * - Put stuff, check that it's returned immediately when it matches the predicate * - Check that it immediately finds added stuff that matches * Test ReadHead#poll() * - All of the above, and: * - Put stuff, check that it timeouts when it doesn't match the predicate, and the read head * has advanced * - Check that it immediately advances the read head + * - Check multiple read heads in different threads * Test ReadHead#peek() */ } + +/** + * A small interpreter for testing parallel code. The interpreter will read a list of lines + * consisting of "|"-separated statements. Each column runs in a different concurrent thread + * and all threads wait for each other in between lines. Each statement is split on ";" then + * matched with regular expressions in the instructionTable constant, which contains the + * code associated with each statement. + * + * The time unit is defined in milliseconds by the INTERPRET_TIME_UNIT constant. Whitespace is + * ignored. Quick ref of supported expressions : + * sleep(x) : sleeps for x time units and returns Unit ; sleep alone means sleep(1) + * add(x) : calls and returns TrackRecord#add. + * poll(time, pos) [{ predicate }] : calls and returns TrackRecord#poll(x time units, pos). + * Optionally, a predicate may be specified. + * poll() [{ predicate }] : calls and returns ReadHead#poll(1 time unit). Optionally, a predicate + * may be specified. + * EXPR = VALUE : asserts that EXPR equals VALUE. EXPR is interpreted. VALUE can either be the + * string "null" or an int. Returns Unit. + * EXPR time x..y : measures the time taken by EXPR and asserts it took at least x and at most + * y time units. + * predicate must be one of "= x", "< x" or "> x". + */ +class SyntaxException(msg: String, cause: Throwable? = null) : RuntimeException(msg, cause) +class InterpretException( + threadIndex: Int, + lineNum: Int, + className: String, + methodName: String, + fileName: String, + cause: Throwable +) : RuntimeException(cause) { + init { + stackTrace = arrayOf(StackTraceElement( + className, + "$methodName:thread$threadIndex", + fileName, + lineNum)) + super.getStackTrace() + } +} + +// Some small helpers to avoid to say the large ".groupValues[index].trim()" every time +private fun MatchResult.strArg(index: Int) = this.groupValues[index].trim() +private fun MatchResult.intArg(index: Int) = strArg(index).toInt() +private fun MatchResult.timeArg(index: Int) = INTERPRET_TIME_UNIT * intArg(index) + +// Parses a { = x } or { < x } or { > x } string and returns the corresponding predicate +// Returns an always-true predicate for empty and null arguments +private fun makePredicate(spec: String?): (Int) -> Boolean { + if (spec.isNullOrEmpty()) return { true } + val match = Regex("""\{\s*([<>=])\s*(\d+)\s*\}""").matchEntire(spec) + if (null == match) throw SyntaxException("Predicate \"${spec}\"") + val arg = match.intArg(2) + return when (match.strArg(1)) { + ">" -> { i -> i > arg } + "<" -> { i -> i < arg } + "=" -> { i -> i == arg } + else -> throw RuntimeException("How did \"${spec}\" match this regexp ?") + } +} + +const val DEBUG_INTERPRETER = true + +// The table contains pairs associating a regexp with the code to run. The statement is matched +// against each matcher in sequence and when a match is found the associated code is run, passing +// it the TrackRecord under test and the result of the regexp match. +typealias InterpretMatcher = Pair<Regex, (TrackRecord<Int>, MatchResult) -> Any?> + +val interpretTable = listOf<InterpretMatcher>( + // Interpret an empty line as doing nothing. + Regex("") to { _, _ -> null }, + Regex("(.*)//.*") to { t, r -> interpret(r.strArg(1), t) }, + // Interpret "XXX time x..y" : run XXX and check it took at least x and not more than y + Regex("""(.*)\s*time\s*(\d+)\.\.(\d+)""") to { t, r -> + assertTrue(measureTimeMillis { interpret(r.strArg(1), t) } in r.timeArg(2)..r.timeArg(3)) + }, + // Interpret "XXX = YYY" : run XXX and assert its return value is equal to YYY. "null" supported + Regex("""(.*)\s*=\s*(null|\d+)""") to { t, r -> + interpret(r.strArg(1), t).also { + if ("null" == r.strArg(2)) assertNull(it) else assertEquals(r.intArg(2), it) + } + }, + // Interpret "XXX is odd" : run XXX and assert its return value is odd ("even" works too) + Regex("(.*)\\s+is\\s+(even|odd)") to { t, r -> + interpret(r.strArg(1), t).also { + assertEquals((it as Int) % 2, if ("even" == r.strArg(2)) 0 else 1) + } + }, + // Interpret sleep. Optional argument for the count, in INTERPRET_TIME_UNIT units. + Regex("""sleep(\((\d+)\))?""") to { t, r -> + SystemClock.sleep(if (r.strArg(2).isEmpty()) INTERPRET_TIME_UNIT else r.timeArg(2)) + }, + // Interpret "add(XXX)" as TrackRecord#add(int) + Regex("""add\((\d+)\)""") to { t, r -> + t.add(r.intArg(1)) + }, + // Interpret "poll(x, y)" as TrackRecord#poll(timeout = x * INTERPRET_TIME_UNIT, pos = y) + // Accepts an optional {} argument for the predicate (see makePredicate for syntax) + Regex("""poll\((\d+),\s*(\d+)\)\s*(\{.*\})?""") to { t, r -> + t.poll(r.timeArg(1), r.intArg(2), makePredicate(r.strArg(3))) + }, + // ReadHead#poll. If this throws in the cast, the code is malformed and has passed "poll()" + // in a test that takes a TrackRecord that is not a ReadHead. It's technically possible to get + // the test code to not compile instead of throw, but it's vastly more complex and this will + // fail 100% at runtime any test that would not have compiled. + Regex("""poll\((\d+)?\)\s*(\{.*\})?""") to { t, r -> + (if (r.strArg(1).isEmpty()) INTERPRET_TIME_UNIT else r.timeArg(1)).let { time -> + (t as ArrayTrackRecord<Int>.ReadHead).poll(time, makePredicate(r.strArg(2))) + } + }, + // ReadHead#mark. The same remarks apply as with ReadHead#poll. + Regex("mark") to { t, _ -> (t as ArrayTrackRecord<Int>.ReadHead).mark }, + // ReadHead#peek. The same remarks apply as with ReadHead#poll. + Regex("peek\\(\\)") to { t, _ -> (t as ArrayTrackRecord<Int>.ReadHead).peek() } +) + +// Split the line into multiple statements separated by ";" and execute them. Return whatever +// the last statement returned. +private fun <T : TrackRecord<Int>> interpretMultiple(instruction: String, r: T): Any? { + return instruction.split(";").map { interpret(it.trim(), r) }.last() +} +// Match the statement to a regex and interpret it. +private fun <T : TrackRecord<Int>> interpret(instr: String, r: T): Any? { + val (matcher, code) = + interpretTable.find { instr matches it.first } ?: throw SyntaxException(instr) + val match = matcher.matchEntire(instr) ?: throw SyntaxException(instr) + return code(r, match) +} + +// Create the ArrayTrackRecord<Int> under test, then spins as many threads as needed by the test +// spec and interpret each program concurrently, having all threads waiting on a CyclicBarrier +// after each line. If |useReadHeads| is true, it will create a ReadHead over the ArrayTrackRecord +// in each thread and call the interpreted methods on that ; if it's false, it will call the +// interpreted methods on the ArrayTrackRecord directly. Be careful that some instructions may +// only be supported on ReadHead, and will throw if called when using useReadHeads = false. +private fun interpretTestSpec(useReadHeads: Boolean, spec: String) { + // For nice stack traces + val callSite = getCallingMethod() + val lines = spec.trim().trim('\n').split("\n").map { it.split("|") } + // |threads| contains arrays of strings that make up the statements of a thread : in other + // words, it's an array that contains a list of statements for each column in the spec. + val threadCount = lines[0].size + assertTrue(lines.all { it.size == threadCount }) + val threadInstructions = (0 until threadCount).map { i -> lines.map { it[i].trim() } } + val barrier = CyclicBarrier(threadCount) + val rec = ArrayTrackRecord<Int>() + var crash: InterpretException? = null + threadInstructions.mapIndexed { threadIndex, instructions -> + Thread { + val rh = if (useReadHeads) rec.newReadHead() else rec + barrier.await() + var lineNum = 0 + instructions.forEach { + if (null != crash) return@Thread + lineNum += 1 + try { + interpretMultiple(it, rh) + } catch (e: Throwable) { + // If fail() or some exception was called, the thread will come here ; if the + // exception isn't caught the process will crash, which is not nice for testing. + // Instead, catch the exception, cancel other threads, and report nicely. + // Catch throwable because fail() is AssertionError, which inherits from Error. + crash = InterpretException(threadIndex, callSite.lineNumber + lineNum, + callSite.className, callSite.methodName, callSite.fileName, e) + } + barrier.await() + } + }.also { it.start() } + }.forEach { it.join() } + // If the test failed, crash with line number + crash?.let { throw it } +} + +private fun getCallingMethod(): StackTraceElement { + try { + throw RuntimeException() + } catch (e: RuntimeException) { + return e.stackTrace[3] // 0 is this method here, 1 is interpretTestSpec, 2 the lambda + } +} diff --git a/tests/unit/src/com/android/server/connectivity/NetworkMonitorTest.java b/tests/unit/src/com/android/server/connectivity/NetworkMonitorTest.java index cf5cb4b..8f0974d 100644 --- a/tests/unit/src/com/android/server/connectivity/NetworkMonitorTest.java +++ b/tests/unit/src/com/android/server/connectivity/NetworkMonitorTest.java @@ -103,7 +103,8 @@ import org.mockito.Captor; import org.mockito.Mock; import org.mockito.MockitoAnnotations; import org.mockito.Spy; -import org.mockito.verification.VerificationWithTimeout; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.stubbing.Answer; import java.io.IOException; import java.net.HttpURLConnection; @@ -252,36 +253,33 @@ public class NetworkMonitorTest { // Queries on mCleartextDnsNetwork using DnsResolver#query. doAnswer(invocation -> { - String hostname = (String) invocation.getArgument(1); - Executor executor = (Executor) invocation.getArgument(3); - DnsResolver.Callback<List<InetAddress>> callback = invocation.getArgument(5); - - List<InetAddress> answer = getAnswer(invocation.getMock(), hostname); - if (answer != null && answer.size() > 0) { - new Handler(Looper.getMainLooper()).post(() -> { - executor.execute(() -> callback.onAnswer(answer, 0)); - }); - } - // If no answers, do nothing. sendDnsProbeWithTimeout will time out and throw UHE. - return null; + return mockQuery(invocation, 1 /* posHostname */, 3 /* posExecutor */, + 5 /* posCallback */); }).when(mDnsResolver).query(any(), any(), anyInt(), any(), any(), any()); - // Queries on mCleartextDnsNetwork using using DnsResolver#query with QueryType. + // Queries on mCleartextDnsNetwork using DnsResolver#query with QueryType. doAnswer(invocation -> { - String hostname = (String) invocation.getArgument(1); - Executor executor = (Executor) invocation.getArgument(4); - DnsResolver.Callback<List<InetAddress>> callback = invocation.getArgument(6); - - List<InetAddress> answer = getAnswer(invocation.getMock(), hostname); - if (answer != null && answer.size() > 0) { - new Handler(Looper.getMainLooper()).post(() -> { - executor.execute(() -> callback.onAnswer(answer, 0)); - }); - } - // If no answers, do nothing. sendDnsProbeWithTimeout will time out and throw UHE. - return null; + return mockQuery(invocation, 1 /* posHostname */, 4 /* posExecutor */, + 6 /* posCallback */); }).when(mDnsResolver).query(any(), any(), anyInt(), anyInt(), any(), any(), any()); } + + // Mocking queries on DnsResolver#query. + private Answer mockQuery(InvocationOnMock invocation, int posHostname, int posExecutor, + int posCallback) { + String hostname = (String) invocation.getArgument(posHostname); + Executor executor = (Executor) invocation.getArgument(posExecutor); + DnsResolver.Callback<List<InetAddress>> callback = invocation.getArgument(posCallback); + + List<InetAddress> answer = getAnswer(invocation.getMock(), hostname); + if (answer != null && answer.size() > 0) { + new Handler(Looper.getMainLooper()).post(() -> { + executor.execute(() -> callback.onAnswer(answer, 0)); + }); + } + // If no answers, do nothing. sendDnsProbeWithTimeout will time out and throw UHE. + return null; + } } private FakeDns mFakeDns; |