diff options
author | Chalard Jean <jchalard@google.com> | 2020-04-24 02:51:07 +0000 |
---|---|---|
committer | Automerger Merge Worker <android-build-automerger-merge-worker@system.gserviceaccount.com> | 2020-04-24 02:51:07 +0000 |
commit | 9feb7a75f109d3fe6fd3f0aadfa1a253b67dde2d (patch) | |
tree | 665d6ca3c2ab3af5402deae17b0166cf0618660d | |
parent | 804f1f332b7e8a42fad13452a16a4e442ed0753b (diff) | |
parent | 36987092a2145e277705b7c4bfed0098e6d026df (diff) |
Improve testable utils. am: 36987092a2
Change-Id: Ia19f81d1d12bdf1f03728ce3005e68ba985ae314
3 files changed, 66 insertions, 13 deletions
diff --git a/tests/lib/multivariant/com/android/testutils/TrackRecord.kt b/tests/lib/multivariant/com/android/testutils/TrackRecord.kt index ab2dc64..3cdea12 100644 --- a/tests/lib/multivariant/com/android/testutils/TrackRecord.kt +++ b/tests/lib/multivariant/com/android/testutils/TrackRecord.kt @@ -196,7 +196,13 @@ class ArrayTrackRecord<E> : TrackRecord<E> { /** * @return the current value of the mark. */ - val mark get() = readHead.also { checkThread() } + var mark + get() = readHead.also { checkThread() } + set(v: Int) = rewind(v) + fun rewind(v: Int) { + checkThread() + readHead = v + } private fun checkThread() = check(Thread.currentThread() == owningThread) { "Must be called by the thread that created this object" diff --git a/tests/lib/src/com/android/testutils/TestableNetworkCallback.kt b/tests/lib/src/com/android/testutils/TestableNetworkCallback.kt index cfdff51..c9da2f2 100644 --- a/tests/lib/src/com/android/testutils/TestableNetworkCallback.kt +++ b/tests/lib/src/com/android/testutils/TestableNetworkCallback.kt @@ -36,6 +36,7 @@ import kotlin.test.assertTrue import kotlin.test.fail object NULL_NETWORK : Network(-1) +object ANY_NETWORK : Network(-2) private val Int.capabilityName get() = NetworkCapabilities.capabilityNameOf(this) @@ -100,7 +101,8 @@ open class RecorderCallback private constructor( } } - protected val history = backingRecord.newReadHead() + val history = backingRecord.newReadHead() + val mark get() = history.mark override fun onAvailable(network: Network) { history.add(Available(network)) @@ -172,17 +174,28 @@ open class TestableNetworkCallback private constructor( if (null != cb) fail("Expected no callback but got $cb") } + // Expects a callback of the specified type on the specified network within the timeout. + // If no callback arrives, or a different callback arrives, fail. Returns the callback. inline fun <reified T : CallbackEntry> expectCallback( - network: Network, + network: Network = ANY_NETWORK, timeoutMs: Long = defaultTimeoutMs ): T = pollForNextCallback(timeoutMs).let { - if (it !is T || it.network != network) { + if (it !is T || (ANY_NETWORK !== network && it.network != network)) { fail("Unexpected callback : $it, expected ${T::class} with Network[$network]") } else { it } } + // Expects a callback of the specified type matching the predicate within the timeout. + // Any callback that doesn't match the predicate will be skipped. Fails only if + // no matching callback is received within the timeout. + inline fun <reified T : CallbackEntry> eventuallyExpect( + timeoutMs: Long = defaultTimeoutMs, + from: Int = mark, + crossinline predicate: (T) -> Boolean = { true } + ) = history.poll(timeoutMs, from) { it is T && predicate(it) } as T + fun expectCallbackThat( timeoutMs: Long = defaultTimeoutMs, valid: (CallbackEntry) -> Boolean diff --git a/tests/lib/src/com/android/testutils/TestableNetworkStatsProvider.kt b/tests/lib/src/com/android/testutils/TestableNetworkStatsProvider.kt index 25fe38e..a4ef770 100644 --- a/tests/lib/src/com/android/testutils/TestableNetworkStatsProvider.kt +++ b/tests/lib/src/com/android/testutils/TestableNetworkStatsProvider.kt @@ -18,18 +18,23 @@ package com.android.testutils import android.net.netstats.provider.NetworkStatsProvider import kotlin.test.assertEquals +import kotlin.test.assertTrue import kotlin.test.fail private const val DEFAULT_TIMEOUT_MS = 200L -open class TestableNetworkStatsProvider : NetworkStatsProvider() { +open class TestableNetworkStatsProvider( + val defaultTimeoutMs: Long = DEFAULT_TIMEOUT_MS +) : NetworkStatsProvider() { sealed class CallbackType { data class OnRequestStatsUpdate(val token: Int) : CallbackType() data class OnSetLimit(val iface: String?, val quotaBytes: Long) : CallbackType() data class OnSetAlert(val quotaBytes: Long) : CallbackType() } - private val history = ArrayTrackRecord<CallbackType>().ReadHead() + val history = ArrayTrackRecord<CallbackType>().newReadHead() + // See ReadHead#mark + val mark get() = history.mark override fun onRequestStatsUpdate(token: Int) { history.add(CallbackType.OnRequestStatsUpdate(token)) @@ -43,20 +48,49 @@ open class TestableNetworkStatsProvider : NetworkStatsProvider() { history.add(CallbackType.OnSetAlert(quotaBytes)) } - fun expectOnRequestStatsUpdate(token: Int) { - assertEquals(CallbackType.OnRequestStatsUpdate(token), history.poll(DEFAULT_TIMEOUT_MS)) + fun expectOnRequestStatsUpdate(token: Int, timeout: Long = defaultTimeoutMs) { + assertEquals(CallbackType.OnRequestStatsUpdate(token), history.poll(timeout)) } - fun expectOnSetLimit(iface: String?, quotaBytes: Long) { - assertEquals(CallbackType.OnSetLimit(iface, quotaBytes), history.poll(DEFAULT_TIMEOUT_MS)) + fun expectOnSetLimit(iface: String?, quotaBytes: Long, timeout: Long = defaultTimeoutMs) { + assertEquals(CallbackType.OnSetLimit(iface, quotaBytes), history.poll(timeout)) } - fun expectOnSetAlert(quotaBytes: Long) { - assertEquals(CallbackType.OnSetAlert(quotaBytes), history.poll(DEFAULT_TIMEOUT_MS)) + fun expectOnSetAlert(quotaBytes: Long, timeout: Long = defaultTimeoutMs) { + assertEquals(CallbackType.OnSetAlert(quotaBytes), history.poll(timeout)) + } + + fun pollForNextCallback(timeout: Long = defaultTimeoutMs) = + history.poll(timeout) ?: fail("Did not receive callback after ${timeout}ms") + + inline fun <reified T : CallbackType> expectCallback( + timeout: Long = defaultTimeoutMs, + predicate: (T) -> Boolean = { true } + ): T { + return pollForNextCallback(timeout).also { assertTrue(it is T && predicate(it)) } as T + } + + // Expects a callback of the specified type matching the predicate within the timeout. + // Any callback that doesn't match the predicate will be skipped. Fails only if + // no matching callback is received within the timeout. + // TODO : factorize the code for this with the identical call in TestableNetworkCallback. + // There should be a common superclass doing this generically. + // TODO : have a better error message to have this fail. Right now the failure when no + // matching callback arrives comes from the casting to a non-nullable T. + // TODO : in fact, completely removing this method and have clients use + // history.poll(timeout, index, predicate) directly might be simpler. + inline fun <reified T : CallbackType> eventuallyExpect( + timeoutMs: Long = defaultTimeoutMs, + from: Int = mark, + crossinline predicate: (T) -> Boolean = { true } + ) = history.poll(timeoutMs, from) { it is T && predicate(it) } as T + + fun drainCallbacks() { + history.mark = history.size } @JvmOverloads - fun assertNoCallback(timeout: Long = DEFAULT_TIMEOUT_MS) { + fun assertNoCallback(timeout: Long = defaultTimeoutMs) { val cb = history.poll(timeout) cb?.let { fail("Expected no callback but got $cb") } } |