summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--common/networkstackclient/Android.bp2
-rw-r--r--common/networkstackclient/src/android/net/INetworkMonitorCallbacks.aidl1
-rw-r--r--src/android/net/apf/ApfFilter.java269
-rw-r--r--src/android/net/dhcp/DhcpClient.java160
-rw-r--r--src/android/net/util/FdEventsReader.java32
-rw-r--r--src/android/net/util/PacketReader.java2
-rw-r--r--src/com/android/server/connectivity/NetworkMonitor.java78
-rw-r--r--tests/integration/src/android/net/ip/IpClientIntegrationTest.java8
-rw-r--r--tests/lib/src/com/android/testutils/ConcurrentIntepreter.kt17
-rw-r--r--tests/lib/src/com/android/testutils/TestableNetworkCallback.kt85
-rw-r--r--tests/unit/Android.bp1
-rw-r--r--tests/unit/src/android/net/ip/IpReachabilityMonitorTest.java4
-rw-r--r--tests/unit/src/android/net/testutils/TestableNetworkCallbackTest.kt293
-rw-r--r--tests/unit/src/android/net/testutils/TrackRecordTest.kt3
-rw-r--r--tests/unit/src/android/net/util/PacketReaderTest.java32
-rw-r--r--tests/unit/src/com/android/server/connectivity/NetworkMonitorTest.java66
16 files changed, 769 insertions, 284 deletions
diff --git a/common/networkstackclient/Android.bp b/common/networkstackclient/Android.bp
index ccb3f45..c82f751 100644
--- a/common/networkstackclient/Android.bp
+++ b/common/networkstackclient/Android.bp
@@ -87,6 +87,6 @@ java_library {
],
static_libs: [
"ipmemorystore-aidl-interfaces-V3-java",
- "networkstack-aidl-interfaces-V3-java",
+ "networkstack-aidl-interfaces-java",
],
}
diff --git a/common/networkstackclient/src/android/net/INetworkMonitorCallbacks.aidl b/common/networkstackclient/src/android/net/INetworkMonitorCallbacks.aidl
index 2c61511..f8dcd6c 100644
--- a/common/networkstackclient/src/android/net/INetworkMonitorCallbacks.aidl
+++ b/common/networkstackclient/src/android/net/INetworkMonitorCallbacks.aidl
@@ -26,4 +26,5 @@ oneway interface INetworkMonitorCallbacks {
void notifyPrivateDnsConfigResolved(in PrivateDnsConfigParcel config);
void showProvisioningNotification(String action, String packageName);
void hideProvisioningNotification();
+ void notifyProbeStatusChanged(int probesCompleted, int probesSucceeded);
} \ No newline at end of file
diff --git a/src/android/net/apf/ApfFilter.java b/src/android/net/apf/ApfFilter.java
index 1d3421c..2f74ad6 100644
--- a/src/android/net/apf/ApfFilter.java
+++ b/src/android/net/apf/ApfFilter.java
@@ -57,7 +57,6 @@ import android.system.ErrnoException;
import android.system.Os;
import android.text.format.DateUtils;
import android.util.Log;
-import android.util.Pair;
import android.util.SparseArray;
import com.android.internal.annotations.GuardedBy;
@@ -500,6 +499,44 @@ public class ApfFilter {
}
}
+ /**
+ * Class to keep track of a section in a packet.
+ */
+ private static class PacketSection {
+ public enum Type {
+ MATCH, // A field that should be matched (e.g., the router IP address).
+ IGNORE, // An ignored field such as the checksum of the flow label. Not matched.
+ LIFETIME, // A lifetime. Not matched, and generally counts toward minimum RA lifetime.
+ }
+
+ /** The type of section. */
+ public final Type type;
+ /** Offset into the packet at which this section begins. */
+ public final int start;
+ /** Length of this section. */
+ public final int length;
+ /** If this is a lifetime, the ICMP option that the defined it. 0 for router lifetime. */
+ public final int option;
+ /** If this is a lifetime, the lifetime value. */
+ public final long lifetime;
+
+ PacketSection(int start, int length, Type type, int option, long lifetime) {
+ this.start = start;
+ this.length = length;
+ this.type = type;
+ this.option = option;
+ this.lifetime = lifetime;
+ }
+
+ public String toString() {
+ if (type == Type.LIFETIME) {
+ return String.format("%s: (%d, %d) %d %d", type, start, length, option, lifetime);
+ } else {
+ return String.format("%s: (%d, %d)", type, start, length);
+ }
+ }
+ }
+
// A class to hold information about an RA.
@VisibleForTesting
class Ra {
@@ -534,10 +571,10 @@ public class ApfFilter {
// Note: mPacket's position() cannot be assumed to be reset.
private final ByteBuffer mPacket;
- // List of binary ranges that include the whole packet except the lifetimes.
- // Pairs consist of offset and length.
- private final ArrayList<Pair<Integer, Integer>> mNonLifetimes =
- new ArrayList<Pair<Integer, Integer>>();
+
+ // List of sections in the packet.
+ private final ArrayList<PacketSection> mPacketSections = new ArrayList<>();
+
// Minimum lifetime in packet
long mMinLifetime;
// When the packet was last captured, in seconds since Unix Epoch
@@ -649,27 +686,65 @@ public class ApfFilter {
}
/**
- * Add a binary range of the packet that does not include a lifetime to mNonLifetimes.
- * Assumes mPacket.position() is as far as we've parsed the packet.
- * @param lastNonLifetimeStart offset within packet of where the last binary range of
- * data not including a lifetime.
- * @param lifetimeOffset offset from mPacket.position() to the next lifetime data.
- * @param lifetimeLength length of the next lifetime data.
- * @return offset within packet of where the next binary range of data not including
- * a lifetime. This can be passed into the next invocation of this function
- * via {@code lastNonLifetimeStart}.
+ * Add a packet section that should be matched, starting from the current position.
+ * @param length the length of the section
+ */
+ private void addMatchSection(int length) {
+ // Don't generate JNEBS instruction for 0 bytes as they will fail the
+ // ASSERT_FORWARD_IN_PROGRAM(pc + cmp_imm - 1) check (where cmp_imm is
+ // the number of bytes to compare) and immediately pass the packet.
+ // The code does not attempt to generate such matches, but add a safety
+ // check to prevent doing so in the presence of bugs or malformed or
+ // truncated packets.
+ if (length == 0) return;
+ mPacketSections.add(
+ new PacketSection(mPacket.position(), length, PacketSection.Type.MATCH, 0, 0));
+ mPacket.position(mPacket.position() + length);
+ }
+
+ /**
+ * Add a packet section that should be matched, starting from the current position.
+ * @param end the offset in the packet before which the section ends
+ */
+ private void addMatchUntil(int end) {
+ addMatchSection(end - mPacket.position());
+ }
+
+ /**
+ * Add a packet section that should be ignored, starting from the current position.
+ * @param length the length of the section in bytes
*/
- private int addNonLifetime(int lastNonLifetimeStart, int lifetimeOffset,
- int lifetimeLength) {
- lifetimeOffset += mPacket.position();
- mNonLifetimes.add(new Pair<Integer, Integer>(lastNonLifetimeStart,
- lifetimeOffset - lastNonLifetimeStart));
- return lifetimeOffset + lifetimeLength;
+ private void addIgnoreSection(int length) {
+ mPacketSections.add(
+ new PacketSection(mPacket.position(), length, PacketSection.Type.IGNORE, 0, 0));
+ mPacket.position(mPacket.position() + length);
}
- private int addNonLifetimeU32(int lastNonLifetimeStart) {
- return addNonLifetime(lastNonLifetimeStart,
- ICMP6_4_BYTE_LIFETIME_OFFSET, ICMP6_4_BYTE_LIFETIME_LEN);
+ /**
+ * Add a packet section that represents a lifetime, starting from the current position.
+ * @param length the length of the section in bytes
+ * @param optionType the RA option containing this lifetime, or 0 for router lifetime
+ * @param lifetime the lifetime
+ */
+ private void addLifetimeSection(int length, int optionType, long lifetime) {
+ mPacketSections.add(
+ new PacketSection(mPacket.position(), length, PacketSection.Type.LIFETIME,
+ optionType, lifetime));
+ mPacket.position(mPacket.position() + length);
+ }
+
+ /**
+ * Adds packet sections for an RA option with a 4-byte lifetime 4 bytes into the option
+ * @param optionType the RA option that is being added
+ * @param optionLength the length of the option in bytes
+ */
+ private long add4ByteLifetimeOption(int optionType, int optionLength) {
+ addMatchSection(ICMP6_4_BYTE_LIFETIME_OFFSET);
+ final long lifetime = getUint32(mPacket, mPacket.position());
+ addLifetimeSection(ICMP6_4_BYTE_LIFETIME_LEN, optionType, lifetime);
+ addMatchSection(optionLength - ICMP6_4_BYTE_LIFETIME_OFFSET
+ - ICMP6_4_BYTE_LIFETIME_LEN);
+ return lifetime;
}
// Note that this parses RA and may throw InvalidRaException (from
@@ -696,20 +771,18 @@ public class ApfFilter {
RaEvent.Builder builder = new RaEvent.Builder();
// Ignore the flow label and low 4 bits of traffic class.
- int lastNonLifetimeStart = addNonLifetime(0,
- IPV6_FLOW_LABEL_OFFSET,
- IPV6_FLOW_LABEL_LEN);
+ addMatchUntil(IPV6_FLOW_LABEL_OFFSET);
+ addIgnoreSection(IPV6_FLOW_LABEL_LEN);
- // Ignore the checksum.
- lastNonLifetimeStart = addNonLifetime(lastNonLifetimeStart,
- ICMP6_RA_CHECKSUM_OFFSET,
- ICMP6_RA_CHECKSUM_LEN);
+ // Ignore checksum.
+ addMatchUntil(ICMP6_RA_CHECKSUM_OFFSET);
+ addIgnoreSection(ICMP6_RA_CHECKSUM_LEN);
// Parse router lifetime
- lastNonLifetimeStart = addNonLifetime(lastNonLifetimeStart,
- ICMP6_RA_ROUTER_LIFETIME_OFFSET,
- ICMP6_RA_ROUTER_LIFETIME_LEN);
- builder.updateRouterLifetime(getUint16(mPacket, ICMP6_RA_ROUTER_LIFETIME_OFFSET));
+ addMatchUntil(ICMP6_RA_ROUTER_LIFETIME_OFFSET);
+ final long routerLifetime = getUint16(mPacket, ICMP6_RA_ROUTER_LIFETIME_OFFSET);
+ addLifetimeSection(ICMP6_RA_ROUTER_LIFETIME_LEN, 0, routerLifetime);
+ builder.updateRouterLifetime(routerLifetime);
// Ensures that the RA is not truncated.
mPacket.position(ICMP6_RA_OPTION_OFFSET);
@@ -720,64 +793,62 @@ public class ApfFilter {
long lifetime;
switch (optionType) {
case ICMP6_PREFIX_OPTION_TYPE:
+ mPrefixOptionOffsets.add(position);
+
// Parse valid lifetime
- lastNonLifetimeStart = addNonLifetime(lastNonLifetimeStart,
- ICMP6_PREFIX_OPTION_VALID_LIFETIME_OFFSET,
- ICMP6_PREFIX_OPTION_VALID_LIFETIME_LEN);
- lifetime = getUint32(mPacket,
- position + ICMP6_PREFIX_OPTION_VALID_LIFETIME_OFFSET);
+ addMatchSection(ICMP6_PREFIX_OPTION_VALID_LIFETIME_LEN);
+ lifetime = getUint32(mPacket, mPacket.position());
+ addLifetimeSection(ICMP6_PREFIX_OPTION_VALID_LIFETIME_LEN,
+ ICMP6_PREFIX_OPTION_TYPE, lifetime);
builder.updatePrefixValidLifetime(lifetime);
+
// Parse preferred lifetime
- lastNonLifetimeStart = addNonLifetime(lastNonLifetimeStart,
- ICMP6_PREFIX_OPTION_PREFERRED_LIFETIME_OFFSET,
- ICMP6_PREFIX_OPTION_PREFERRED_LIFETIME_LEN);
- lifetime = getUint32(mPacket,
- position + ICMP6_PREFIX_OPTION_PREFERRED_LIFETIME_OFFSET);
+ lifetime = getUint32(mPacket, mPacket.position());
+ addLifetimeSection(ICMP6_PREFIX_OPTION_PREFERRED_LIFETIME_LEN,
+ ICMP6_PREFIX_OPTION_TYPE, lifetime);
builder.updatePrefixPreferredLifetime(lifetime);
- mPrefixOptionOffsets.add(position);
+
+ addMatchSection(4); // Reserved bytes
+ addMatchSection(IPV6_ADDR_LEN); // The prefix itself
break;
// These three options have the same lifetime offset and size, and
- // are processed with the same specialized addNonLifetimeU32:
+ // are processed with the same specialized add4ByteLifetimeOption:
case ICMP6_RDNSS_OPTION_TYPE:
mRdnssOptionOffsets.add(position);
- lastNonLifetimeStart = addNonLifetimeU32(lastNonLifetimeStart);
- lifetime = getUint32(mPacket, position + ICMP6_4_BYTE_LIFETIME_OFFSET);
+ lifetime = add4ByteLifetimeOption(optionType, optionLength);
builder.updateRdnssLifetime(lifetime);
break;
case ICMP6_ROUTE_INFO_OPTION_TYPE:
mRioOptionOffsets.add(position);
- lastNonLifetimeStart = addNonLifetimeU32(lastNonLifetimeStart);
- lifetime = getUint32(mPacket, position + ICMP6_4_BYTE_LIFETIME_OFFSET);
+ lifetime = add4ByteLifetimeOption(optionType, optionLength);
builder.updateRouteInfoLifetime(lifetime);
break;
case ICMP6_DNSSL_OPTION_TYPE:
- lastNonLifetimeStart = addNonLifetimeU32(lastNonLifetimeStart);
- lifetime = getUint32(mPacket, position + ICMP6_4_BYTE_LIFETIME_OFFSET);
+ lifetime = add4ByteLifetimeOption(optionType, optionLength);
builder.updateDnsslLifetime(lifetime);
break;
default:
- // RFC4861 section 4.2 dictates we ignore unknown options for fowards
+ // RFC4861 section 4.2 dictates we ignore unknown options for forwards
// compatibility.
+ mPacket.position(position + optionLength);
break;
}
if (optionLength <= 0) {
throw new InvalidRaException(String.format(
"Invalid option length opt=%d len=%d", optionType, optionLength));
}
- mPacket.position(position + optionLength);
}
- // Mark non-lifetime bytes since last lifetime.
- addNonLifetime(lastNonLifetimeStart, 0, 0);
- mMinLifetime = minLifetime(packet, length);
+ mMinLifetime = minLifetime();
mMetricsLog.log(builder.build());
}
- // Ignoring lifetimes (which may change) does {@code packet} match this RA?
+ // Considering only the MATCH sections, does {@code packet} match this RA?
boolean matches(byte[] packet, int length) {
if (length != mPacket.capacity()) return false;
byte[] referencePacket = mPacket.array();
- for (Pair<Integer, Integer> nonLifetime : mNonLifetimes) {
- for (int i = nonLifetime.first; i < (nonLifetime.first + nonLifetime.second); i++) {
+ for (PacketSection section : mPacketSections) {
+ if (section.type != PacketSection.Type.MATCH) continue;
+ for (int i = section.start; i < (section.start + section.length); i++) {
if (packet[i] != referencePacket[i]) return false;
}
}
@@ -786,36 +857,12 @@ public class ApfFilter {
// What is the minimum of all lifetimes within {@code packet} in seconds?
// Precondition: matches(packet, length) already returned true.
- long minLifetime(byte[] packet, int length) {
+ long minLifetime() {
long minLifetime = Long.MAX_VALUE;
- // Wrap packet in ByteBuffer so we can read big-endian values easily
- ByteBuffer byteBuffer = ByteBuffer.wrap(packet);
- for (int i = 0; (i + 1) < mNonLifetimes.size(); i++) {
- int offset = mNonLifetimes.get(i).first + mNonLifetimes.get(i).second;
-
- // The flow label is in mNonLifetimes, but it's not a lifetime.
- if (offset == IPV6_FLOW_LABEL_OFFSET) {
- continue;
- }
-
- // The checksum is in mNonLifetimes, but it's not a lifetime.
- if (offset == ICMP6_RA_CHECKSUM_OFFSET) {
- continue;
- }
-
- final int lifetimeLength = mNonLifetimes.get(i+1).first - offset;
- final long optionLifetime;
- switch (lifetimeLength) {
- case 2:
- optionLifetime = getUint16(byteBuffer, offset);
- break;
- case 4:
- optionLifetime = getUint32(byteBuffer, offset);
- break;
- default:
- throw new IllegalStateException("bogus lifetime size " + lifetimeLength);
+ for (PacketSection section : mPacketSections) {
+ if (section.type == PacketSection.Type.LIFETIME) {
+ minLifetime = Math.min(minLifetime, section.lifetime);
}
- minLifetime = Math.min(minLifetime, optionLifetime);
}
return minLifetime;
}
@@ -844,38 +891,24 @@ public class ApfFilter {
// Skip filter if expired
gen.addLoadFromMemory(Register.R0, gen.FILTER_AGE_MEMORY_SLOT);
gen.addJumpIfR0GreaterThan(filterLifetime, nextFilterLabel);
- for (int i = 0; i < mNonLifetimes.size(); i++) {
- // Generate code to match the packet bytes
- Pair<Integer, Integer> nonLifetime = mNonLifetimes.get(i);
- // Don't generate JNEBS instruction for 0 bytes as it always fails the
- // ASSERT_FORWARD_IN_PROGRAM(pc + cmp_imm - 1) check where cmp_imm is
- // the number of bytes to compare. nonLifetime is zero between the
- // valid and preferred lifetimes in the prefix option.
- if (nonLifetime.second != 0) {
- gen.addLoadImmediate(Register.R0, nonLifetime.first);
+ for (PacketSection section : mPacketSections) {
+ // Generate code to match the packet bytes.
+ if (section.type == PacketSection.Type.MATCH) {
+ gen.addLoadImmediate(Register.R0, section.start);
gen.addJumpIfBytesNotEqual(Register.R0,
- Arrays.copyOfRange(mPacket.array(), nonLifetime.first,
- nonLifetime.first + nonLifetime.second),
+ Arrays.copyOfRange(mPacket.array(), section.start,
+ section.start + section.length),
nextFilterLabel);
}
- // Generate code to test the lifetimes haven't gone down too far
- if ((i + 1) < mNonLifetimes.size()) {
- Pair<Integer, Integer> nextNonLifetime = mNonLifetimes.get(i + 1);
- int offset = nonLifetime.first + nonLifetime.second;
-
- // Skip the Flow label.
- if (offset == IPV6_FLOW_LABEL_OFFSET) {
- continue;
- }
- // Skip the checksum.
- if (offset == ICMP6_RA_CHECKSUM_OFFSET) {
- continue;
- }
- int length = nextNonLifetime.first - offset;
- switch (length) {
- case 4: gen.addLoad32(Register.R0, offset); break;
- case 2: gen.addLoad16(Register.R0, offset); break;
- default: throw new IllegalStateException("bogus lifetime size " + length);
+ // Generate code to test the lifetimes haven't gone down too far.
+ // The packet is accepted if any of its lifetimes are lower than filterLifetime.
+ if (section.type == PacketSection.Type.LIFETIME) {
+ switch (section.length) {
+ case 4: gen.addLoad32(Register.R0, section.start); break;
+ case 2: gen.addLoad16(Register.R0, section.start); break;
+ default:
+ throw new IllegalStateException(
+ "bogus lifetime size " + section.length);
}
gen.addJumpIfR0LessThan(filterLifetime, nextFilterLabel);
}
@@ -1682,7 +1715,7 @@ public class ApfFilter {
if (VDBG) log("matched RA " + ra);
// Update lifetimes.
ra.mLastSeen = currentTimeSeconds();
- ra.mMinLifetime = ra.minLifetime(packet, length);
+ ra.mMinLifetime = ra.minLifetime();
ra.seenCount++;
// Keep mRas in LRU order so as to prioritize generating filters for recently seen
diff --git a/src/android/net/dhcp/DhcpClient.java b/src/android/net/dhcp/DhcpClient.java
index bf63c1c..e88c7cc 100644
--- a/src/android/net/dhcp/DhcpClient.java
+++ b/src/android/net/dhcp/DhcpClient.java
@@ -37,6 +37,7 @@ import static android.system.OsConstants.AF_PACKET;
import static android.system.OsConstants.ETH_P_IP;
import static android.system.OsConstants.IPPROTO_UDP;
import static android.system.OsConstants.SOCK_DGRAM;
+import static android.system.OsConstants.SOCK_NONBLOCK;
import static android.system.OsConstants.SOCK_RAW;
import static android.system.OsConstants.SOL_SOCKET;
import static android.system.OsConstants.SO_BROADCAST;
@@ -46,6 +47,7 @@ import static android.system.OsConstants.SO_REUSEADDR;
import static com.android.server.util.NetworkStackConstants.IPV4_ADDR_ANY;
import android.annotation.NonNull;
+import android.annotation.Nullable;
import android.content.Context;
import android.net.DhcpResults;
import android.net.InetAddresses;
@@ -60,7 +62,9 @@ import android.net.metrics.DhcpErrorEvent;
import android.net.metrics.IpConnectivityLog;
import android.net.util.InterfaceParams;
import android.net.util.NetworkStackUtils;
+import android.net.util.PacketReader;
import android.net.util.SocketUtils;
+import android.os.Handler;
import android.os.Message;
import android.os.SystemClock;
import android.system.ErrnoException;
@@ -210,14 +214,9 @@ public class DhcpClient extends StateMachine {
private final Random mRandom;
private final IpConnectivityLog mMetricsLog = new IpConnectivityLog();
- // Sockets.
- // - We use a packet socket to receive, because servers send us packets bound for IP addresses
- // which we have not yet configured, and the kernel protocol stack drops these.
- // - We use a UDP socket to send, so the kernel handles ARP and routing for us (DHCP servers can
- // be off-link as well as on-link).
- private FileDescriptor mPacketSock;
+ // We use a UDP socket to send, so the kernel handles ARP and routing for us (DHCP servers can
+ // be off-link as well as on-link).
private FileDescriptor mUdpSock;
- private ReceiveThread mReceiveThread;
// State variables.
private final StateMachine mController;
@@ -244,6 +243,8 @@ public class DhcpClient extends StateMachine {
private Dependencies mDependencies;
@NonNull
private final NetworkStackIpMemoryStore mIpMemoryStore;
+ @Nullable
+ private DhcpPacketHandler mDhcpPacketHandler;
// Milliseconds SystemClock timestamps used to record transition times to DhcpBoundState.
private long mLastInitEnterTime;
@@ -396,23 +397,6 @@ public class DhcpClient extends StateMachine {
mTransactionStartMillis = SystemClock.elapsedRealtime();
}
- private boolean initSockets() {
- return initPacketSocket() && initUdpSocket();
- }
-
- private boolean initPacketSocket() {
- try {
- mPacketSock = Os.socket(AF_PACKET, SOCK_RAW, ETH_P_IP);
- SocketAddress addr = makePacketSocketAddress(ETH_P_IP, mIface.index);
- Os.bind(mPacketSock, addr);
- NetworkStackUtils.attachDhcpFilter(mPacketSock);
- } catch(SocketException|ErrnoException e) {
- Log.e(TAG, "Error creating packet socket", e);
- return false;
- }
- return true;
- }
-
private boolean initUdpSocket() {
final int oldTag = TrafficStats.getAndSetThreadStatsTag(
TrafficStatsConstants.TAG_SYSTEM_DHCP);
@@ -423,7 +407,7 @@ public class DhcpClient extends StateMachine {
Os.setsockoptInt(mUdpSock, SOL_SOCKET, SO_BROADCAST, 1);
Os.setsockoptInt(mUdpSock, SOL_SOCKET, SO_RCVBUF, 0);
Os.bind(mUdpSock, IPV4_ADDR_ANY, DhcpPacket.DHCP_CLIENT);
- } catch(SocketException|ErrnoException e) {
+ } catch (SocketException | ErrnoException e) {
Log.e(TAG, "Error creating UDP socket", e);
return false;
} finally {
@@ -436,59 +420,76 @@ public class DhcpClient extends StateMachine {
try {
Os.connect(mUdpSock, to, DhcpPacket.DHCP_SERVER);
return true;
- } catch (SocketException|ErrnoException e) {
+ } catch (SocketException | ErrnoException e) {
Log.e(TAG, "Error connecting UDP socket", e);
return false;
}
}
- private void closeSockets() {
- closeSocketQuietly(mUdpSock);
- closeSocketQuietly(mPacketSock);
- }
+ private class DhcpPacketHandler extends PacketReader {
+ private FileDescriptor mPacketSock;
- class ReceiveThread extends Thread {
+ DhcpPacketHandler(Handler handler) {
+ super(handler);
+ }
- private final byte[] mPacket = new byte[DhcpPacket.MAX_LENGTH];
- private volatile boolean mStopped = false;
+ @Override
+ protected void handlePacket(byte[] recvbuf, int length) {
+ try {
+ final DhcpPacket packet = DhcpPacket.decodeFullPacket(recvbuf, length,
+ DhcpPacket.ENCAP_L2);
+ if (DBG) Log.d(TAG, "Received packet: " + packet);
+ sendMessage(CMD_RECEIVED_PACKET, packet);
+ } catch (DhcpPacket.ParseException e) {
+ Log.e(TAG, "Can't parse packet: " + e.getMessage());
+ if (PACKET_DBG) {
+ Log.d(TAG, HexDump.dumpHexString(recvbuf, 0, length));
+ }
+ if (e.errorCode == DhcpErrorEvent.DHCP_NO_COOKIE) {
+ final int snetTagId = 0x534e4554;
+ final String bugId = "31850211";
+ final int uid = -1;
+ final String data = DhcpPacket.ParseException.class.getName();
+ EventLog.writeEvent(snetTagId, bugId, uid, data);
+ }
+ mMetricsLog.log(mIfaceName, new DhcpErrorEvent(e.errorCode));
+ }
+ }
- public void halt() {
- mStopped = true;
- closeSockets(); // Interrupts the read() call the thread is blocked in.
+ @Override
+ protected FileDescriptor createFd() {
+ try {
+ mPacketSock = Os.socket(AF_PACKET, SOCK_RAW | SOCK_NONBLOCK, 0 /* protocol */);
+ NetworkStackUtils.attachDhcpFilter(mPacketSock);
+ final SocketAddress addr = makePacketSocketAddress(ETH_P_IP, mIface.index);
+ Os.bind(mPacketSock, addr);
+ } catch (SocketException | ErrnoException e) {
+ logError("Error creating packet socket", e);
+ closeFd(mPacketSock);
+ mPacketSock = null;
+ return null;
+ }
+ return mPacketSock;
}
@Override
- public void run() {
- if (DBG) Log.d(TAG, "Receive thread started");
- while (!mStopped) {
- int length = 0; // Or compiler can't tell it's initialized if a parse error occurs.
- try {
- length = Os.read(mPacketSock, mPacket, 0, mPacket.length);
- DhcpPacket packet = null;
- packet = DhcpPacket.decodeFullPacket(mPacket, length, DhcpPacket.ENCAP_L2);
- if (DBG) Log.d(TAG, "Received packet: " + packet);
- sendMessage(CMD_RECEIVED_PACKET, packet);
- } catch (IOException|ErrnoException e) {
- if (!mStopped) {
- Log.e(TAG, "Read error", e);
- logError(DhcpErrorEvent.RECEIVE_ERROR);
- }
- } catch (DhcpPacket.ParseException e) {
- Log.e(TAG, "Can't parse packet: " + e.getMessage());
- if (PACKET_DBG) {
- Log.d(TAG, HexDump.dumpHexString(mPacket, 0, length));
- }
- if (e.errorCode == DhcpErrorEvent.DHCP_NO_COOKIE) {
- int snetTagId = 0x534e4554;
- String bugId = "31850211";
- int uid = -1;
- String data = DhcpPacket.ParseException.class.getName();
- EventLog.writeEvent(snetTagId, bugId, uid, data);
- }
- logError(e.errorCode);
- }
+ protected int readPacket(FileDescriptor fd, byte[] packetBuffer) throws Exception {
+ try {
+ return Os.read(fd, packetBuffer, 0, packetBuffer.length);
+ } catch (IOException | ErrnoException e) {
+ mMetricsLog.log(mIfaceName, new DhcpErrorEvent(DhcpErrorEvent.RECEIVE_ERROR));
+ throw e;
}
- if (DBG) Log.d(TAG, "Receive thread stopped");
+ }
+
+ @Override
+ protected void logError(@NonNull String msg, @Nullable Exception e) {
+ Log.e(TAG, msg, e);
+ }
+
+ public int transmitPacket(final ByteBuffer buf, final SocketAddress socketAddress)
+ throws ErrnoException, SocketException {
+ return Os.sendto(mPacketSock, buf.array(), 0, buf.limit(), 0, socketAddress);
}
}
@@ -500,7 +501,7 @@ public class DhcpClient extends StateMachine {
try {
if (encap == DhcpPacket.ENCAP_L2) {
if (DBG) Log.d(TAG, "Broadcasting " + description);
- Os.sendto(mPacketSock, buf.array(), 0, buf.limit(), 0, mInterfaceBroadcastAddr);
+ mDhcpPacketHandler.transmitPacket(buf, mInterfaceBroadcastAddr);
} else if (encap == DhcpPacket.ENCAP_BOOTP && to.equals(INADDR_BROADCAST)) {
if (DBG) Log.d(TAG, "Broadcasting " + description);
// We only send L3-encapped broadcasts in DhcpRebindingState,
@@ -517,7 +518,7 @@ public class DhcpClient extends StateMachine {
description, Os.getpeername(mUdpSock)));
Os.write(mUdpSock, buf);
}
- } catch(ErrnoException|IOException e) {
+ } catch (ErrnoException | IOException e) {
Log.e(TAG, "Can't send packet: ", e);
return false;
}
@@ -774,21 +775,22 @@ public class DhcpClient extends StateMachine {
@Override
public void enter() {
clearDhcpState();
- if (initInterface() && initSockets()) {
- mReceiveThread = new ReceiveThread();
- mReceiveThread.start();
- } else {
- notifyFailure();
- transitionTo(mStoppedState);
+ if (initInterface() && initUdpSocket()) {
+ mDhcpPacketHandler = new DhcpPacketHandler(getHandler());
+ if (mDhcpPacketHandler.start()) return;
+ Log.e(TAG, "Fail to start DHCP Packet Handler");
}
+ notifyFailure();
+ transitionTo(mStoppedState);
}
@Override
public void exit() {
- if (mReceiveThread != null) {
- mReceiveThread.halt(); // Also closes sockets.
- mReceiveThread = null;
+ if (mDhcpPacketHandler != null) {
+ mDhcpPacketHandler.stop();
+ if (DBG) Log.d(TAG, "DHCP Packet Handler stopped");
}
+ closeSocketQuietly(mUdpSock);
clearDhcpState();
}
@@ -1289,10 +1291,6 @@ public class DhcpClient extends StateMachine {
class DhcpRebootingState extends LoggingState {
}
- private void logError(int errorCode) {
- mMetricsLog.log(mIfaceName, new DhcpErrorEvent(errorCode));
- }
-
private void logState(String name, int durationMs) {
final DhcpClientEvent event = new DhcpClientEvent.Builder()
.setMsg(name)
diff --git a/src/android/net/util/FdEventsReader.java b/src/android/net/util/FdEventsReader.java
index 1380ea7..e82c69b 100644
--- a/src/android/net/util/FdEventsReader.java
+++ b/src/android/net/util/FdEventsReader.java
@@ -63,7 +63,6 @@ import java.io.IOException;
* the Handler constructor argument is associated.
*
* @param <BufferType> the type of the buffer used to read data.
- * @hide
*/
public abstract class FdEventsReader<BufferType> {
private static final int FD_EVENTS = EVENT_INPUT | EVENT_ERROR;
@@ -93,27 +92,21 @@ public abstract class FdEventsReader<BufferType> {
}
/** Start this FdEventsReader. */
- public void start() {
- if (onCorrectThread()) {
- createAndRegisterFd();
- } else {
- mHandler.post(() -> {
- logError("start() called from off-thread", null);
- createAndRegisterFd();
- });
+ public boolean start() {
+ if (!onCorrectThread()) {
+ throw new IllegalStateException("start() called from off-thread");
}
+
+ return createAndRegisterFd();
}
/** Stop this FdEventsReader and destroy the file descriptor. */
public void stop() {
- if (onCorrectThread()) {
- unregisterAndDestroyFd();
- } else {
- mHandler.post(() -> {
- logError("stop() called from off-thread", null);
- unregisterAndDestroyFd();
- });
+ if (!onCorrectThread()) {
+ throw new IllegalStateException("stop() called from off-thread");
}
+
+ unregisterAndDestroyFd();
}
@NonNull
@@ -178,8 +171,8 @@ public abstract class FdEventsReader<BufferType> {
*/
protected void onStop() {}
- private void createAndRegisterFd() {
- if (mFd != null) return;
+ private boolean createAndRegisterFd() {
+ if (mFd != null) return true;
try {
mFd = createFd();
@@ -189,7 +182,7 @@ public abstract class FdEventsReader<BufferType> {
mFd = null;
}
- if (mFd == null) return;
+ if (mFd == null) return false;
mQueue.addOnFileDescriptorEventListener(
mFd,
@@ -205,6 +198,7 @@ public abstract class FdEventsReader<BufferType> {
return FD_EVENTS;
});
onStart();
+ return true;
}
private boolean isRunning() {
diff --git a/src/android/net/util/PacketReader.java b/src/android/net/util/PacketReader.java
index 4aec6b6..0be7187 100644
--- a/src/android/net/util/PacketReader.java
+++ b/src/android/net/util/PacketReader.java
@@ -28,8 +28,6 @@ import java.io.FileDescriptor;
*
* TODO: rename this class to something more correctly descriptive (something
* like [or less horrible than] FdReadEventsHandler?).
- *
- * @hide
*/
public abstract class PacketReader extends FdEventsReader<byte[]> {
diff --git a/src/com/android/server/connectivity/NetworkMonitor.java b/src/com/android/server/connectivity/NetworkMonitor.java
index 585e38e..e08170a 100644
--- a/src/com/android/server/connectivity/NetworkMonitor.java
+++ b/src/com/android/server/connectivity/NetworkMonitor.java
@@ -368,7 +368,13 @@ public class NetworkMonitor extends StateMachine {
} catch (RemoteException e) {
version = 0;
}
- if (version == Build.VERSION_CODES.CUR_DEVELOPMENT) version = 0;
+ // The AIDL was freezed from Q beta 5 but it's unfreezing from R before releasing. In order
+ // to distinguish the behavior between R and Q beta 5 and before Q beta 5, add SDK and
+ // CODENAME check here. Basically, it's only expected to return 0 for Q beta 4 and below
+ // because the test result has changed.
+ if (Build.VERSION.SDK_INT == Build.VERSION_CODES.Q
+ && Build.VERSION.CODENAME.equals("REL")
+ && version == Build.VERSION_CODES.CUR_DEVELOPMENT) version = 0;
return version;
}
@@ -555,6 +561,14 @@ public class NetworkMonitor extends StateMachine {
}
}
+ private void notifyProbeStatusChanged(int probesCompleted, int probesSucceeded) {
+ try {
+ mCallback.notifyProbeStatusChanged(probesCompleted, probesSucceeded);
+ } catch (RemoteException e) {
+ Log.e(TAG, "Error sending probe status", e);
+ }
+ }
+
private void showProvisioningNotification(String action) {
try {
mCallback.showProvisioningNotification(action, mContext.getPackageName());
@@ -667,6 +681,8 @@ public class NetworkMonitor extends StateMachine {
// no resolved IP addresses, IPs unreachable,
// port 853 unreachable, port 853 is not running a
// DNS-over-TLS server, et cetera).
+ // Cancel any outstanding CMD_EVALUATE_PRIVATE_DNS.
+ removeMessages(CMD_EVALUATE_PRIVATE_DNS);
sendMessage(CMD_EVALUATE_PRIVATE_DNS);
break;
}
@@ -1020,11 +1036,19 @@ public class NetworkMonitor extends StateMachine {
handlePrivateDnsEvaluationFailure();
break;
}
+ handlePrivateDnsEvaluationSuccess();
+ } else {
+ mEvaluationState.removeProbeResult(NETWORK_VALIDATION_PROBE_PRIVDNS);
}
// All good!
transitionTo(mValidatedState);
break;
+ case CMD_PRIVATE_DNS_SETTINGS_CHANGED:
+ // When settings change the reevaluation timer must be reset.
+ mPrivateDnsReevalDelayMs = INITIAL_REEVALUATE_DELAY_MS;
+ // Let the message bubble up and be handled by parent states as usual.
+ return NOT_HANDLED;
default:
return NOT_HANDLED;
}
@@ -1051,8 +1075,6 @@ public class NetworkMonitor extends StateMachine {
} catch (UnknownHostException uhe) {
mPrivateDnsConfig = null;
}
- mEvaluationState.noteProbeResult(NETWORK_VALIDATION_PROBE_PRIVDNS,
- (mPrivateDnsConfig != null) /* succeeded */);
}
private void notifyPrivateDnsConfigResolved() {
@@ -1063,7 +1085,14 @@ public class NetworkMonitor extends StateMachine {
}
}
+ private void handlePrivateDnsEvaluationSuccess() {
+ mEvaluationState.noteProbeResult(NETWORK_VALIDATION_PROBE_PRIVDNS,
+ true /* succeeded */);
+ }
+
private void handlePrivateDnsEvaluationFailure() {
+ mEvaluationState.noteProbeResult(NETWORK_VALIDATION_PROBE_PRIVDNS,
+ false /* succeeded */);
mEvaluationState.reportEvaluationResult(NETWORK_VALIDATION_RESULT_INVALID,
null /* redirectUrl */);
// Queue up a re-evaluation with backoff.
@@ -1072,10 +1101,6 @@ public class NetworkMonitor extends StateMachine {
// transitioning back to EvaluatingState, to perhaps give ourselves
// the opportunity to (re)detect a captive portal or something.
//
- // TODO: distinguish between CMD_EVALUATE_PRIVATE_DNS messages that are caused by server
- // lookup failures (which should continue to do exponential backoff) and
- // CMD_EVALUATE_PRIVATE_DNS messages that are caused by user reconfiguration (which
- // should be processed immediately.
sendMessageDelayed(CMD_EVALUATE_PRIVATE_DNS, mPrivateDnsReevalDelayMs);
mPrivateDnsReevalDelayMs *= 2;
if (mPrivateDnsReevalDelayMs > MAX_REEVALUATE_DELAY_MS) {
@@ -1101,7 +1126,6 @@ public class NetworkMonitor extends StateMachine {
String.format("%dms - Error: %s", time, uhe.getMessage()));
}
logValidationProbe(time, PROBE_PRIVDNS, success ? DNS_SUCCESS : DNS_FAILURE);
- mEvaluationState.noteProbeResult(NETWORK_VALIDATION_PROBE_PRIVDNS, success);
return success;
}
}
@@ -2105,22 +2129,43 @@ public class NetworkMonitor extends StateMachine {
// Indicates which probes have completed since clearProbeResults was called.
// This is a bitmask of INetworkMonitor.NETWORK_VALIDATION_PROBE_* constants.
private int mProbeResults = 0;
+ // A bitmask to record which probes are completed.
+ private int mProbeCompleted = 0;
// The latest redirect URL.
private String mRedirectUrl;
protected void clearProbeResults() {
mProbeResults = 0;
+ mProbeCompleted = 0;
}
- // Probe result for http probe should be updated from reportHttpProbeResult().
- protected void noteProbeResult(int probeResult, boolean succeeded) {
- if (succeeded) {
- mProbeResults |= probeResult;
- } else {
- mProbeResults &= ~probeResult;
+ private void maybeNotifyProbeResults(@NonNull final Runnable modif) {
+ final int oldCompleted = mProbeCompleted;
+ final int oldResults = mProbeResults;
+ modif.run();
+ if (oldCompleted != mProbeCompleted || oldResults != mProbeResults) {
+ notifyProbeStatusChanged(mProbeCompleted, mProbeResults);
}
}
+ protected void removeProbeResult(final int probeResult) {
+ maybeNotifyProbeResults(() -> {
+ mProbeCompleted &= ~probeResult;
+ mProbeResults &= ~probeResult;
+ });
+ }
+
+ protected void noteProbeResult(final int probeResult, final boolean succeeded) {
+ maybeNotifyProbeResults(() -> {
+ mProbeCompleted |= probeResult;
+ if (succeeded) {
+ mProbeResults |= probeResult;
+ } else {
+ mProbeResults &= ~probeResult;
+ }
+ });
+ }
+
protected void reportEvaluationResult(int result, @Nullable String redirectUrl) {
mEvaluationResult = result;
mRedirectUrl = redirectUrl;
@@ -2139,6 +2184,11 @@ public class NetworkMonitor extends StateMachine {
}
return mEvaluationResult | mProbeResults;
}
+
+ @VisibleForTesting
+ protected int getProbeCompletedResult() {
+ return mProbeCompleted;
+ }
}
@VisibleForTesting
diff --git a/tests/integration/src/android/net/ip/IpClientIntegrationTest.java b/tests/integration/src/android/net/ip/IpClientIntegrationTest.java
index cb7d418..f32171c 100644
--- a/tests/integration/src/android/net/ip/IpClientIntegrationTest.java
+++ b/tests/integration/src/android/net/ip/IpClientIntegrationTest.java
@@ -130,6 +130,7 @@ public class IpClientIntegrationTest {
private String mIfaceName;
private INetd mNetd;
private HandlerThread mPacketReaderThread;
+ private Handler mHandler;
private TapPacketReader mPacketReader;
private IpClient mIpc;
private Dependencies mDependencies;
@@ -271,7 +272,7 @@ public class IpClientIntegrationTest {
@After
public void tearDown() throws Exception {
if (mPacketReader != null) {
- mPacketReader.stop(); // Also closes the socket
+ mHandler.post(() -> mPacketReader.stop()); // Also closes the socket
}
if (mPacketReaderThread != null) {
mPacketReaderThread.quitSafely();
@@ -297,10 +298,11 @@ public class IpClientIntegrationTest {
mIfaceName = iface.getInterfaceName();
mPacketReaderThread = new HandlerThread(IpClientIntegrationTest.class.getSimpleName());
mPacketReaderThread.start();
+ mHandler = mPacketReaderThread.getThreadHandler();
final ParcelFileDescriptor tapFd = iface.getFileDescriptor();
- mPacketReader = new TapPacketReader(mPacketReaderThread.getThreadHandler(), tapFd);
- mPacketReader.start();
+ mPacketReader = new TapPacketReader(mHandler, tapFd);
+ mHandler.post(() -> mPacketReader.start());
}
private void setUpIpClient() throws Exception {
diff --git a/tests/lib/src/com/android/testutils/ConcurrentIntepreter.kt b/tests/lib/src/com/android/testutils/ConcurrentIntepreter.kt
index 7796dcd..25b1e0f 100644
--- a/tests/lib/src/com/android/testutils/ConcurrentIntepreter.kt
+++ b/tests/lib/src/com/android/testutils/ConcurrentIntepreter.kt
@@ -4,6 +4,7 @@ import android.os.SystemClock
import java.util.concurrent.CyclicBarrier
import kotlin.system.measureTimeMillis
import kotlin.test.assertEquals
+import kotlin.test.assertFails
import kotlin.test.assertNull
import kotlin.test.assertTrue
@@ -64,7 +65,15 @@ open class ConcurrentIntepreter<T>(
// Spins as many threads as needed by the test spec and interpret each program concurrently,
// having all threads waiting on a CyclicBarrier after each line.
- fun interpretTestSpec(spec: String, initial: T, threadTransform: (T) -> T = { it }) {
+ // |lineShift| says how many lines after the call the spec starts. This is used for error
+ // reporting. Unfortunately AFAICT there is no way to get the line of an argument rather
+ // than the line at which the expression starts.
+ fun interpretTestSpec(
+ spec: String,
+ initial: T,
+ lineShift: Int = 0,
+ threadTransform: (T) -> T = { it }
+ ) {
// For nice stack traces
val callSite = getCallingMethod()
val lines = spec.trim().trim('\n').split("\n").map { it.split("|") }
@@ -91,7 +100,8 @@ open class ConcurrentIntepreter<T>(
// testing. Instead, catch the exception, cancel other threads, and report
// nicely. Catch throwable because fail() is AssertionError, which inherits
// from Error.
- crash = InterpretException(threadIndex, it, callSite.lineNumber + lineNum,
+ crash = InterpretException(threadIndex, it,
+ callSite.lineNumber + lineNum + lineShift,
callSite.className, callSite.methodName, callSite.fileName, e)
}
barrier.await()
@@ -147,6 +157,9 @@ private fun <T> getDefaultInstructions() = listOf<InterpretMatcher<T>>(
// Interpret sleep. Optional argument for the count, in INTERPRET_TIME_UNIT units.
Regex("""sleep(\((\d+)\))?""") to { i, t, r ->
SystemClock.sleep(if (r.strArg(2).isEmpty()) i.interpretTimeUnit else r.timeArg(2))
+ },
+ Regex("""(.*)\s*fails""") to { i, t, r ->
+ assertFails { i.interpret(r.strArg(1), t) }
}
)
diff --git a/tests/lib/src/com/android/testutils/TestableNetworkCallback.kt b/tests/lib/src/com/android/testutils/TestableNetworkCallback.kt
index 1cc1168..bbb279e 100644
--- a/tests/lib/src/com/android/testutils/TestableNetworkCallback.kt
+++ b/tests/lib/src/com/android/testutils/TestableNetworkCallback.kt
@@ -21,11 +21,15 @@ import android.net.LinkProperties
import android.net.Network
import android.net.NetworkCapabilities
import android.net.NetworkCapabilities.NET_CAPABILITY_VALIDATED
-import com.android.testutils.RecorderCallback.CallbackRecord.Available
-import com.android.testutils.RecorderCallback.CallbackRecord.BlockedStatus
-import com.android.testutils.RecorderCallback.CallbackRecord.CapabilitiesChanged
-import com.android.testutils.RecorderCallback.CallbackRecord.LinkPropertiesChanged
-import com.android.testutils.RecorderCallback.CallbackRecord.Lost
+import com.android.testutils.RecorderCallback.CallbackEntry.Available
+import com.android.testutils.RecorderCallback.CallbackEntry.BlockedStatus
+import com.android.testutils.RecorderCallback.CallbackEntry.CapabilitiesChanged
+import com.android.testutils.RecorderCallback.CallbackEntry.LinkPropertiesChanged
+import com.android.testutils.RecorderCallback.CallbackEntry.Losing
+import com.android.testutils.RecorderCallback.CallbackEntry.Lost
+import com.android.testutils.RecorderCallback.CallbackEntry.Resumed
+import com.android.testutils.RecorderCallback.CallbackEntry.Suspended
+import com.android.testutils.RecorderCallback.CallbackEntry.Unavailable
import kotlin.reflect.KClass
import kotlin.test.assertEquals
import kotlin.test.assertTrue
@@ -35,38 +39,43 @@ object NULL_NETWORK : Network(-1)
private val Int.capabilityName get() = NetworkCapabilities.capabilityNameOf(this)
-open class RecorderCallback : NetworkCallback() {
- sealed class CallbackRecord {
+open class RecorderCallback private constructor(
+ private val backingRecord: ArrayTrackRecord<CallbackEntry>
+) : NetworkCallback() {
+ public constructor() : this(ArrayTrackRecord())
+ protected constructor(src: RecorderCallback?): this(src?.backingRecord ?: ArrayTrackRecord())
+
+ sealed class CallbackEntry {
// To get equals(), hashcode(), componentN() etc for free, the child classes of
// this class are data classes. But while data classes can inherit from other classes,
// they may only have visible members in the constructors, so they couldn't declare
- // a constructor with a non-val arg to pass to CallbackRecord. Instead, force all
+ // a constructor with a non-val arg to pass to CallbackEntry. Instead, force all
// subclasses to implement a `network' property, which can be done in a data class
// constructor by specifying override.
abstract val network: Network
- data class Available(override val network: Network) : CallbackRecord()
+ data class Available(override val network: Network) : CallbackEntry()
data class CapabilitiesChanged(
override val network: Network,
val caps: NetworkCapabilities
- ) : CallbackRecord()
+ ) : CallbackEntry()
data class LinkPropertiesChanged(
override val network: Network,
val lp: LinkProperties
- ) : CallbackRecord()
- data class Suspended(override val network: Network) : CallbackRecord()
- data class Resumed(override val network: Network) : CallbackRecord()
- data class Losing(override val network: Network, val maxMsToLive: Int) : CallbackRecord()
- data class Lost(override val network: Network) : CallbackRecord()
+ ) : CallbackEntry()
+ data class Suspended(override val network: Network) : CallbackEntry()
+ data class Resumed(override val network: Network) : CallbackEntry()
+ data class Losing(override val network: Network, val maxMsToLive: Int) : CallbackEntry()
+ data class Lost(override val network: Network) : CallbackEntry()
data class Unavailable private constructor(
override val network: Network
- ) : CallbackRecord() {
+ ) : CallbackEntry() {
constructor() : this(NULL_NETWORK)
}
data class BlockedStatus(
override val network: Network,
val blocked: Boolean
- ) : CallbackRecord()
+ ) : CallbackEntry()
// Convenience constants for expecting a type
companion object {
@@ -91,12 +100,15 @@ open class RecorderCallback : NetworkCallback() {
}
}
- protected val history = ArrayTrackRecord<CallbackRecord>().newReadHead()
+ protected val history = backingRecord.newReadHead()
override fun onAvailable(network: Network) {
history.add(Available(network))
}
+ // PreCheck is not used in the tests today. For backward compatibility with existing tests that
+ // expect the callbacks not to record this, do not listen to PreCheck here.
+
override fun onCapabilitiesChanged(network: Network, caps: NetworkCapabilities) {
history.add(CapabilitiesChanged(network, caps))
}
@@ -110,39 +122,46 @@ open class RecorderCallback : NetworkCallback() {
}
override fun onNetworkSuspended(network: Network) {
- history.add(CallbackRecord.Suspended(network))
+ history.add(Suspended(network))
}
override fun onNetworkResumed(network: Network) {
- history.add(CallbackRecord.Resumed(network))
+ history.add(Resumed(network))
}
override fun onLosing(network: Network, maxMsToLive: Int) {
- history.add(CallbackRecord.Losing(network, maxMsToLive))
+ history.add(Losing(network, maxMsToLive))
}
override fun onLost(network: Network) {
- history.add(CallbackRecord.Lost(network))
+ history.add(Lost(network))
}
override fun onUnavailable() {
- history.add(CallbackRecord.Unavailable())
+ history.add(Unavailable())
}
}
-typealias CallbackType = KClass<out RecorderCallback.CallbackRecord>
-const val DEFAULT_TIMEOUT = 200L // ms
+private const val DEFAULT_TIMEOUT = 200L // ms
+
+open class TestableNetworkCallback private constructor(
+ src: TestableNetworkCallback?,
+ val defaultTimeoutMs: Long = DEFAULT_TIMEOUT
+) : RecorderCallback(src) {
+ @JvmOverloads
+ constructor(timeoutMs: Long = DEFAULT_TIMEOUT): this(null, timeoutMs)
+
+ fun createLinkedCopy() = TestableNetworkCallback(this, defaultTimeoutMs)
-open class TestableNetworkCallback(val defaultTimeoutMs: Long = DEFAULT_TIMEOUT)
- : RecorderCallback() {
- // The last available network. Null if the last available network was lost since.
+ // The last available network, or null if any network was lost since the last call to
+ // onAvailable. TODO : fix this by fixing the tests that rely on this behavior
val lastAvailableNetwork: Network?
get() = when (val it = history.lastOrNull { it is Available || it is Lost }) {
is Available -> it.network
else -> null
}
- fun pollForNextCallback(timeoutMs: Long = defaultTimeoutMs): CallbackRecord {
+ fun pollForNextCallback(timeoutMs: Long = defaultTimeoutMs): CallbackEntry {
return history.poll(timeoutMs) ?: fail("Did not receive callback after ${timeoutMs}ms")
}
@@ -153,7 +172,7 @@ open class TestableNetworkCallback(val defaultTimeoutMs: Long = DEFAULT_TIMEOUT)
if (null != cb) fail("Expected no callback but got $cb")
}
- inline fun <reified T : CallbackRecord> expectCallback(
+ inline fun <reified T : CallbackEntry> expectCallback(
network: Network,
timeoutMs: Long = defaultTimeoutMs
): T = pollForNextCallback(timeoutMs).let {
@@ -166,7 +185,7 @@ open class TestableNetworkCallback(val defaultTimeoutMs: Long = DEFAULT_TIMEOUT)
fun expectCallbackThat(
timeoutMs: Long = defaultTimeoutMs,
- valid: (CallbackRecord) -> Boolean
+ valid: (CallbackEntry) -> Boolean
) = pollForNextCallback(timeoutMs).also { assertTrue(valid(it), "Unexpected callback : $it") }
fun expectCapabilitiesThat(
@@ -209,7 +228,7 @@ open class TestableNetworkCallback(val defaultTimeoutMs: Long = DEFAULT_TIMEOUT)
) {
expectCallback<Available>(net, tmt)
if (suspended) {
- expectCallback<CallbackRecord.Suspended>(net, tmt)
+ expectCallback<CallbackEntry.Suspended>(net, tmt)
}
expectCapabilitiesThat(net, tmt) { validated == it.hasCapability(NET_CAPABILITY_VALIDATED) }
expectCallback<LinkPropertiesChanged>(net, tmt)
@@ -257,7 +276,7 @@ open class TestableNetworkCallback(val defaultTimeoutMs: Long = DEFAULT_TIMEOUT)
}
@JvmOverloads
- open fun <T : CallbackRecord> expectCallback(
+ open fun <T : CallbackEntry> expectCallback(
type: KClass<T>,
n: HasNetwork?,
timeoutMs: Long = defaultTimeoutMs
diff --git a/tests/unit/Android.bp b/tests/unit/Android.bp
index 3410d0e..03bcf95 100644
--- a/tests/unit/Android.bp
+++ b/tests/unit/Android.bp
@@ -22,6 +22,7 @@ android_test {
resource_dirs: ["res"],
static_libs: [
"androidx.test.rules",
+ "kotlin-reflect",
"mockito-target-extended-minus-junit4",
"net-tests-utils",
"NetworkStackApiCurrentLib",
diff --git a/tests/unit/src/android/net/ip/IpReachabilityMonitorTest.java b/tests/unit/src/android/net/ip/IpReachabilityMonitorTest.java
index 64b168a..428baac 100644
--- a/tests/unit/src/android/net/ip/IpReachabilityMonitorTest.java
+++ b/tests/unit/src/android/net/ip/IpReachabilityMonitorTest.java
@@ -62,6 +62,8 @@ public class IpReachabilityMonitorTest {
@Test
public void testNothing() {
- IpReachabilityMonitor monitor = makeMonitor();
+ // make sure the unit test runs in the same thread with main looper.
+ // Otherwise, throwing IllegalStateException would cause test fails.
+ mHandler.post(() -> makeMonitor());
}
}
diff --git a/tests/unit/src/android/net/testutils/TestableNetworkCallbackTest.kt b/tests/unit/src/android/net/testutils/TestableNetworkCallbackTest.kt
new file mode 100644
index 0000000..4e4d25a
--- /dev/null
+++ b/tests/unit/src/android/net/testutils/TestableNetworkCallbackTest.kt
@@ -0,0 +1,293 @@
+package android.net.testutils
+
+import android.net.LinkAddress
+import android.net.LinkProperties
+import android.net.Network
+import android.net.NetworkCapabilities
+import com.android.testutils.ConcurrentIntepreter
+import com.android.testutils.InterpretMatcher
+import com.android.testutils.RecorderCallback.CallbackEntry
+import com.android.testutils.RecorderCallback.CallbackEntry.Available
+import com.android.testutils.RecorderCallback.CallbackEntry.BlockedStatus
+import com.android.testutils.RecorderCallback.CallbackEntry.CapabilitiesChanged
+import com.android.testutils.TestableNetworkCallback
+import com.android.testutils.intArg
+import com.android.testutils.strArg
+import com.android.testutils.timeArg
+import org.junit.Before
+import org.junit.Test
+import org.junit.runner.RunWith
+import org.junit.runners.JUnit4
+import kotlin.reflect.KClass
+import kotlin.test.assertEquals
+import kotlin.test.assertFails
+import kotlin.test.assertNull
+import kotlin.test.assertTrue
+import kotlin.test.fail
+
+const val SHORT_TIMEOUT_MS = 20L
+const val DEFAULT_LINGER_DELAY_MS = 30000
+const val NOT_METERED = NetworkCapabilities.NET_CAPABILITY_NOT_METERED
+const val WIFI = NetworkCapabilities.TRANSPORT_WIFI
+const val CELLULAR = NetworkCapabilities.TRANSPORT_CELLULAR
+const val TEST_INTERFACE_NAME = "testInterfaceName"
+
+@RunWith(JUnit4::class)
+class TestableNetworkCallbackTest {
+ private lateinit var mCallback: TestableNetworkCallback
+
+ private fun makeHasNetwork(netId: Int) = object : TestableNetworkCallback.HasNetwork {
+ override val network: Network = Network(netId)
+ }
+
+ @Before
+ fun setUp() {
+ mCallback = TestableNetworkCallback()
+ }
+
+ @Test
+ fun testLastAvailableNetwork() {
+ // Make sure there is no last available network at first, then the last available network
+ // is returned after onAvailable is called.
+ val net2097 = Network(2097)
+ assertNull(mCallback.lastAvailableNetwork)
+ mCallback.onAvailable(net2097)
+ assertEquals(mCallback.lastAvailableNetwork, net2097)
+
+ // Make sure calling onCapsChanged/onLinkPropertiesChanged don't affect the last available
+ // network.
+ mCallback.onCapabilitiesChanged(net2097, NetworkCapabilities())
+ mCallback.onLinkPropertiesChanged(net2097, LinkProperties())
+ assertEquals(mCallback.lastAvailableNetwork, net2097)
+
+ // Make sure onLost clears the last available network.
+ mCallback.onLost(net2097)
+ assertNull(mCallback.lastAvailableNetwork)
+
+ // Do the same but with a different network after onLost : make sure the last available
+ // network is the new one, not the original one.
+ val net2098 = Network(2098)
+ mCallback.onAvailable(net2098)
+ mCallback.onCapabilitiesChanged(net2098, NetworkCapabilities())
+ mCallback.onLinkPropertiesChanged(net2098, LinkProperties())
+ assertEquals(mCallback.lastAvailableNetwork, net2098)
+
+ // Make sure onAvailable changes the last available network even if onLost was not called.
+ val net2099 = Network(2099)
+ mCallback.onAvailable(net2099)
+ assertEquals(mCallback.lastAvailableNetwork, net2099)
+
+ // For legacy reasons, lastAvailableNetwork is null as soon as any is lost, not necessarily
+ // the last available one. Check that behavior.
+ mCallback.onLost(net2098)
+ assertNull(mCallback.lastAvailableNetwork)
+
+ // Make sure that losing the really last available one still results in null.
+ mCallback.onLost(net2099)
+ assertNull(mCallback.lastAvailableNetwork)
+
+ // Make sure multiple onAvailable in a row then onLost still results in null.
+ mCallback.onAvailable(net2097)
+ mCallback.onAvailable(net2098)
+ mCallback.onAvailable(net2099)
+ mCallback.onLost(net2097)
+ assertNull(mCallback.lastAvailableNetwork)
+ }
+
+ @Test
+ fun testAssertNoCallback() {
+ mCallback.assertNoCallback(SHORT_TIMEOUT_MS)
+ mCallback.onAvailable(Network(100))
+ assertFails { mCallback.assertNoCallback(SHORT_TIMEOUT_MS) }
+ }
+
+ @Test
+ fun testCapabilitiesWithAndWithout() {
+ val net = Network(101)
+ val matcher = makeHasNetwork(101)
+ val meteredNc = NetworkCapabilities()
+ val unmeteredNc = NetworkCapabilities().addCapability(NOT_METERED)
+ // Check that expecting caps (with or without) fails when no callback has been received.
+ assertFails { mCallback.expectCapabilitiesWith(NOT_METERED, matcher, SHORT_TIMEOUT_MS) }
+ assertFails { mCallback.expectCapabilitiesWithout(NOT_METERED, matcher, SHORT_TIMEOUT_MS) }
+
+ // Add NOT_METERED and check that With succeeds and Without fails.
+ mCallback.onCapabilitiesChanged(net, unmeteredNc)
+ mCallback.expectCapabilitiesWith(NOT_METERED, matcher)
+ mCallback.onCapabilitiesChanged(net, unmeteredNc)
+ assertFails { mCallback.expectCapabilitiesWithout(NOT_METERED, matcher, SHORT_TIMEOUT_MS) }
+
+ // Don't add NOT_METERED and check that With fails and Without succeeds.
+ mCallback.onCapabilitiesChanged(net, meteredNc)
+ assertFails { mCallback.expectCapabilitiesWith(NOT_METERED, matcher, SHORT_TIMEOUT_MS) }
+ mCallback.onCapabilitiesChanged(net, meteredNc)
+ mCallback.expectCapabilitiesWithout(NOT_METERED, matcher)
+ }
+
+ @Test
+ fun testExpectCallbackThat() {
+ val net = Network(193)
+ val netCaps = NetworkCapabilities().addTransportType(CELLULAR)
+ // Check that expecting callbackThat anything fails when no callback has been received.
+ assertFails { mCallback.expectCallbackThat(SHORT_TIMEOUT_MS) { true } }
+
+ // Basic test for true and false
+ mCallback.onAvailable(net)
+ mCallback.expectCallbackThat { true }
+ mCallback.onAvailable(net)
+ assertFails { mCallback.expectCallbackThat(SHORT_TIMEOUT_MS) { false } }
+
+ // Try a positive and a negative case
+ mCallback.onBlockedStatusChanged(net, true)
+ mCallback.expectCallbackThat { cb -> cb is BlockedStatus && cb.blocked }
+ mCallback.onCapabilitiesChanged(net, netCaps)
+ assertFails { mCallback.expectCallbackThat(SHORT_TIMEOUT_MS) { cb ->
+ cb is CapabilitiesChanged && cb.caps.hasTransport(WIFI)
+ } }
+ }
+
+ @Test
+ fun testCapabilitiesThat() {
+ val net = Network(101)
+ val netCaps = NetworkCapabilities().addCapability(NOT_METERED).addTransportType(WIFI)
+ // Check that expecting capabilitiesThat anything fails when no callback has been received.
+ assertFails { mCallback.expectCapabilitiesThat(net, SHORT_TIMEOUT_MS) { true } }
+
+ // Basic test for true and false
+ mCallback.onCapabilitiesChanged(net, netCaps)
+ mCallback.expectCapabilitiesThat(net) { true }
+ mCallback.onCapabilitiesChanged(net, netCaps)
+ assertFails { mCallback.expectCapabilitiesThat(net, SHORT_TIMEOUT_MS) { false } }
+
+ // Try a positive and a negative case
+ mCallback.onCapabilitiesChanged(net, netCaps)
+ mCallback.expectCapabilitiesThat(net) { caps ->
+ caps.hasCapability(NOT_METERED) &&
+ caps.hasTransport(WIFI) &&
+ !caps.hasTransport(CELLULAR)
+ }
+ mCallback.onCapabilitiesChanged(net, netCaps)
+ assertFails { mCallback.expectCapabilitiesThat(net, SHORT_TIMEOUT_MS) { caps ->
+ caps.hasTransport(CELLULAR)
+ } }
+
+ // Try a matching callback on the wrong network
+ mCallback.onCapabilitiesChanged(net, netCaps)
+ assertFails { mCallback.expectCapabilitiesThat(Network(100), SHORT_TIMEOUT_MS) { true } }
+ }
+
+ @Test
+ fun testLinkPropertiesThat() {
+ val net = Network(112)
+ val linkAddress = LinkAddress("fe80::ace:d00d/64")
+ val mtu = 1984
+ val linkProps = LinkProperties().apply {
+ this.mtu = mtu
+ interfaceName = TEST_INTERFACE_NAME
+ addLinkAddress(linkAddress)
+ }
+
+ // Check that expecting linkPropsThat anything fails when no callback has been received.
+ assertFails { mCallback.expectLinkPropertiesThat(net, SHORT_TIMEOUT_MS) { true } }
+
+ // Basic test for true and false
+ mCallback.onLinkPropertiesChanged(net, linkProps)
+ mCallback.expectLinkPropertiesThat(net) { true }
+ mCallback.onLinkPropertiesChanged(net, linkProps)
+ assertFails { mCallback.expectLinkPropertiesThat(net, SHORT_TIMEOUT_MS) { false } }
+
+ // Try a positive and negative case
+ mCallback.onLinkPropertiesChanged(net, linkProps)
+ mCallback.expectLinkPropertiesThat(net) { lp ->
+ lp.interfaceName == TEST_INTERFACE_NAME &&
+ lp.linkAddresses.contains(linkAddress) &&
+ lp.mtu == mtu
+ }
+ mCallback.onLinkPropertiesChanged(net, linkProps)
+ assertFails { mCallback.expectLinkPropertiesThat(net, SHORT_TIMEOUT_MS) { lp ->
+ lp.interfaceName != TEST_INTERFACE_NAME
+ } }
+
+ // Try a matching callback on the wrong network
+ mCallback.onLinkPropertiesChanged(net, linkProps)
+ assertFails { mCallback.expectLinkPropertiesThat(Network(114), SHORT_TIMEOUT_MS) { lp ->
+ lp.interfaceName == TEST_INTERFACE_NAME
+ } }
+ }
+
+ @Test
+ fun testExpectCallback() {
+ val net = Network(103)
+ // Test expectCallback fails when nothing was sent.
+ assertFails { mCallback.expectCallback<BlockedStatus>(net, SHORT_TIMEOUT_MS) }
+
+ // Test onAvailable is seen and can be expected
+ mCallback.onAvailable(net)
+ mCallback.expectCallback<Available>(net, SHORT_TIMEOUT_MS)
+
+ // Test onAvailable won't return calls with a different network
+ mCallback.onAvailable(Network(106))
+ assertFails { mCallback.expectCallback<Available>(net, SHORT_TIMEOUT_MS) }
+
+ // Test onAvailable won't return calls with a different callback
+ mCallback.onAvailable(net)
+ assertFails { mCallback.expectCallback<BlockedStatus>(net, SHORT_TIMEOUT_MS) }
+ }
+
+ @Test
+ fun testPollForNextCallback() {
+ assertFails { mCallback.pollForNextCallback(SHORT_TIMEOUT_MS) }
+ TNCInterpreter.interpretTestSpec(initial = mCallback, lineShift = 1,
+ threadTransform = { cb -> cb.createLinkedCopy() }, spec = """
+ sleep; onAvailable(133) | poll(2) = Available(133) time 1..4
+ | poll(1) fails
+ onCapabilitiesChanged(108) | poll(1) = CapabilitiesChanged(108) time 0..3
+ onBlockedStatus(199) | poll(1) = BlockedStatus(199) time 0..3
+ """)
+ }
+}
+
+private object TNCInterpreter : ConcurrentIntepreter<TestableNetworkCallback>(interpretTable)
+
+val EntryList = CallbackEntry::class.sealedSubclasses.map { it.simpleName }.joinToString("|")
+private fun callbackEntryFromString(name: String): KClass<out CallbackEntry> {
+ return CallbackEntry::class.sealedSubclasses.first { it.simpleName == name }
+}
+
+private val interpretTable = listOf<InterpretMatcher<TestableNetworkCallback>>(
+ // Interpret "Available(xx)" as "call to onAvailable with netId xx", and likewise for
+ // all callback types. This is implemented above by enumerating the subclasses of
+ // CallbackEntry and reading their simpleName.
+ Regex("""(.*)\s+=\s+($EntryList)\((\d+)\)""") to { i, cb, t ->
+ val record = i.interpret(t.strArg(1), cb)
+ assertTrue(callbackEntryFromString(t.strArg(2)).isInstance(record))
+ // Strictly speaking testing for is CallbackEntry is useless as it's been tested above
+ // but the compiler can't figure things out from the isInstance call. It does understand
+ // from the assertTrue(is CallbackEntry) that this is true, which allows to access
+ // the 'network' member below.
+ assertTrue(record is CallbackEntry)
+ assertEquals(record.network.netId, t.intArg(3))
+ },
+ // Interpret "onAvailable(xx)" as calling "onAvailable" with a netId of xx, and likewise for
+ // all callback types. NetworkCapabilities and LinkProperties just get an empty object
+ // as their argument. Losing gets the default linger timer. Blocked gets false.
+ Regex("""on($EntryList)\((\d+)\)""") to { i, cb, t ->
+ val net = Network(t.intArg(2))
+ when (t.strArg(1)) {
+ "Available" -> cb.onAvailable(net)
+ // PreCheck not used in tests. Add it here if it becomes useful.
+ "CapabilitiesChanged" -> cb.onCapabilitiesChanged(net, NetworkCapabilities())
+ "LinkPropertiesChanged" -> cb.onLinkPropertiesChanged(net, LinkProperties())
+ "Suspended" -> cb.onNetworkSuspended(net)
+ "Resumed" -> cb.onNetworkResumed(net)
+ "Losing" -> cb.onLosing(net, DEFAULT_LINGER_DELAY_MS)
+ "Lost" -> cb.onLost(net)
+ "Unavailable" -> cb.onUnavailable()
+ "BlockedStatus" -> cb.onBlockedStatusChanged(net, false)
+ else -> fail("Unknown callback type")
+ }
+ },
+ Regex("""poll\((\d+)\)""") to { i, cb, t ->
+ cb.pollForNextCallback(t.timeArg(1))
+ }
+)
diff --git a/tests/unit/src/android/net/testutils/TrackRecordTest.kt b/tests/unit/src/android/net/testutils/TrackRecordTest.kt
index 4fe8d37..995d537 100644
--- a/tests/unit/src/android/net/testutils/TrackRecordTest.kt
+++ b/tests/unit/src/android/net/testutils/TrackRecordTest.kt
@@ -352,7 +352,8 @@ class TrackRecordTest {
private object TRTInterpreter : ConcurrentIntepreter<TrackRecord<Int>>(interpretTable) {
fun interpretTestSpec(spec: String, useReadHeads: Boolean) = if (useReadHeads) {
- interpretTestSpec(spec, ArrayTrackRecord(), { (it as ArrayTrackRecord).newReadHead() })
+ interpretTestSpec(spec, initial = ArrayTrackRecord(),
+ threadTransform = { (it as ArrayTrackRecord).newReadHead() })
} else {
interpretTestSpec(spec, ArrayTrackRecord())
}
diff --git a/tests/unit/src/android/net/util/PacketReaderTest.java b/tests/unit/src/android/net/util/PacketReaderTest.java
index 289dcad..3947d15 100644
--- a/tests/unit/src/android/net/util/PacketReaderTest.java
+++ b/tests/unit/src/android/net/util/PacketReaderTest.java
@@ -24,6 +24,8 @@ import static android.system.OsConstants.SOCK_NONBLOCK;
import static android.system.OsConstants.SOL_SOCKET;
import static android.system.OsConstants.SO_SNDTIMEO;
+import static com.android.testutils.MiscAssertsKt.assertThrows;
+
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
@@ -182,7 +184,7 @@ public class PacketReaderTest {
assertTrue(Arrays.equals(two, mLastRecvBuf));
assertFalse(mStopped);
- mReceiver.stop();
+ h.post(() -> mReceiver.stop());
waitForActivity();
assertEquals(2, mReceiver.numPacketsReceived());
assertTrue(Arrays.equals(two, mLastRecvBuf));
@@ -208,4 +210,32 @@ public class PacketReaderTest {
assertEquals(DEFAULT_RECV_BUF_SIZE, b.recvBufSize());
}
}
+
+ @Test
+ public void testStartingFromWrongThread() throws Exception {
+ final Handler h = mHandlerThread.getThreadHandler();
+ final PacketReader b = new NullPacketReader(h, DEFAULT_RECV_BUF_SIZE);
+ assertThrows(IllegalStateException.class, () -> b.start());
+ }
+
+ @Test
+ public void testStoppingFromWrongThread() throws Exception {
+ final Handler h = mHandlerThread.getThreadHandler();
+ final PacketReader b = new NullPacketReader(h, DEFAULT_RECV_BUF_SIZE);
+ assertThrows(IllegalStateException.class, () -> b.stop());
+ }
+
+ @Test
+ public void testSuccessToCreateSocket() throws Exception {
+ final Handler h = mHandlerThread.getThreadHandler();
+ final PacketReader b = new UdpLoopbackReader(h);
+ h.post(() -> assertTrue(b.start()));
+ }
+
+ @Test
+ public void testFailToCreateSocket() throws Exception {
+ final Handler h = mHandlerThread.getThreadHandler();
+ final PacketReader b = new NullPacketReader(h, DEFAULT_RECV_BUF_SIZE);
+ h.post(() -> assertFalse(b.start()));
+ }
}
diff --git a/tests/unit/src/com/android/server/connectivity/NetworkMonitorTest.java b/tests/unit/src/com/android/server/connectivity/NetworkMonitorTest.java
index e8bc8d2..49a2ce3 100644
--- a/tests/unit/src/com/android/server/connectivity/NetworkMonitorTest.java
+++ b/tests/unit/src/com/android/server/connectivity/NetworkMonitorTest.java
@@ -173,6 +173,8 @@ public class NetworkMonitorTest {
private static final int VALIDATION_RESULT_VALID = NETWORK_VALIDATION_PROBE_DNS
| NETWORK_VALIDATION_PROBE_HTTPS
| NETWORK_VALIDATION_RESULT_VALID;
+ private static final int VALIDATION_RESULT_PRIVDNS_VALID = NETWORK_VALIDATION_PROBE_DNS
+ | NETWORK_VALIDATION_PROBE_HTTPS | NETWORK_VALIDATION_PROBE_PRIVDNS;
private static final int RETURN_CODE_DNS_SUCCESS = 0;
private static final int RETURN_CODE_DNS_TIMEOUT = 255;
@@ -800,6 +802,8 @@ public class NetworkMonitorTest {
verify(mCallbacks, timeout(HANDLER_TIMEOUT_MS).times(1))
.notifyNetworkTested(eq(VALIDATION_RESULT_VALID | NETWORK_VALIDATION_PROBE_PRIVDNS),
eq(null));
+ verify(mCallbacks, timeout(HANDLER_TIMEOUT_MS).times(1)).notifyProbeStatusChanged(
+ eq(VALIDATION_RESULT_PRIVDNS_VALID), eq(VALIDATION_RESULT_PRIVDNS_VALID));
// Verify dns query only get v4 address.
resetCallbacks();
@@ -809,6 +813,10 @@ public class NetworkMonitorTest {
verify(mCallbacks, timeout(HANDLER_TIMEOUT_MS).times(1))
.notifyNetworkTested(eq(VALIDATION_RESULT_VALID | NETWORK_VALIDATION_PROBE_PRIVDNS),
eq(null));
+ // NetworkMonitor will check if the probes has changed or not, if the probes has not
+ // changed, the callback won't be fired.
+ verify(mCallbacks, never()).notifyProbeStatusChanged(
+ eq(VALIDATION_RESULT_PRIVDNS_VALID), eq(VALIDATION_RESULT_PRIVDNS_VALID));
// Verify dns query get both v4 and v6 address.
resetCallbacks();
@@ -818,10 +826,12 @@ public class NetworkMonitorTest {
verify(mCallbacks, timeout(HANDLER_TIMEOUT_MS).times(1))
.notifyNetworkTested(eq(VALIDATION_RESULT_VALID | NETWORK_VALIDATION_PROBE_PRIVDNS),
eq(null));
+ verify(mCallbacks, never()).notifyProbeStatusChanged(
+ eq(VALIDATION_RESULT_PRIVDNS_VALID), eq(VALIDATION_RESULT_PRIVDNS_VALID));
}
@Test
- public void testPrivateDnsResolutionRetryUpdate() throws Exception {
+ public void testProbeStatusChanged() throws Exception {
// Set no record in FakeDns and expect validation to fail.
setStatus(mHttpsConnection, 204);
setStatus(mHttpConnection, 204);
@@ -829,10 +839,40 @@ public class NetworkMonitorTest {
WrappedNetworkMonitor wnm = makeNotMeteredNetworkMonitor();
wnm.notifyPrivateDnsSettingsChanged(new PrivateDnsConfig("dns.google", new InetAddress[0]));
wnm.notifyNetworkConnected(TEST_LINK_PROPERTIES, NOT_METERED_CAPABILITIES);
- verify(mCallbacks, timeout(HANDLER_TIMEOUT_MS).atLeastOnce())
- .notifyNetworkTested(
- eq(NETWORK_VALIDATION_PROBE_DNS | NETWORK_VALIDATION_PROBE_HTTPS),
+ verify(mCallbacks, timeout(HANDLER_TIMEOUT_MS).times(1)).notifyNetworkTested(
+ eq(NETWORK_VALIDATION_PROBE_DNS | NETWORK_VALIDATION_PROBE_HTTPS), eq(null));
+ verify(mCallbacks, timeout(HANDLER_TIMEOUT_MS).times(1)).notifyProbeStatusChanged(
+ eq(VALIDATION_RESULT_PRIVDNS_VALID), eq(NETWORK_VALIDATION_PROBE_DNS
+ | NETWORK_VALIDATION_PROBE_HTTPS));
+ // Fix DNS and retry, expect validation to succeed.
+ resetCallbacks();
+ mFakeDns.setAnswer("dns.google", new String[]{"2001:db8::1"}, TYPE_AAAA);
+
+ wnm.forceReevaluation(Process.myUid());
+ // ProbeCompleted should be reset to 0
+ HandlerUtilsKt.waitForIdle(wnm.getHandler(), HANDLER_TIMEOUT_MS);
+ assertEquals(wnm.getEvaluationState().getProbeCompletedResult(), 0);
+ verify(mCallbacks, timeout(HANDLER_TIMEOUT_MS).times(1))
+ .notifyNetworkTested(eq(VALIDATION_RESULT_VALID | NETWORK_VALIDATION_PROBE_PRIVDNS),
eq(null));
+ verify(mCallbacks, timeout(HANDLER_TIMEOUT_MS).times(1)).notifyProbeStatusChanged(
+ eq(VALIDATION_RESULT_PRIVDNS_VALID), eq(VALIDATION_RESULT_PRIVDNS_VALID));
+ }
+
+ @Test
+ public void testPrivateDnsResolutionRetryUpdate() throws Exception {
+ // Set no record in FakeDns and expect validation to fail.
+ setStatus(mHttpsConnection, 204);
+ setStatus(mHttpConnection, 204);
+
+ WrappedNetworkMonitor wnm = makeNotMeteredNetworkMonitor();
+ wnm.notifyPrivateDnsSettingsChanged(new PrivateDnsConfig("dns.google", new InetAddress[0]));
+ wnm.notifyNetworkConnected(TEST_LINK_PROPERTIES, NOT_METERED_CAPABILITIES);
+ verify(mCallbacks, timeout(HANDLER_TIMEOUT_MS)).notifyNetworkTested(
+ eq(NETWORK_VALIDATION_PROBE_DNS | NETWORK_VALIDATION_PROBE_HTTPS), eq(null));
+ verify(mCallbacks, timeout(HANDLER_TIMEOUT_MS).times(1)).notifyProbeStatusChanged(
+ eq(VALIDATION_RESULT_PRIVDNS_VALID), eq(NETWORK_VALIDATION_PROBE_DNS
+ | NETWORK_VALIDATION_PROBE_HTTPS));
// Fix DNS and retry, expect validation to succeed.
resetCallbacks();
@@ -842,6 +882,8 @@ public class NetworkMonitorTest {
verify(mCallbacks, timeout(HANDLER_TIMEOUT_MS).atLeastOnce())
.notifyNetworkTested(eq(VALIDATION_RESULT_VALID | NETWORK_VALIDATION_PROBE_PRIVDNS),
eq(null));
+ verify(mCallbacks, timeout(HANDLER_TIMEOUT_MS).times(1)).notifyProbeStatusChanged(
+ eq(VALIDATION_RESULT_PRIVDNS_VALID), eq(VALIDATION_RESULT_PRIVDNS_VALID));
// Change configuration to an invalid DNS name, expect validation to fail.
resetCallbacks();
@@ -853,6 +895,9 @@ public class NetworkMonitorTest {
verify(mCallbacks, timeout(HANDLER_TIMEOUT_MS))
.notifyNetworkTested(eq(NETWORK_VALIDATION_PROBE_DNS
| NETWORK_VALIDATION_PROBE_HTTPS), eq(null));
+ verify(mCallbacks, timeout(HANDLER_TIMEOUT_MS).times(1)).notifyProbeStatusChanged(
+ eq(VALIDATION_RESULT_PRIVDNS_VALID), eq(NETWORK_VALIDATION_PROBE_DNS
+ | NETWORK_VALIDATION_PROBE_HTTPS));
// Change configuration back to working again, but make private DNS not work.
// Expect validation to fail.
@@ -860,10 +905,13 @@ public class NetworkMonitorTest {
mFakeDns.setNonBypassPrivateDnsWorking(false);
wnm.notifyPrivateDnsSettingsChanged(new PrivateDnsConfig("dns.google",
new InetAddress[0]));
- verify(mCallbacks, timeout(HANDLER_TIMEOUT_MS).atLeastOnce())
- .notifyNetworkTested(
- eq(NETWORK_VALIDATION_PROBE_DNS | NETWORK_VALIDATION_PROBE_HTTPS),
- eq(null));
+ verify(mCallbacks, timeout(HANDLER_TIMEOUT_MS)).notifyNetworkTested(
+ eq(NETWORK_VALIDATION_PROBE_DNS | NETWORK_VALIDATION_PROBE_HTTPS), eq(null));
+ // NetworkMonitor will check if the probes has changed or not, if the probes has not
+ // changed, the callback won't be fired.
+ verify(mCallbacks, never()).notifyProbeStatusChanged(
+ eq(VALIDATION_RESULT_PRIVDNS_VALID), eq(NETWORK_VALIDATION_PROBE_DNS
+ | NETWORK_VALIDATION_PROBE_HTTPS));
// Make private DNS work again. Expect validation to succeed.
resetCallbacks();
@@ -872,6 +920,8 @@ public class NetworkMonitorTest {
verify(mCallbacks, timeout(HANDLER_TIMEOUT_MS).atLeastOnce())
.notifyNetworkTested(
eq(VALIDATION_RESULT_VALID | NETWORK_VALIDATION_PROBE_PRIVDNS), eq(null));
+ verify(mCallbacks, timeout(HANDLER_TIMEOUT_MS).times(1)).notifyProbeStatusChanged(
+ eq(VALIDATION_RESULT_PRIVDNS_VALID), eq(VALIDATION_RESULT_PRIVDNS_VALID));
}
@Test