diff options
16 files changed, 769 insertions, 284 deletions
diff --git a/common/networkstackclient/Android.bp b/common/networkstackclient/Android.bp index ccb3f45..c82f751 100644 --- a/common/networkstackclient/Android.bp +++ b/common/networkstackclient/Android.bp @@ -87,6 +87,6 @@ java_library { ], static_libs: [ "ipmemorystore-aidl-interfaces-V3-java", - "networkstack-aidl-interfaces-V3-java", + "networkstack-aidl-interfaces-java", ], } diff --git a/common/networkstackclient/src/android/net/INetworkMonitorCallbacks.aidl b/common/networkstackclient/src/android/net/INetworkMonitorCallbacks.aidl index 2c61511..f8dcd6c 100644 --- a/common/networkstackclient/src/android/net/INetworkMonitorCallbacks.aidl +++ b/common/networkstackclient/src/android/net/INetworkMonitorCallbacks.aidl @@ -26,4 +26,5 @@ oneway interface INetworkMonitorCallbacks { void notifyPrivateDnsConfigResolved(in PrivateDnsConfigParcel config); void showProvisioningNotification(String action, String packageName); void hideProvisioningNotification(); + void notifyProbeStatusChanged(int probesCompleted, int probesSucceeded); }
\ No newline at end of file diff --git a/src/android/net/apf/ApfFilter.java b/src/android/net/apf/ApfFilter.java index 1d3421c..2f74ad6 100644 --- a/src/android/net/apf/ApfFilter.java +++ b/src/android/net/apf/ApfFilter.java @@ -57,7 +57,6 @@ import android.system.ErrnoException; import android.system.Os; import android.text.format.DateUtils; import android.util.Log; -import android.util.Pair; import android.util.SparseArray; import com.android.internal.annotations.GuardedBy; @@ -500,6 +499,44 @@ public class ApfFilter { } } + /** + * Class to keep track of a section in a packet. + */ + private static class PacketSection { + public enum Type { + MATCH, // A field that should be matched (e.g., the router IP address). + IGNORE, // An ignored field such as the checksum of the flow label. Not matched. + LIFETIME, // A lifetime. Not matched, and generally counts toward minimum RA lifetime. + } + + /** The type of section. */ + public final Type type; + /** Offset into the packet at which this section begins. */ + public final int start; + /** Length of this section. */ + public final int length; + /** If this is a lifetime, the ICMP option that the defined it. 0 for router lifetime. */ + public final int option; + /** If this is a lifetime, the lifetime value. */ + public final long lifetime; + + PacketSection(int start, int length, Type type, int option, long lifetime) { + this.start = start; + this.length = length; + this.type = type; + this.option = option; + this.lifetime = lifetime; + } + + public String toString() { + if (type == Type.LIFETIME) { + return String.format("%s: (%d, %d) %d %d", type, start, length, option, lifetime); + } else { + return String.format("%s: (%d, %d)", type, start, length); + } + } + } + // A class to hold information about an RA. @VisibleForTesting class Ra { @@ -534,10 +571,10 @@ public class ApfFilter { // Note: mPacket's position() cannot be assumed to be reset. private final ByteBuffer mPacket; - // List of binary ranges that include the whole packet except the lifetimes. - // Pairs consist of offset and length. - private final ArrayList<Pair<Integer, Integer>> mNonLifetimes = - new ArrayList<Pair<Integer, Integer>>(); + + // List of sections in the packet. + private final ArrayList<PacketSection> mPacketSections = new ArrayList<>(); + // Minimum lifetime in packet long mMinLifetime; // When the packet was last captured, in seconds since Unix Epoch @@ -649,27 +686,65 @@ public class ApfFilter { } /** - * Add a binary range of the packet that does not include a lifetime to mNonLifetimes. - * Assumes mPacket.position() is as far as we've parsed the packet. - * @param lastNonLifetimeStart offset within packet of where the last binary range of - * data not including a lifetime. - * @param lifetimeOffset offset from mPacket.position() to the next lifetime data. - * @param lifetimeLength length of the next lifetime data. - * @return offset within packet of where the next binary range of data not including - * a lifetime. This can be passed into the next invocation of this function - * via {@code lastNonLifetimeStart}. + * Add a packet section that should be matched, starting from the current position. + * @param length the length of the section + */ + private void addMatchSection(int length) { + // Don't generate JNEBS instruction for 0 bytes as they will fail the + // ASSERT_FORWARD_IN_PROGRAM(pc + cmp_imm - 1) check (where cmp_imm is + // the number of bytes to compare) and immediately pass the packet. + // The code does not attempt to generate such matches, but add a safety + // check to prevent doing so in the presence of bugs or malformed or + // truncated packets. + if (length == 0) return; + mPacketSections.add( + new PacketSection(mPacket.position(), length, PacketSection.Type.MATCH, 0, 0)); + mPacket.position(mPacket.position() + length); + } + + /** + * Add a packet section that should be matched, starting from the current position. + * @param end the offset in the packet before which the section ends + */ + private void addMatchUntil(int end) { + addMatchSection(end - mPacket.position()); + } + + /** + * Add a packet section that should be ignored, starting from the current position. + * @param length the length of the section in bytes */ - private int addNonLifetime(int lastNonLifetimeStart, int lifetimeOffset, - int lifetimeLength) { - lifetimeOffset += mPacket.position(); - mNonLifetimes.add(new Pair<Integer, Integer>(lastNonLifetimeStart, - lifetimeOffset - lastNonLifetimeStart)); - return lifetimeOffset + lifetimeLength; + private void addIgnoreSection(int length) { + mPacketSections.add( + new PacketSection(mPacket.position(), length, PacketSection.Type.IGNORE, 0, 0)); + mPacket.position(mPacket.position() + length); } - private int addNonLifetimeU32(int lastNonLifetimeStart) { - return addNonLifetime(lastNonLifetimeStart, - ICMP6_4_BYTE_LIFETIME_OFFSET, ICMP6_4_BYTE_LIFETIME_LEN); + /** + * Add a packet section that represents a lifetime, starting from the current position. + * @param length the length of the section in bytes + * @param optionType the RA option containing this lifetime, or 0 for router lifetime + * @param lifetime the lifetime + */ + private void addLifetimeSection(int length, int optionType, long lifetime) { + mPacketSections.add( + new PacketSection(mPacket.position(), length, PacketSection.Type.LIFETIME, + optionType, lifetime)); + mPacket.position(mPacket.position() + length); + } + + /** + * Adds packet sections for an RA option with a 4-byte lifetime 4 bytes into the option + * @param optionType the RA option that is being added + * @param optionLength the length of the option in bytes + */ + private long add4ByteLifetimeOption(int optionType, int optionLength) { + addMatchSection(ICMP6_4_BYTE_LIFETIME_OFFSET); + final long lifetime = getUint32(mPacket, mPacket.position()); + addLifetimeSection(ICMP6_4_BYTE_LIFETIME_LEN, optionType, lifetime); + addMatchSection(optionLength - ICMP6_4_BYTE_LIFETIME_OFFSET + - ICMP6_4_BYTE_LIFETIME_LEN); + return lifetime; } // Note that this parses RA and may throw InvalidRaException (from @@ -696,20 +771,18 @@ public class ApfFilter { RaEvent.Builder builder = new RaEvent.Builder(); // Ignore the flow label and low 4 bits of traffic class. - int lastNonLifetimeStart = addNonLifetime(0, - IPV6_FLOW_LABEL_OFFSET, - IPV6_FLOW_LABEL_LEN); + addMatchUntil(IPV6_FLOW_LABEL_OFFSET); + addIgnoreSection(IPV6_FLOW_LABEL_LEN); - // Ignore the checksum. - lastNonLifetimeStart = addNonLifetime(lastNonLifetimeStart, - ICMP6_RA_CHECKSUM_OFFSET, - ICMP6_RA_CHECKSUM_LEN); + // Ignore checksum. + addMatchUntil(ICMP6_RA_CHECKSUM_OFFSET); + addIgnoreSection(ICMP6_RA_CHECKSUM_LEN); // Parse router lifetime - lastNonLifetimeStart = addNonLifetime(lastNonLifetimeStart, - ICMP6_RA_ROUTER_LIFETIME_OFFSET, - ICMP6_RA_ROUTER_LIFETIME_LEN); - builder.updateRouterLifetime(getUint16(mPacket, ICMP6_RA_ROUTER_LIFETIME_OFFSET)); + addMatchUntil(ICMP6_RA_ROUTER_LIFETIME_OFFSET); + final long routerLifetime = getUint16(mPacket, ICMP6_RA_ROUTER_LIFETIME_OFFSET); + addLifetimeSection(ICMP6_RA_ROUTER_LIFETIME_LEN, 0, routerLifetime); + builder.updateRouterLifetime(routerLifetime); // Ensures that the RA is not truncated. mPacket.position(ICMP6_RA_OPTION_OFFSET); @@ -720,64 +793,62 @@ public class ApfFilter { long lifetime; switch (optionType) { case ICMP6_PREFIX_OPTION_TYPE: + mPrefixOptionOffsets.add(position); + // Parse valid lifetime - lastNonLifetimeStart = addNonLifetime(lastNonLifetimeStart, - ICMP6_PREFIX_OPTION_VALID_LIFETIME_OFFSET, - ICMP6_PREFIX_OPTION_VALID_LIFETIME_LEN); - lifetime = getUint32(mPacket, - position + ICMP6_PREFIX_OPTION_VALID_LIFETIME_OFFSET); + addMatchSection(ICMP6_PREFIX_OPTION_VALID_LIFETIME_LEN); + lifetime = getUint32(mPacket, mPacket.position()); + addLifetimeSection(ICMP6_PREFIX_OPTION_VALID_LIFETIME_LEN, + ICMP6_PREFIX_OPTION_TYPE, lifetime); builder.updatePrefixValidLifetime(lifetime); + // Parse preferred lifetime - lastNonLifetimeStart = addNonLifetime(lastNonLifetimeStart, - ICMP6_PREFIX_OPTION_PREFERRED_LIFETIME_OFFSET, - ICMP6_PREFIX_OPTION_PREFERRED_LIFETIME_LEN); - lifetime = getUint32(mPacket, - position + ICMP6_PREFIX_OPTION_PREFERRED_LIFETIME_OFFSET); + lifetime = getUint32(mPacket, mPacket.position()); + addLifetimeSection(ICMP6_PREFIX_OPTION_PREFERRED_LIFETIME_LEN, + ICMP6_PREFIX_OPTION_TYPE, lifetime); builder.updatePrefixPreferredLifetime(lifetime); - mPrefixOptionOffsets.add(position); + + addMatchSection(4); // Reserved bytes + addMatchSection(IPV6_ADDR_LEN); // The prefix itself break; // These three options have the same lifetime offset and size, and - // are processed with the same specialized addNonLifetimeU32: + // are processed with the same specialized add4ByteLifetimeOption: case ICMP6_RDNSS_OPTION_TYPE: mRdnssOptionOffsets.add(position); - lastNonLifetimeStart = addNonLifetimeU32(lastNonLifetimeStart); - lifetime = getUint32(mPacket, position + ICMP6_4_BYTE_LIFETIME_OFFSET); + lifetime = add4ByteLifetimeOption(optionType, optionLength); builder.updateRdnssLifetime(lifetime); break; case ICMP6_ROUTE_INFO_OPTION_TYPE: mRioOptionOffsets.add(position); - lastNonLifetimeStart = addNonLifetimeU32(lastNonLifetimeStart); - lifetime = getUint32(mPacket, position + ICMP6_4_BYTE_LIFETIME_OFFSET); + lifetime = add4ByteLifetimeOption(optionType, optionLength); builder.updateRouteInfoLifetime(lifetime); break; case ICMP6_DNSSL_OPTION_TYPE: - lastNonLifetimeStart = addNonLifetimeU32(lastNonLifetimeStart); - lifetime = getUint32(mPacket, position + ICMP6_4_BYTE_LIFETIME_OFFSET); + lifetime = add4ByteLifetimeOption(optionType, optionLength); builder.updateDnsslLifetime(lifetime); break; default: - // RFC4861 section 4.2 dictates we ignore unknown options for fowards + // RFC4861 section 4.2 dictates we ignore unknown options for forwards // compatibility. + mPacket.position(position + optionLength); break; } if (optionLength <= 0) { throw new InvalidRaException(String.format( "Invalid option length opt=%d len=%d", optionType, optionLength)); } - mPacket.position(position + optionLength); } - // Mark non-lifetime bytes since last lifetime. - addNonLifetime(lastNonLifetimeStart, 0, 0); - mMinLifetime = minLifetime(packet, length); + mMinLifetime = minLifetime(); mMetricsLog.log(builder.build()); } - // Ignoring lifetimes (which may change) does {@code packet} match this RA? + // Considering only the MATCH sections, does {@code packet} match this RA? boolean matches(byte[] packet, int length) { if (length != mPacket.capacity()) return false; byte[] referencePacket = mPacket.array(); - for (Pair<Integer, Integer> nonLifetime : mNonLifetimes) { - for (int i = nonLifetime.first; i < (nonLifetime.first + nonLifetime.second); i++) { + for (PacketSection section : mPacketSections) { + if (section.type != PacketSection.Type.MATCH) continue; + for (int i = section.start; i < (section.start + section.length); i++) { if (packet[i] != referencePacket[i]) return false; } } @@ -786,36 +857,12 @@ public class ApfFilter { // What is the minimum of all lifetimes within {@code packet} in seconds? // Precondition: matches(packet, length) already returned true. - long minLifetime(byte[] packet, int length) { + long minLifetime() { long minLifetime = Long.MAX_VALUE; - // Wrap packet in ByteBuffer so we can read big-endian values easily - ByteBuffer byteBuffer = ByteBuffer.wrap(packet); - for (int i = 0; (i + 1) < mNonLifetimes.size(); i++) { - int offset = mNonLifetimes.get(i).first + mNonLifetimes.get(i).second; - - // The flow label is in mNonLifetimes, but it's not a lifetime. - if (offset == IPV6_FLOW_LABEL_OFFSET) { - continue; - } - - // The checksum is in mNonLifetimes, but it's not a lifetime. - if (offset == ICMP6_RA_CHECKSUM_OFFSET) { - continue; - } - - final int lifetimeLength = mNonLifetimes.get(i+1).first - offset; - final long optionLifetime; - switch (lifetimeLength) { - case 2: - optionLifetime = getUint16(byteBuffer, offset); - break; - case 4: - optionLifetime = getUint32(byteBuffer, offset); - break; - default: - throw new IllegalStateException("bogus lifetime size " + lifetimeLength); + for (PacketSection section : mPacketSections) { + if (section.type == PacketSection.Type.LIFETIME) { + minLifetime = Math.min(minLifetime, section.lifetime); } - minLifetime = Math.min(minLifetime, optionLifetime); } return minLifetime; } @@ -844,38 +891,24 @@ public class ApfFilter { // Skip filter if expired gen.addLoadFromMemory(Register.R0, gen.FILTER_AGE_MEMORY_SLOT); gen.addJumpIfR0GreaterThan(filterLifetime, nextFilterLabel); - for (int i = 0; i < mNonLifetimes.size(); i++) { - // Generate code to match the packet bytes - Pair<Integer, Integer> nonLifetime = mNonLifetimes.get(i); - // Don't generate JNEBS instruction for 0 bytes as it always fails the - // ASSERT_FORWARD_IN_PROGRAM(pc + cmp_imm - 1) check where cmp_imm is - // the number of bytes to compare. nonLifetime is zero between the - // valid and preferred lifetimes in the prefix option. - if (nonLifetime.second != 0) { - gen.addLoadImmediate(Register.R0, nonLifetime.first); + for (PacketSection section : mPacketSections) { + // Generate code to match the packet bytes. + if (section.type == PacketSection.Type.MATCH) { + gen.addLoadImmediate(Register.R0, section.start); gen.addJumpIfBytesNotEqual(Register.R0, - Arrays.copyOfRange(mPacket.array(), nonLifetime.first, - nonLifetime.first + nonLifetime.second), + Arrays.copyOfRange(mPacket.array(), section.start, + section.start + section.length), nextFilterLabel); } - // Generate code to test the lifetimes haven't gone down too far - if ((i + 1) < mNonLifetimes.size()) { - Pair<Integer, Integer> nextNonLifetime = mNonLifetimes.get(i + 1); - int offset = nonLifetime.first + nonLifetime.second; - - // Skip the Flow label. - if (offset == IPV6_FLOW_LABEL_OFFSET) { - continue; - } - // Skip the checksum. - if (offset == ICMP6_RA_CHECKSUM_OFFSET) { - continue; - } - int length = nextNonLifetime.first - offset; - switch (length) { - case 4: gen.addLoad32(Register.R0, offset); break; - case 2: gen.addLoad16(Register.R0, offset); break; - default: throw new IllegalStateException("bogus lifetime size " + length); + // Generate code to test the lifetimes haven't gone down too far. + // The packet is accepted if any of its lifetimes are lower than filterLifetime. + if (section.type == PacketSection.Type.LIFETIME) { + switch (section.length) { + case 4: gen.addLoad32(Register.R0, section.start); break; + case 2: gen.addLoad16(Register.R0, section.start); break; + default: + throw new IllegalStateException( + "bogus lifetime size " + section.length); } gen.addJumpIfR0LessThan(filterLifetime, nextFilterLabel); } @@ -1682,7 +1715,7 @@ public class ApfFilter { if (VDBG) log("matched RA " + ra); // Update lifetimes. ra.mLastSeen = currentTimeSeconds(); - ra.mMinLifetime = ra.minLifetime(packet, length); + ra.mMinLifetime = ra.minLifetime(); ra.seenCount++; // Keep mRas in LRU order so as to prioritize generating filters for recently seen diff --git a/src/android/net/dhcp/DhcpClient.java b/src/android/net/dhcp/DhcpClient.java index bf63c1c..e88c7cc 100644 --- a/src/android/net/dhcp/DhcpClient.java +++ b/src/android/net/dhcp/DhcpClient.java @@ -37,6 +37,7 @@ import static android.system.OsConstants.AF_PACKET; import static android.system.OsConstants.ETH_P_IP; import static android.system.OsConstants.IPPROTO_UDP; import static android.system.OsConstants.SOCK_DGRAM; +import static android.system.OsConstants.SOCK_NONBLOCK; import static android.system.OsConstants.SOCK_RAW; import static android.system.OsConstants.SOL_SOCKET; import static android.system.OsConstants.SO_BROADCAST; @@ -46,6 +47,7 @@ import static android.system.OsConstants.SO_REUSEADDR; import static com.android.server.util.NetworkStackConstants.IPV4_ADDR_ANY; import android.annotation.NonNull; +import android.annotation.Nullable; import android.content.Context; import android.net.DhcpResults; import android.net.InetAddresses; @@ -60,7 +62,9 @@ import android.net.metrics.DhcpErrorEvent; import android.net.metrics.IpConnectivityLog; import android.net.util.InterfaceParams; import android.net.util.NetworkStackUtils; +import android.net.util.PacketReader; import android.net.util.SocketUtils; +import android.os.Handler; import android.os.Message; import android.os.SystemClock; import android.system.ErrnoException; @@ -210,14 +214,9 @@ public class DhcpClient extends StateMachine { private final Random mRandom; private final IpConnectivityLog mMetricsLog = new IpConnectivityLog(); - // Sockets. - // - We use a packet socket to receive, because servers send us packets bound for IP addresses - // which we have not yet configured, and the kernel protocol stack drops these. - // - We use a UDP socket to send, so the kernel handles ARP and routing for us (DHCP servers can - // be off-link as well as on-link). - private FileDescriptor mPacketSock; + // We use a UDP socket to send, so the kernel handles ARP and routing for us (DHCP servers can + // be off-link as well as on-link). private FileDescriptor mUdpSock; - private ReceiveThread mReceiveThread; // State variables. private final StateMachine mController; @@ -244,6 +243,8 @@ public class DhcpClient extends StateMachine { private Dependencies mDependencies; @NonNull private final NetworkStackIpMemoryStore mIpMemoryStore; + @Nullable + private DhcpPacketHandler mDhcpPacketHandler; // Milliseconds SystemClock timestamps used to record transition times to DhcpBoundState. private long mLastInitEnterTime; @@ -396,23 +397,6 @@ public class DhcpClient extends StateMachine { mTransactionStartMillis = SystemClock.elapsedRealtime(); } - private boolean initSockets() { - return initPacketSocket() && initUdpSocket(); - } - - private boolean initPacketSocket() { - try { - mPacketSock = Os.socket(AF_PACKET, SOCK_RAW, ETH_P_IP); - SocketAddress addr = makePacketSocketAddress(ETH_P_IP, mIface.index); - Os.bind(mPacketSock, addr); - NetworkStackUtils.attachDhcpFilter(mPacketSock); - } catch(SocketException|ErrnoException e) { - Log.e(TAG, "Error creating packet socket", e); - return false; - } - return true; - } - private boolean initUdpSocket() { final int oldTag = TrafficStats.getAndSetThreadStatsTag( TrafficStatsConstants.TAG_SYSTEM_DHCP); @@ -423,7 +407,7 @@ public class DhcpClient extends StateMachine { Os.setsockoptInt(mUdpSock, SOL_SOCKET, SO_BROADCAST, 1); Os.setsockoptInt(mUdpSock, SOL_SOCKET, SO_RCVBUF, 0); Os.bind(mUdpSock, IPV4_ADDR_ANY, DhcpPacket.DHCP_CLIENT); - } catch(SocketException|ErrnoException e) { + } catch (SocketException | ErrnoException e) { Log.e(TAG, "Error creating UDP socket", e); return false; } finally { @@ -436,59 +420,76 @@ public class DhcpClient extends StateMachine { try { Os.connect(mUdpSock, to, DhcpPacket.DHCP_SERVER); return true; - } catch (SocketException|ErrnoException e) { + } catch (SocketException | ErrnoException e) { Log.e(TAG, "Error connecting UDP socket", e); return false; } } - private void closeSockets() { - closeSocketQuietly(mUdpSock); - closeSocketQuietly(mPacketSock); - } + private class DhcpPacketHandler extends PacketReader { + private FileDescriptor mPacketSock; - class ReceiveThread extends Thread { + DhcpPacketHandler(Handler handler) { + super(handler); + } - private final byte[] mPacket = new byte[DhcpPacket.MAX_LENGTH]; - private volatile boolean mStopped = false; + @Override + protected void handlePacket(byte[] recvbuf, int length) { + try { + final DhcpPacket packet = DhcpPacket.decodeFullPacket(recvbuf, length, + DhcpPacket.ENCAP_L2); + if (DBG) Log.d(TAG, "Received packet: " + packet); + sendMessage(CMD_RECEIVED_PACKET, packet); + } catch (DhcpPacket.ParseException e) { + Log.e(TAG, "Can't parse packet: " + e.getMessage()); + if (PACKET_DBG) { + Log.d(TAG, HexDump.dumpHexString(recvbuf, 0, length)); + } + if (e.errorCode == DhcpErrorEvent.DHCP_NO_COOKIE) { + final int snetTagId = 0x534e4554; + final String bugId = "31850211"; + final int uid = -1; + final String data = DhcpPacket.ParseException.class.getName(); + EventLog.writeEvent(snetTagId, bugId, uid, data); + } + mMetricsLog.log(mIfaceName, new DhcpErrorEvent(e.errorCode)); + } + } - public void halt() { - mStopped = true; - closeSockets(); // Interrupts the read() call the thread is blocked in. + @Override + protected FileDescriptor createFd() { + try { + mPacketSock = Os.socket(AF_PACKET, SOCK_RAW | SOCK_NONBLOCK, 0 /* protocol */); + NetworkStackUtils.attachDhcpFilter(mPacketSock); + final SocketAddress addr = makePacketSocketAddress(ETH_P_IP, mIface.index); + Os.bind(mPacketSock, addr); + } catch (SocketException | ErrnoException e) { + logError("Error creating packet socket", e); + closeFd(mPacketSock); + mPacketSock = null; + return null; + } + return mPacketSock; } @Override - public void run() { - if (DBG) Log.d(TAG, "Receive thread started"); - while (!mStopped) { - int length = 0; // Or compiler can't tell it's initialized if a parse error occurs. - try { - length = Os.read(mPacketSock, mPacket, 0, mPacket.length); - DhcpPacket packet = null; - packet = DhcpPacket.decodeFullPacket(mPacket, length, DhcpPacket.ENCAP_L2); - if (DBG) Log.d(TAG, "Received packet: " + packet); - sendMessage(CMD_RECEIVED_PACKET, packet); - } catch (IOException|ErrnoException e) { - if (!mStopped) { - Log.e(TAG, "Read error", e); - logError(DhcpErrorEvent.RECEIVE_ERROR); - } - } catch (DhcpPacket.ParseException e) { - Log.e(TAG, "Can't parse packet: " + e.getMessage()); - if (PACKET_DBG) { - Log.d(TAG, HexDump.dumpHexString(mPacket, 0, length)); - } - if (e.errorCode == DhcpErrorEvent.DHCP_NO_COOKIE) { - int snetTagId = 0x534e4554; - String bugId = "31850211"; - int uid = -1; - String data = DhcpPacket.ParseException.class.getName(); - EventLog.writeEvent(snetTagId, bugId, uid, data); - } - logError(e.errorCode); - } + protected int readPacket(FileDescriptor fd, byte[] packetBuffer) throws Exception { + try { + return Os.read(fd, packetBuffer, 0, packetBuffer.length); + } catch (IOException | ErrnoException e) { + mMetricsLog.log(mIfaceName, new DhcpErrorEvent(DhcpErrorEvent.RECEIVE_ERROR)); + throw e; } - if (DBG) Log.d(TAG, "Receive thread stopped"); + } + + @Override + protected void logError(@NonNull String msg, @Nullable Exception e) { + Log.e(TAG, msg, e); + } + + public int transmitPacket(final ByteBuffer buf, final SocketAddress socketAddress) + throws ErrnoException, SocketException { + return Os.sendto(mPacketSock, buf.array(), 0, buf.limit(), 0, socketAddress); } } @@ -500,7 +501,7 @@ public class DhcpClient extends StateMachine { try { if (encap == DhcpPacket.ENCAP_L2) { if (DBG) Log.d(TAG, "Broadcasting " + description); - Os.sendto(mPacketSock, buf.array(), 0, buf.limit(), 0, mInterfaceBroadcastAddr); + mDhcpPacketHandler.transmitPacket(buf, mInterfaceBroadcastAddr); } else if (encap == DhcpPacket.ENCAP_BOOTP && to.equals(INADDR_BROADCAST)) { if (DBG) Log.d(TAG, "Broadcasting " + description); // We only send L3-encapped broadcasts in DhcpRebindingState, @@ -517,7 +518,7 @@ public class DhcpClient extends StateMachine { description, Os.getpeername(mUdpSock))); Os.write(mUdpSock, buf); } - } catch(ErrnoException|IOException e) { + } catch (ErrnoException | IOException e) { Log.e(TAG, "Can't send packet: ", e); return false; } @@ -774,21 +775,22 @@ public class DhcpClient extends StateMachine { @Override public void enter() { clearDhcpState(); - if (initInterface() && initSockets()) { - mReceiveThread = new ReceiveThread(); - mReceiveThread.start(); - } else { - notifyFailure(); - transitionTo(mStoppedState); + if (initInterface() && initUdpSocket()) { + mDhcpPacketHandler = new DhcpPacketHandler(getHandler()); + if (mDhcpPacketHandler.start()) return; + Log.e(TAG, "Fail to start DHCP Packet Handler"); } + notifyFailure(); + transitionTo(mStoppedState); } @Override public void exit() { - if (mReceiveThread != null) { - mReceiveThread.halt(); // Also closes sockets. - mReceiveThread = null; + if (mDhcpPacketHandler != null) { + mDhcpPacketHandler.stop(); + if (DBG) Log.d(TAG, "DHCP Packet Handler stopped"); } + closeSocketQuietly(mUdpSock); clearDhcpState(); } @@ -1289,10 +1291,6 @@ public class DhcpClient extends StateMachine { class DhcpRebootingState extends LoggingState { } - private void logError(int errorCode) { - mMetricsLog.log(mIfaceName, new DhcpErrorEvent(errorCode)); - } - private void logState(String name, int durationMs) { final DhcpClientEvent event = new DhcpClientEvent.Builder() .setMsg(name) diff --git a/src/android/net/util/FdEventsReader.java b/src/android/net/util/FdEventsReader.java index 1380ea7..e82c69b 100644 --- a/src/android/net/util/FdEventsReader.java +++ b/src/android/net/util/FdEventsReader.java @@ -63,7 +63,6 @@ import java.io.IOException; * the Handler constructor argument is associated. * * @param <BufferType> the type of the buffer used to read data. - * @hide */ public abstract class FdEventsReader<BufferType> { private static final int FD_EVENTS = EVENT_INPUT | EVENT_ERROR; @@ -93,27 +92,21 @@ public abstract class FdEventsReader<BufferType> { } /** Start this FdEventsReader. */ - public void start() { - if (onCorrectThread()) { - createAndRegisterFd(); - } else { - mHandler.post(() -> { - logError("start() called from off-thread", null); - createAndRegisterFd(); - }); + public boolean start() { + if (!onCorrectThread()) { + throw new IllegalStateException("start() called from off-thread"); } + + return createAndRegisterFd(); } /** Stop this FdEventsReader and destroy the file descriptor. */ public void stop() { - if (onCorrectThread()) { - unregisterAndDestroyFd(); - } else { - mHandler.post(() -> { - logError("stop() called from off-thread", null); - unregisterAndDestroyFd(); - }); + if (!onCorrectThread()) { + throw new IllegalStateException("stop() called from off-thread"); } + + unregisterAndDestroyFd(); } @NonNull @@ -178,8 +171,8 @@ public abstract class FdEventsReader<BufferType> { */ protected void onStop() {} - private void createAndRegisterFd() { - if (mFd != null) return; + private boolean createAndRegisterFd() { + if (mFd != null) return true; try { mFd = createFd(); @@ -189,7 +182,7 @@ public abstract class FdEventsReader<BufferType> { mFd = null; } - if (mFd == null) return; + if (mFd == null) return false; mQueue.addOnFileDescriptorEventListener( mFd, @@ -205,6 +198,7 @@ public abstract class FdEventsReader<BufferType> { return FD_EVENTS; }); onStart(); + return true; } private boolean isRunning() { diff --git a/src/android/net/util/PacketReader.java b/src/android/net/util/PacketReader.java index 4aec6b6..0be7187 100644 --- a/src/android/net/util/PacketReader.java +++ b/src/android/net/util/PacketReader.java @@ -28,8 +28,6 @@ import java.io.FileDescriptor; * * TODO: rename this class to something more correctly descriptive (something * like [or less horrible than] FdReadEventsHandler?). - * - * @hide */ public abstract class PacketReader extends FdEventsReader<byte[]> { diff --git a/src/com/android/server/connectivity/NetworkMonitor.java b/src/com/android/server/connectivity/NetworkMonitor.java index 585e38e..e08170a 100644 --- a/src/com/android/server/connectivity/NetworkMonitor.java +++ b/src/com/android/server/connectivity/NetworkMonitor.java @@ -368,7 +368,13 @@ public class NetworkMonitor extends StateMachine { } catch (RemoteException e) { version = 0; } - if (version == Build.VERSION_CODES.CUR_DEVELOPMENT) version = 0; + // The AIDL was freezed from Q beta 5 but it's unfreezing from R before releasing. In order + // to distinguish the behavior between R and Q beta 5 and before Q beta 5, add SDK and + // CODENAME check here. Basically, it's only expected to return 0 for Q beta 4 and below + // because the test result has changed. + if (Build.VERSION.SDK_INT == Build.VERSION_CODES.Q + && Build.VERSION.CODENAME.equals("REL") + && version == Build.VERSION_CODES.CUR_DEVELOPMENT) version = 0; return version; } @@ -555,6 +561,14 @@ public class NetworkMonitor extends StateMachine { } } + private void notifyProbeStatusChanged(int probesCompleted, int probesSucceeded) { + try { + mCallback.notifyProbeStatusChanged(probesCompleted, probesSucceeded); + } catch (RemoteException e) { + Log.e(TAG, "Error sending probe status", e); + } + } + private void showProvisioningNotification(String action) { try { mCallback.showProvisioningNotification(action, mContext.getPackageName()); @@ -667,6 +681,8 @@ public class NetworkMonitor extends StateMachine { // no resolved IP addresses, IPs unreachable, // port 853 unreachable, port 853 is not running a // DNS-over-TLS server, et cetera). + // Cancel any outstanding CMD_EVALUATE_PRIVATE_DNS. + removeMessages(CMD_EVALUATE_PRIVATE_DNS); sendMessage(CMD_EVALUATE_PRIVATE_DNS); break; } @@ -1020,11 +1036,19 @@ public class NetworkMonitor extends StateMachine { handlePrivateDnsEvaluationFailure(); break; } + handlePrivateDnsEvaluationSuccess(); + } else { + mEvaluationState.removeProbeResult(NETWORK_VALIDATION_PROBE_PRIVDNS); } // All good! transitionTo(mValidatedState); break; + case CMD_PRIVATE_DNS_SETTINGS_CHANGED: + // When settings change the reevaluation timer must be reset. + mPrivateDnsReevalDelayMs = INITIAL_REEVALUATE_DELAY_MS; + // Let the message bubble up and be handled by parent states as usual. + return NOT_HANDLED; default: return NOT_HANDLED; } @@ -1051,8 +1075,6 @@ public class NetworkMonitor extends StateMachine { } catch (UnknownHostException uhe) { mPrivateDnsConfig = null; } - mEvaluationState.noteProbeResult(NETWORK_VALIDATION_PROBE_PRIVDNS, - (mPrivateDnsConfig != null) /* succeeded */); } private void notifyPrivateDnsConfigResolved() { @@ -1063,7 +1085,14 @@ public class NetworkMonitor extends StateMachine { } } + private void handlePrivateDnsEvaluationSuccess() { + mEvaluationState.noteProbeResult(NETWORK_VALIDATION_PROBE_PRIVDNS, + true /* succeeded */); + } + private void handlePrivateDnsEvaluationFailure() { + mEvaluationState.noteProbeResult(NETWORK_VALIDATION_PROBE_PRIVDNS, + false /* succeeded */); mEvaluationState.reportEvaluationResult(NETWORK_VALIDATION_RESULT_INVALID, null /* redirectUrl */); // Queue up a re-evaluation with backoff. @@ -1072,10 +1101,6 @@ public class NetworkMonitor extends StateMachine { // transitioning back to EvaluatingState, to perhaps give ourselves // the opportunity to (re)detect a captive portal or something. // - // TODO: distinguish between CMD_EVALUATE_PRIVATE_DNS messages that are caused by server - // lookup failures (which should continue to do exponential backoff) and - // CMD_EVALUATE_PRIVATE_DNS messages that are caused by user reconfiguration (which - // should be processed immediately. sendMessageDelayed(CMD_EVALUATE_PRIVATE_DNS, mPrivateDnsReevalDelayMs); mPrivateDnsReevalDelayMs *= 2; if (mPrivateDnsReevalDelayMs > MAX_REEVALUATE_DELAY_MS) { @@ -1101,7 +1126,6 @@ public class NetworkMonitor extends StateMachine { String.format("%dms - Error: %s", time, uhe.getMessage())); } logValidationProbe(time, PROBE_PRIVDNS, success ? DNS_SUCCESS : DNS_FAILURE); - mEvaluationState.noteProbeResult(NETWORK_VALIDATION_PROBE_PRIVDNS, success); return success; } } @@ -2105,22 +2129,43 @@ public class NetworkMonitor extends StateMachine { // Indicates which probes have completed since clearProbeResults was called. // This is a bitmask of INetworkMonitor.NETWORK_VALIDATION_PROBE_* constants. private int mProbeResults = 0; + // A bitmask to record which probes are completed. + private int mProbeCompleted = 0; // The latest redirect URL. private String mRedirectUrl; protected void clearProbeResults() { mProbeResults = 0; + mProbeCompleted = 0; } - // Probe result for http probe should be updated from reportHttpProbeResult(). - protected void noteProbeResult(int probeResult, boolean succeeded) { - if (succeeded) { - mProbeResults |= probeResult; - } else { - mProbeResults &= ~probeResult; + private void maybeNotifyProbeResults(@NonNull final Runnable modif) { + final int oldCompleted = mProbeCompleted; + final int oldResults = mProbeResults; + modif.run(); + if (oldCompleted != mProbeCompleted || oldResults != mProbeResults) { + notifyProbeStatusChanged(mProbeCompleted, mProbeResults); } } + protected void removeProbeResult(final int probeResult) { + maybeNotifyProbeResults(() -> { + mProbeCompleted &= ~probeResult; + mProbeResults &= ~probeResult; + }); + } + + protected void noteProbeResult(final int probeResult, final boolean succeeded) { + maybeNotifyProbeResults(() -> { + mProbeCompleted |= probeResult; + if (succeeded) { + mProbeResults |= probeResult; + } else { + mProbeResults &= ~probeResult; + } + }); + } + protected void reportEvaluationResult(int result, @Nullable String redirectUrl) { mEvaluationResult = result; mRedirectUrl = redirectUrl; @@ -2139,6 +2184,11 @@ public class NetworkMonitor extends StateMachine { } return mEvaluationResult | mProbeResults; } + + @VisibleForTesting + protected int getProbeCompletedResult() { + return mProbeCompleted; + } } @VisibleForTesting diff --git a/tests/integration/src/android/net/ip/IpClientIntegrationTest.java b/tests/integration/src/android/net/ip/IpClientIntegrationTest.java index cb7d418..f32171c 100644 --- a/tests/integration/src/android/net/ip/IpClientIntegrationTest.java +++ b/tests/integration/src/android/net/ip/IpClientIntegrationTest.java @@ -130,6 +130,7 @@ public class IpClientIntegrationTest { private String mIfaceName; private INetd mNetd; private HandlerThread mPacketReaderThread; + private Handler mHandler; private TapPacketReader mPacketReader; private IpClient mIpc; private Dependencies mDependencies; @@ -271,7 +272,7 @@ public class IpClientIntegrationTest { @After public void tearDown() throws Exception { if (mPacketReader != null) { - mPacketReader.stop(); // Also closes the socket + mHandler.post(() -> mPacketReader.stop()); // Also closes the socket } if (mPacketReaderThread != null) { mPacketReaderThread.quitSafely(); @@ -297,10 +298,11 @@ public class IpClientIntegrationTest { mIfaceName = iface.getInterfaceName(); mPacketReaderThread = new HandlerThread(IpClientIntegrationTest.class.getSimpleName()); mPacketReaderThread.start(); + mHandler = mPacketReaderThread.getThreadHandler(); final ParcelFileDescriptor tapFd = iface.getFileDescriptor(); - mPacketReader = new TapPacketReader(mPacketReaderThread.getThreadHandler(), tapFd); - mPacketReader.start(); + mPacketReader = new TapPacketReader(mHandler, tapFd); + mHandler.post(() -> mPacketReader.start()); } private void setUpIpClient() throws Exception { diff --git a/tests/lib/src/com/android/testutils/ConcurrentIntepreter.kt b/tests/lib/src/com/android/testutils/ConcurrentIntepreter.kt index 7796dcd..25b1e0f 100644 --- a/tests/lib/src/com/android/testutils/ConcurrentIntepreter.kt +++ b/tests/lib/src/com/android/testutils/ConcurrentIntepreter.kt @@ -4,6 +4,7 @@ import android.os.SystemClock import java.util.concurrent.CyclicBarrier import kotlin.system.measureTimeMillis import kotlin.test.assertEquals +import kotlin.test.assertFails import kotlin.test.assertNull import kotlin.test.assertTrue @@ -64,7 +65,15 @@ open class ConcurrentIntepreter<T>( // Spins as many threads as needed by the test spec and interpret each program concurrently, // having all threads waiting on a CyclicBarrier after each line. - fun interpretTestSpec(spec: String, initial: T, threadTransform: (T) -> T = { it }) { + // |lineShift| says how many lines after the call the spec starts. This is used for error + // reporting. Unfortunately AFAICT there is no way to get the line of an argument rather + // than the line at which the expression starts. + fun interpretTestSpec( + spec: String, + initial: T, + lineShift: Int = 0, + threadTransform: (T) -> T = { it } + ) { // For nice stack traces val callSite = getCallingMethod() val lines = spec.trim().trim('\n').split("\n").map { it.split("|") } @@ -91,7 +100,8 @@ open class ConcurrentIntepreter<T>( // testing. Instead, catch the exception, cancel other threads, and report // nicely. Catch throwable because fail() is AssertionError, which inherits // from Error. - crash = InterpretException(threadIndex, it, callSite.lineNumber + lineNum, + crash = InterpretException(threadIndex, it, + callSite.lineNumber + lineNum + lineShift, callSite.className, callSite.methodName, callSite.fileName, e) } barrier.await() @@ -147,6 +157,9 @@ private fun <T> getDefaultInstructions() = listOf<InterpretMatcher<T>>( // Interpret sleep. Optional argument for the count, in INTERPRET_TIME_UNIT units. Regex("""sleep(\((\d+)\))?""") to { i, t, r -> SystemClock.sleep(if (r.strArg(2).isEmpty()) i.interpretTimeUnit else r.timeArg(2)) + }, + Regex("""(.*)\s*fails""") to { i, t, r -> + assertFails { i.interpret(r.strArg(1), t) } } ) diff --git a/tests/lib/src/com/android/testutils/TestableNetworkCallback.kt b/tests/lib/src/com/android/testutils/TestableNetworkCallback.kt index 1cc1168..bbb279e 100644 --- a/tests/lib/src/com/android/testutils/TestableNetworkCallback.kt +++ b/tests/lib/src/com/android/testutils/TestableNetworkCallback.kt @@ -21,11 +21,15 @@ import android.net.LinkProperties import android.net.Network import android.net.NetworkCapabilities import android.net.NetworkCapabilities.NET_CAPABILITY_VALIDATED -import com.android.testutils.RecorderCallback.CallbackRecord.Available -import com.android.testutils.RecorderCallback.CallbackRecord.BlockedStatus -import com.android.testutils.RecorderCallback.CallbackRecord.CapabilitiesChanged -import com.android.testutils.RecorderCallback.CallbackRecord.LinkPropertiesChanged -import com.android.testutils.RecorderCallback.CallbackRecord.Lost +import com.android.testutils.RecorderCallback.CallbackEntry.Available +import com.android.testutils.RecorderCallback.CallbackEntry.BlockedStatus +import com.android.testutils.RecorderCallback.CallbackEntry.CapabilitiesChanged +import com.android.testutils.RecorderCallback.CallbackEntry.LinkPropertiesChanged +import com.android.testutils.RecorderCallback.CallbackEntry.Losing +import com.android.testutils.RecorderCallback.CallbackEntry.Lost +import com.android.testutils.RecorderCallback.CallbackEntry.Resumed +import com.android.testutils.RecorderCallback.CallbackEntry.Suspended +import com.android.testutils.RecorderCallback.CallbackEntry.Unavailable import kotlin.reflect.KClass import kotlin.test.assertEquals import kotlin.test.assertTrue @@ -35,38 +39,43 @@ object NULL_NETWORK : Network(-1) private val Int.capabilityName get() = NetworkCapabilities.capabilityNameOf(this) -open class RecorderCallback : NetworkCallback() { - sealed class CallbackRecord { +open class RecorderCallback private constructor( + private val backingRecord: ArrayTrackRecord<CallbackEntry> +) : NetworkCallback() { + public constructor() : this(ArrayTrackRecord()) + protected constructor(src: RecorderCallback?): this(src?.backingRecord ?: ArrayTrackRecord()) + + sealed class CallbackEntry { // To get equals(), hashcode(), componentN() etc for free, the child classes of // this class are data classes. But while data classes can inherit from other classes, // they may only have visible members in the constructors, so they couldn't declare - // a constructor with a non-val arg to pass to CallbackRecord. Instead, force all + // a constructor with a non-val arg to pass to CallbackEntry. Instead, force all // subclasses to implement a `network' property, which can be done in a data class // constructor by specifying override. abstract val network: Network - data class Available(override val network: Network) : CallbackRecord() + data class Available(override val network: Network) : CallbackEntry() data class CapabilitiesChanged( override val network: Network, val caps: NetworkCapabilities - ) : CallbackRecord() + ) : CallbackEntry() data class LinkPropertiesChanged( override val network: Network, val lp: LinkProperties - ) : CallbackRecord() - data class Suspended(override val network: Network) : CallbackRecord() - data class Resumed(override val network: Network) : CallbackRecord() - data class Losing(override val network: Network, val maxMsToLive: Int) : CallbackRecord() - data class Lost(override val network: Network) : CallbackRecord() + ) : CallbackEntry() + data class Suspended(override val network: Network) : CallbackEntry() + data class Resumed(override val network: Network) : CallbackEntry() + data class Losing(override val network: Network, val maxMsToLive: Int) : CallbackEntry() + data class Lost(override val network: Network) : CallbackEntry() data class Unavailable private constructor( override val network: Network - ) : CallbackRecord() { + ) : CallbackEntry() { constructor() : this(NULL_NETWORK) } data class BlockedStatus( override val network: Network, val blocked: Boolean - ) : CallbackRecord() + ) : CallbackEntry() // Convenience constants for expecting a type companion object { @@ -91,12 +100,15 @@ open class RecorderCallback : NetworkCallback() { } } - protected val history = ArrayTrackRecord<CallbackRecord>().newReadHead() + protected val history = backingRecord.newReadHead() override fun onAvailable(network: Network) { history.add(Available(network)) } + // PreCheck is not used in the tests today. For backward compatibility with existing tests that + // expect the callbacks not to record this, do not listen to PreCheck here. + override fun onCapabilitiesChanged(network: Network, caps: NetworkCapabilities) { history.add(CapabilitiesChanged(network, caps)) } @@ -110,39 +122,46 @@ open class RecorderCallback : NetworkCallback() { } override fun onNetworkSuspended(network: Network) { - history.add(CallbackRecord.Suspended(network)) + history.add(Suspended(network)) } override fun onNetworkResumed(network: Network) { - history.add(CallbackRecord.Resumed(network)) + history.add(Resumed(network)) } override fun onLosing(network: Network, maxMsToLive: Int) { - history.add(CallbackRecord.Losing(network, maxMsToLive)) + history.add(Losing(network, maxMsToLive)) } override fun onLost(network: Network) { - history.add(CallbackRecord.Lost(network)) + history.add(Lost(network)) } override fun onUnavailable() { - history.add(CallbackRecord.Unavailable()) + history.add(Unavailable()) } } -typealias CallbackType = KClass<out RecorderCallback.CallbackRecord> -const val DEFAULT_TIMEOUT = 200L // ms +private const val DEFAULT_TIMEOUT = 200L // ms + +open class TestableNetworkCallback private constructor( + src: TestableNetworkCallback?, + val defaultTimeoutMs: Long = DEFAULT_TIMEOUT +) : RecorderCallback(src) { + @JvmOverloads + constructor(timeoutMs: Long = DEFAULT_TIMEOUT): this(null, timeoutMs) + + fun createLinkedCopy() = TestableNetworkCallback(this, defaultTimeoutMs) -open class TestableNetworkCallback(val defaultTimeoutMs: Long = DEFAULT_TIMEOUT) - : RecorderCallback() { - // The last available network. Null if the last available network was lost since. + // The last available network, or null if any network was lost since the last call to + // onAvailable. TODO : fix this by fixing the tests that rely on this behavior val lastAvailableNetwork: Network? get() = when (val it = history.lastOrNull { it is Available || it is Lost }) { is Available -> it.network else -> null } - fun pollForNextCallback(timeoutMs: Long = defaultTimeoutMs): CallbackRecord { + fun pollForNextCallback(timeoutMs: Long = defaultTimeoutMs): CallbackEntry { return history.poll(timeoutMs) ?: fail("Did not receive callback after ${timeoutMs}ms") } @@ -153,7 +172,7 @@ open class TestableNetworkCallback(val defaultTimeoutMs: Long = DEFAULT_TIMEOUT) if (null != cb) fail("Expected no callback but got $cb") } - inline fun <reified T : CallbackRecord> expectCallback( + inline fun <reified T : CallbackEntry> expectCallback( network: Network, timeoutMs: Long = defaultTimeoutMs ): T = pollForNextCallback(timeoutMs).let { @@ -166,7 +185,7 @@ open class TestableNetworkCallback(val defaultTimeoutMs: Long = DEFAULT_TIMEOUT) fun expectCallbackThat( timeoutMs: Long = defaultTimeoutMs, - valid: (CallbackRecord) -> Boolean + valid: (CallbackEntry) -> Boolean ) = pollForNextCallback(timeoutMs).also { assertTrue(valid(it), "Unexpected callback : $it") } fun expectCapabilitiesThat( @@ -209,7 +228,7 @@ open class TestableNetworkCallback(val defaultTimeoutMs: Long = DEFAULT_TIMEOUT) ) { expectCallback<Available>(net, tmt) if (suspended) { - expectCallback<CallbackRecord.Suspended>(net, tmt) + expectCallback<CallbackEntry.Suspended>(net, tmt) } expectCapabilitiesThat(net, tmt) { validated == it.hasCapability(NET_CAPABILITY_VALIDATED) } expectCallback<LinkPropertiesChanged>(net, tmt) @@ -257,7 +276,7 @@ open class TestableNetworkCallback(val defaultTimeoutMs: Long = DEFAULT_TIMEOUT) } @JvmOverloads - open fun <T : CallbackRecord> expectCallback( + open fun <T : CallbackEntry> expectCallback( type: KClass<T>, n: HasNetwork?, timeoutMs: Long = defaultTimeoutMs diff --git a/tests/unit/Android.bp b/tests/unit/Android.bp index 3410d0e..03bcf95 100644 --- a/tests/unit/Android.bp +++ b/tests/unit/Android.bp @@ -22,6 +22,7 @@ android_test { resource_dirs: ["res"], static_libs: [ "androidx.test.rules", + "kotlin-reflect", "mockito-target-extended-minus-junit4", "net-tests-utils", "NetworkStackApiCurrentLib", diff --git a/tests/unit/src/android/net/ip/IpReachabilityMonitorTest.java b/tests/unit/src/android/net/ip/IpReachabilityMonitorTest.java index 64b168a..428baac 100644 --- a/tests/unit/src/android/net/ip/IpReachabilityMonitorTest.java +++ b/tests/unit/src/android/net/ip/IpReachabilityMonitorTest.java @@ -62,6 +62,8 @@ public class IpReachabilityMonitorTest { @Test public void testNothing() { - IpReachabilityMonitor monitor = makeMonitor(); + // make sure the unit test runs in the same thread with main looper. + // Otherwise, throwing IllegalStateException would cause test fails. + mHandler.post(() -> makeMonitor()); } } diff --git a/tests/unit/src/android/net/testutils/TestableNetworkCallbackTest.kt b/tests/unit/src/android/net/testutils/TestableNetworkCallbackTest.kt new file mode 100644 index 0000000..4e4d25a --- /dev/null +++ b/tests/unit/src/android/net/testutils/TestableNetworkCallbackTest.kt @@ -0,0 +1,293 @@ +package android.net.testutils + +import android.net.LinkAddress +import android.net.LinkProperties +import android.net.Network +import android.net.NetworkCapabilities +import com.android.testutils.ConcurrentIntepreter +import com.android.testutils.InterpretMatcher +import com.android.testutils.RecorderCallback.CallbackEntry +import com.android.testutils.RecorderCallback.CallbackEntry.Available +import com.android.testutils.RecorderCallback.CallbackEntry.BlockedStatus +import com.android.testutils.RecorderCallback.CallbackEntry.CapabilitiesChanged +import com.android.testutils.TestableNetworkCallback +import com.android.testutils.intArg +import com.android.testutils.strArg +import com.android.testutils.timeArg +import org.junit.Before +import org.junit.Test +import org.junit.runner.RunWith +import org.junit.runners.JUnit4 +import kotlin.reflect.KClass +import kotlin.test.assertEquals +import kotlin.test.assertFails +import kotlin.test.assertNull +import kotlin.test.assertTrue +import kotlin.test.fail + +const val SHORT_TIMEOUT_MS = 20L +const val DEFAULT_LINGER_DELAY_MS = 30000 +const val NOT_METERED = NetworkCapabilities.NET_CAPABILITY_NOT_METERED +const val WIFI = NetworkCapabilities.TRANSPORT_WIFI +const val CELLULAR = NetworkCapabilities.TRANSPORT_CELLULAR +const val TEST_INTERFACE_NAME = "testInterfaceName" + +@RunWith(JUnit4::class) +class TestableNetworkCallbackTest { + private lateinit var mCallback: TestableNetworkCallback + + private fun makeHasNetwork(netId: Int) = object : TestableNetworkCallback.HasNetwork { + override val network: Network = Network(netId) + } + + @Before + fun setUp() { + mCallback = TestableNetworkCallback() + } + + @Test + fun testLastAvailableNetwork() { + // Make sure there is no last available network at first, then the last available network + // is returned after onAvailable is called. + val net2097 = Network(2097) + assertNull(mCallback.lastAvailableNetwork) + mCallback.onAvailable(net2097) + assertEquals(mCallback.lastAvailableNetwork, net2097) + + // Make sure calling onCapsChanged/onLinkPropertiesChanged don't affect the last available + // network. + mCallback.onCapabilitiesChanged(net2097, NetworkCapabilities()) + mCallback.onLinkPropertiesChanged(net2097, LinkProperties()) + assertEquals(mCallback.lastAvailableNetwork, net2097) + + // Make sure onLost clears the last available network. + mCallback.onLost(net2097) + assertNull(mCallback.lastAvailableNetwork) + + // Do the same but with a different network after onLost : make sure the last available + // network is the new one, not the original one. + val net2098 = Network(2098) + mCallback.onAvailable(net2098) + mCallback.onCapabilitiesChanged(net2098, NetworkCapabilities()) + mCallback.onLinkPropertiesChanged(net2098, LinkProperties()) + assertEquals(mCallback.lastAvailableNetwork, net2098) + + // Make sure onAvailable changes the last available network even if onLost was not called. + val net2099 = Network(2099) + mCallback.onAvailable(net2099) + assertEquals(mCallback.lastAvailableNetwork, net2099) + + // For legacy reasons, lastAvailableNetwork is null as soon as any is lost, not necessarily + // the last available one. Check that behavior. + mCallback.onLost(net2098) + assertNull(mCallback.lastAvailableNetwork) + + // Make sure that losing the really last available one still results in null. + mCallback.onLost(net2099) + assertNull(mCallback.lastAvailableNetwork) + + // Make sure multiple onAvailable in a row then onLost still results in null. + mCallback.onAvailable(net2097) + mCallback.onAvailable(net2098) + mCallback.onAvailable(net2099) + mCallback.onLost(net2097) + assertNull(mCallback.lastAvailableNetwork) + } + + @Test + fun testAssertNoCallback() { + mCallback.assertNoCallback(SHORT_TIMEOUT_MS) + mCallback.onAvailable(Network(100)) + assertFails { mCallback.assertNoCallback(SHORT_TIMEOUT_MS) } + } + + @Test + fun testCapabilitiesWithAndWithout() { + val net = Network(101) + val matcher = makeHasNetwork(101) + val meteredNc = NetworkCapabilities() + val unmeteredNc = NetworkCapabilities().addCapability(NOT_METERED) + // Check that expecting caps (with or without) fails when no callback has been received. + assertFails { mCallback.expectCapabilitiesWith(NOT_METERED, matcher, SHORT_TIMEOUT_MS) } + assertFails { mCallback.expectCapabilitiesWithout(NOT_METERED, matcher, SHORT_TIMEOUT_MS) } + + // Add NOT_METERED and check that With succeeds and Without fails. + mCallback.onCapabilitiesChanged(net, unmeteredNc) + mCallback.expectCapabilitiesWith(NOT_METERED, matcher) + mCallback.onCapabilitiesChanged(net, unmeteredNc) + assertFails { mCallback.expectCapabilitiesWithout(NOT_METERED, matcher, SHORT_TIMEOUT_MS) } + + // Don't add NOT_METERED and check that With fails and Without succeeds. + mCallback.onCapabilitiesChanged(net, meteredNc) + assertFails { mCallback.expectCapabilitiesWith(NOT_METERED, matcher, SHORT_TIMEOUT_MS) } + mCallback.onCapabilitiesChanged(net, meteredNc) + mCallback.expectCapabilitiesWithout(NOT_METERED, matcher) + } + + @Test + fun testExpectCallbackThat() { + val net = Network(193) + val netCaps = NetworkCapabilities().addTransportType(CELLULAR) + // Check that expecting callbackThat anything fails when no callback has been received. + assertFails { mCallback.expectCallbackThat(SHORT_TIMEOUT_MS) { true } } + + // Basic test for true and false + mCallback.onAvailable(net) + mCallback.expectCallbackThat { true } + mCallback.onAvailable(net) + assertFails { mCallback.expectCallbackThat(SHORT_TIMEOUT_MS) { false } } + + // Try a positive and a negative case + mCallback.onBlockedStatusChanged(net, true) + mCallback.expectCallbackThat { cb -> cb is BlockedStatus && cb.blocked } + mCallback.onCapabilitiesChanged(net, netCaps) + assertFails { mCallback.expectCallbackThat(SHORT_TIMEOUT_MS) { cb -> + cb is CapabilitiesChanged && cb.caps.hasTransport(WIFI) + } } + } + + @Test + fun testCapabilitiesThat() { + val net = Network(101) + val netCaps = NetworkCapabilities().addCapability(NOT_METERED).addTransportType(WIFI) + // Check that expecting capabilitiesThat anything fails when no callback has been received. + assertFails { mCallback.expectCapabilitiesThat(net, SHORT_TIMEOUT_MS) { true } } + + // Basic test for true and false + mCallback.onCapabilitiesChanged(net, netCaps) + mCallback.expectCapabilitiesThat(net) { true } + mCallback.onCapabilitiesChanged(net, netCaps) + assertFails { mCallback.expectCapabilitiesThat(net, SHORT_TIMEOUT_MS) { false } } + + // Try a positive and a negative case + mCallback.onCapabilitiesChanged(net, netCaps) + mCallback.expectCapabilitiesThat(net) { caps -> + caps.hasCapability(NOT_METERED) && + caps.hasTransport(WIFI) && + !caps.hasTransport(CELLULAR) + } + mCallback.onCapabilitiesChanged(net, netCaps) + assertFails { mCallback.expectCapabilitiesThat(net, SHORT_TIMEOUT_MS) { caps -> + caps.hasTransport(CELLULAR) + } } + + // Try a matching callback on the wrong network + mCallback.onCapabilitiesChanged(net, netCaps) + assertFails { mCallback.expectCapabilitiesThat(Network(100), SHORT_TIMEOUT_MS) { true } } + } + + @Test + fun testLinkPropertiesThat() { + val net = Network(112) + val linkAddress = LinkAddress("fe80::ace:d00d/64") + val mtu = 1984 + val linkProps = LinkProperties().apply { + this.mtu = mtu + interfaceName = TEST_INTERFACE_NAME + addLinkAddress(linkAddress) + } + + // Check that expecting linkPropsThat anything fails when no callback has been received. + assertFails { mCallback.expectLinkPropertiesThat(net, SHORT_TIMEOUT_MS) { true } } + + // Basic test for true and false + mCallback.onLinkPropertiesChanged(net, linkProps) + mCallback.expectLinkPropertiesThat(net) { true } + mCallback.onLinkPropertiesChanged(net, linkProps) + assertFails { mCallback.expectLinkPropertiesThat(net, SHORT_TIMEOUT_MS) { false } } + + // Try a positive and negative case + mCallback.onLinkPropertiesChanged(net, linkProps) + mCallback.expectLinkPropertiesThat(net) { lp -> + lp.interfaceName == TEST_INTERFACE_NAME && + lp.linkAddresses.contains(linkAddress) && + lp.mtu == mtu + } + mCallback.onLinkPropertiesChanged(net, linkProps) + assertFails { mCallback.expectLinkPropertiesThat(net, SHORT_TIMEOUT_MS) { lp -> + lp.interfaceName != TEST_INTERFACE_NAME + } } + + // Try a matching callback on the wrong network + mCallback.onLinkPropertiesChanged(net, linkProps) + assertFails { mCallback.expectLinkPropertiesThat(Network(114), SHORT_TIMEOUT_MS) { lp -> + lp.interfaceName == TEST_INTERFACE_NAME + } } + } + + @Test + fun testExpectCallback() { + val net = Network(103) + // Test expectCallback fails when nothing was sent. + assertFails { mCallback.expectCallback<BlockedStatus>(net, SHORT_TIMEOUT_MS) } + + // Test onAvailable is seen and can be expected + mCallback.onAvailable(net) + mCallback.expectCallback<Available>(net, SHORT_TIMEOUT_MS) + + // Test onAvailable won't return calls with a different network + mCallback.onAvailable(Network(106)) + assertFails { mCallback.expectCallback<Available>(net, SHORT_TIMEOUT_MS) } + + // Test onAvailable won't return calls with a different callback + mCallback.onAvailable(net) + assertFails { mCallback.expectCallback<BlockedStatus>(net, SHORT_TIMEOUT_MS) } + } + + @Test + fun testPollForNextCallback() { + assertFails { mCallback.pollForNextCallback(SHORT_TIMEOUT_MS) } + TNCInterpreter.interpretTestSpec(initial = mCallback, lineShift = 1, + threadTransform = { cb -> cb.createLinkedCopy() }, spec = """ + sleep; onAvailable(133) | poll(2) = Available(133) time 1..4 + | poll(1) fails + onCapabilitiesChanged(108) | poll(1) = CapabilitiesChanged(108) time 0..3 + onBlockedStatus(199) | poll(1) = BlockedStatus(199) time 0..3 + """) + } +} + +private object TNCInterpreter : ConcurrentIntepreter<TestableNetworkCallback>(interpretTable) + +val EntryList = CallbackEntry::class.sealedSubclasses.map { it.simpleName }.joinToString("|") +private fun callbackEntryFromString(name: String): KClass<out CallbackEntry> { + return CallbackEntry::class.sealedSubclasses.first { it.simpleName == name } +} + +private val interpretTable = listOf<InterpretMatcher<TestableNetworkCallback>>( + // Interpret "Available(xx)" as "call to onAvailable with netId xx", and likewise for + // all callback types. This is implemented above by enumerating the subclasses of + // CallbackEntry and reading their simpleName. + Regex("""(.*)\s+=\s+($EntryList)\((\d+)\)""") to { i, cb, t -> + val record = i.interpret(t.strArg(1), cb) + assertTrue(callbackEntryFromString(t.strArg(2)).isInstance(record)) + // Strictly speaking testing for is CallbackEntry is useless as it's been tested above + // but the compiler can't figure things out from the isInstance call. It does understand + // from the assertTrue(is CallbackEntry) that this is true, which allows to access + // the 'network' member below. + assertTrue(record is CallbackEntry) + assertEquals(record.network.netId, t.intArg(3)) + }, + // Interpret "onAvailable(xx)" as calling "onAvailable" with a netId of xx, and likewise for + // all callback types. NetworkCapabilities and LinkProperties just get an empty object + // as their argument. Losing gets the default linger timer. Blocked gets false. + Regex("""on($EntryList)\((\d+)\)""") to { i, cb, t -> + val net = Network(t.intArg(2)) + when (t.strArg(1)) { + "Available" -> cb.onAvailable(net) + // PreCheck not used in tests. Add it here if it becomes useful. + "CapabilitiesChanged" -> cb.onCapabilitiesChanged(net, NetworkCapabilities()) + "LinkPropertiesChanged" -> cb.onLinkPropertiesChanged(net, LinkProperties()) + "Suspended" -> cb.onNetworkSuspended(net) + "Resumed" -> cb.onNetworkResumed(net) + "Losing" -> cb.onLosing(net, DEFAULT_LINGER_DELAY_MS) + "Lost" -> cb.onLost(net) + "Unavailable" -> cb.onUnavailable() + "BlockedStatus" -> cb.onBlockedStatusChanged(net, false) + else -> fail("Unknown callback type") + } + }, + Regex("""poll\((\d+)\)""") to { i, cb, t -> + cb.pollForNextCallback(t.timeArg(1)) + } +) diff --git a/tests/unit/src/android/net/testutils/TrackRecordTest.kt b/tests/unit/src/android/net/testutils/TrackRecordTest.kt index 4fe8d37..995d537 100644 --- a/tests/unit/src/android/net/testutils/TrackRecordTest.kt +++ b/tests/unit/src/android/net/testutils/TrackRecordTest.kt @@ -352,7 +352,8 @@ class TrackRecordTest { private object TRTInterpreter : ConcurrentIntepreter<TrackRecord<Int>>(interpretTable) { fun interpretTestSpec(spec: String, useReadHeads: Boolean) = if (useReadHeads) { - interpretTestSpec(spec, ArrayTrackRecord(), { (it as ArrayTrackRecord).newReadHead() }) + interpretTestSpec(spec, initial = ArrayTrackRecord(), + threadTransform = { (it as ArrayTrackRecord).newReadHead() }) } else { interpretTestSpec(spec, ArrayTrackRecord()) } diff --git a/tests/unit/src/android/net/util/PacketReaderTest.java b/tests/unit/src/android/net/util/PacketReaderTest.java index 289dcad..3947d15 100644 --- a/tests/unit/src/android/net/util/PacketReaderTest.java +++ b/tests/unit/src/android/net/util/PacketReaderTest.java @@ -24,6 +24,8 @@ import static android.system.OsConstants.SOCK_NONBLOCK; import static android.system.OsConstants.SOL_SOCKET; import static android.system.OsConstants.SO_SNDTIMEO; +import static com.android.testutils.MiscAssertsKt.assertThrows; + import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertTrue; @@ -182,7 +184,7 @@ public class PacketReaderTest { assertTrue(Arrays.equals(two, mLastRecvBuf)); assertFalse(mStopped); - mReceiver.stop(); + h.post(() -> mReceiver.stop()); waitForActivity(); assertEquals(2, mReceiver.numPacketsReceived()); assertTrue(Arrays.equals(two, mLastRecvBuf)); @@ -208,4 +210,32 @@ public class PacketReaderTest { assertEquals(DEFAULT_RECV_BUF_SIZE, b.recvBufSize()); } } + + @Test + public void testStartingFromWrongThread() throws Exception { + final Handler h = mHandlerThread.getThreadHandler(); + final PacketReader b = new NullPacketReader(h, DEFAULT_RECV_BUF_SIZE); + assertThrows(IllegalStateException.class, () -> b.start()); + } + + @Test + public void testStoppingFromWrongThread() throws Exception { + final Handler h = mHandlerThread.getThreadHandler(); + final PacketReader b = new NullPacketReader(h, DEFAULT_RECV_BUF_SIZE); + assertThrows(IllegalStateException.class, () -> b.stop()); + } + + @Test + public void testSuccessToCreateSocket() throws Exception { + final Handler h = mHandlerThread.getThreadHandler(); + final PacketReader b = new UdpLoopbackReader(h); + h.post(() -> assertTrue(b.start())); + } + + @Test + public void testFailToCreateSocket() throws Exception { + final Handler h = mHandlerThread.getThreadHandler(); + final PacketReader b = new NullPacketReader(h, DEFAULT_RECV_BUF_SIZE); + h.post(() -> assertFalse(b.start())); + } } diff --git a/tests/unit/src/com/android/server/connectivity/NetworkMonitorTest.java b/tests/unit/src/com/android/server/connectivity/NetworkMonitorTest.java index e8bc8d2..49a2ce3 100644 --- a/tests/unit/src/com/android/server/connectivity/NetworkMonitorTest.java +++ b/tests/unit/src/com/android/server/connectivity/NetworkMonitorTest.java @@ -173,6 +173,8 @@ public class NetworkMonitorTest { private static final int VALIDATION_RESULT_VALID = NETWORK_VALIDATION_PROBE_DNS | NETWORK_VALIDATION_PROBE_HTTPS | NETWORK_VALIDATION_RESULT_VALID; + private static final int VALIDATION_RESULT_PRIVDNS_VALID = NETWORK_VALIDATION_PROBE_DNS + | NETWORK_VALIDATION_PROBE_HTTPS | NETWORK_VALIDATION_PROBE_PRIVDNS; private static final int RETURN_CODE_DNS_SUCCESS = 0; private static final int RETURN_CODE_DNS_TIMEOUT = 255; @@ -800,6 +802,8 @@ public class NetworkMonitorTest { verify(mCallbacks, timeout(HANDLER_TIMEOUT_MS).times(1)) .notifyNetworkTested(eq(VALIDATION_RESULT_VALID | NETWORK_VALIDATION_PROBE_PRIVDNS), eq(null)); + verify(mCallbacks, timeout(HANDLER_TIMEOUT_MS).times(1)).notifyProbeStatusChanged( + eq(VALIDATION_RESULT_PRIVDNS_VALID), eq(VALIDATION_RESULT_PRIVDNS_VALID)); // Verify dns query only get v4 address. resetCallbacks(); @@ -809,6 +813,10 @@ public class NetworkMonitorTest { verify(mCallbacks, timeout(HANDLER_TIMEOUT_MS).times(1)) .notifyNetworkTested(eq(VALIDATION_RESULT_VALID | NETWORK_VALIDATION_PROBE_PRIVDNS), eq(null)); + // NetworkMonitor will check if the probes has changed or not, if the probes has not + // changed, the callback won't be fired. + verify(mCallbacks, never()).notifyProbeStatusChanged( + eq(VALIDATION_RESULT_PRIVDNS_VALID), eq(VALIDATION_RESULT_PRIVDNS_VALID)); // Verify dns query get both v4 and v6 address. resetCallbacks(); @@ -818,10 +826,12 @@ public class NetworkMonitorTest { verify(mCallbacks, timeout(HANDLER_TIMEOUT_MS).times(1)) .notifyNetworkTested(eq(VALIDATION_RESULT_VALID | NETWORK_VALIDATION_PROBE_PRIVDNS), eq(null)); + verify(mCallbacks, never()).notifyProbeStatusChanged( + eq(VALIDATION_RESULT_PRIVDNS_VALID), eq(VALIDATION_RESULT_PRIVDNS_VALID)); } @Test - public void testPrivateDnsResolutionRetryUpdate() throws Exception { + public void testProbeStatusChanged() throws Exception { // Set no record in FakeDns and expect validation to fail. setStatus(mHttpsConnection, 204); setStatus(mHttpConnection, 204); @@ -829,10 +839,40 @@ public class NetworkMonitorTest { WrappedNetworkMonitor wnm = makeNotMeteredNetworkMonitor(); wnm.notifyPrivateDnsSettingsChanged(new PrivateDnsConfig("dns.google", new InetAddress[0])); wnm.notifyNetworkConnected(TEST_LINK_PROPERTIES, NOT_METERED_CAPABILITIES); - verify(mCallbacks, timeout(HANDLER_TIMEOUT_MS).atLeastOnce()) - .notifyNetworkTested( - eq(NETWORK_VALIDATION_PROBE_DNS | NETWORK_VALIDATION_PROBE_HTTPS), + verify(mCallbacks, timeout(HANDLER_TIMEOUT_MS).times(1)).notifyNetworkTested( + eq(NETWORK_VALIDATION_PROBE_DNS | NETWORK_VALIDATION_PROBE_HTTPS), eq(null)); + verify(mCallbacks, timeout(HANDLER_TIMEOUT_MS).times(1)).notifyProbeStatusChanged( + eq(VALIDATION_RESULT_PRIVDNS_VALID), eq(NETWORK_VALIDATION_PROBE_DNS + | NETWORK_VALIDATION_PROBE_HTTPS)); + // Fix DNS and retry, expect validation to succeed. + resetCallbacks(); + mFakeDns.setAnswer("dns.google", new String[]{"2001:db8::1"}, TYPE_AAAA); + + wnm.forceReevaluation(Process.myUid()); + // ProbeCompleted should be reset to 0 + HandlerUtilsKt.waitForIdle(wnm.getHandler(), HANDLER_TIMEOUT_MS); + assertEquals(wnm.getEvaluationState().getProbeCompletedResult(), 0); + verify(mCallbacks, timeout(HANDLER_TIMEOUT_MS).times(1)) + .notifyNetworkTested(eq(VALIDATION_RESULT_VALID | NETWORK_VALIDATION_PROBE_PRIVDNS), eq(null)); + verify(mCallbacks, timeout(HANDLER_TIMEOUT_MS).times(1)).notifyProbeStatusChanged( + eq(VALIDATION_RESULT_PRIVDNS_VALID), eq(VALIDATION_RESULT_PRIVDNS_VALID)); + } + + @Test + public void testPrivateDnsResolutionRetryUpdate() throws Exception { + // Set no record in FakeDns and expect validation to fail. + setStatus(mHttpsConnection, 204); + setStatus(mHttpConnection, 204); + + WrappedNetworkMonitor wnm = makeNotMeteredNetworkMonitor(); + wnm.notifyPrivateDnsSettingsChanged(new PrivateDnsConfig("dns.google", new InetAddress[0])); + wnm.notifyNetworkConnected(TEST_LINK_PROPERTIES, NOT_METERED_CAPABILITIES); + verify(mCallbacks, timeout(HANDLER_TIMEOUT_MS)).notifyNetworkTested( + eq(NETWORK_VALIDATION_PROBE_DNS | NETWORK_VALIDATION_PROBE_HTTPS), eq(null)); + verify(mCallbacks, timeout(HANDLER_TIMEOUT_MS).times(1)).notifyProbeStatusChanged( + eq(VALIDATION_RESULT_PRIVDNS_VALID), eq(NETWORK_VALIDATION_PROBE_DNS + | NETWORK_VALIDATION_PROBE_HTTPS)); // Fix DNS and retry, expect validation to succeed. resetCallbacks(); @@ -842,6 +882,8 @@ public class NetworkMonitorTest { verify(mCallbacks, timeout(HANDLER_TIMEOUT_MS).atLeastOnce()) .notifyNetworkTested(eq(VALIDATION_RESULT_VALID | NETWORK_VALIDATION_PROBE_PRIVDNS), eq(null)); + verify(mCallbacks, timeout(HANDLER_TIMEOUT_MS).times(1)).notifyProbeStatusChanged( + eq(VALIDATION_RESULT_PRIVDNS_VALID), eq(VALIDATION_RESULT_PRIVDNS_VALID)); // Change configuration to an invalid DNS name, expect validation to fail. resetCallbacks(); @@ -853,6 +895,9 @@ public class NetworkMonitorTest { verify(mCallbacks, timeout(HANDLER_TIMEOUT_MS)) .notifyNetworkTested(eq(NETWORK_VALIDATION_PROBE_DNS | NETWORK_VALIDATION_PROBE_HTTPS), eq(null)); + verify(mCallbacks, timeout(HANDLER_TIMEOUT_MS).times(1)).notifyProbeStatusChanged( + eq(VALIDATION_RESULT_PRIVDNS_VALID), eq(NETWORK_VALIDATION_PROBE_DNS + | NETWORK_VALIDATION_PROBE_HTTPS)); // Change configuration back to working again, but make private DNS not work. // Expect validation to fail. @@ -860,10 +905,13 @@ public class NetworkMonitorTest { mFakeDns.setNonBypassPrivateDnsWorking(false); wnm.notifyPrivateDnsSettingsChanged(new PrivateDnsConfig("dns.google", new InetAddress[0])); - verify(mCallbacks, timeout(HANDLER_TIMEOUT_MS).atLeastOnce()) - .notifyNetworkTested( - eq(NETWORK_VALIDATION_PROBE_DNS | NETWORK_VALIDATION_PROBE_HTTPS), - eq(null)); + verify(mCallbacks, timeout(HANDLER_TIMEOUT_MS)).notifyNetworkTested( + eq(NETWORK_VALIDATION_PROBE_DNS | NETWORK_VALIDATION_PROBE_HTTPS), eq(null)); + // NetworkMonitor will check if the probes has changed or not, if the probes has not + // changed, the callback won't be fired. + verify(mCallbacks, never()).notifyProbeStatusChanged( + eq(VALIDATION_RESULT_PRIVDNS_VALID), eq(NETWORK_VALIDATION_PROBE_DNS + | NETWORK_VALIDATION_PROBE_HTTPS)); // Make private DNS work again. Expect validation to succeed. resetCallbacks(); @@ -872,6 +920,8 @@ public class NetworkMonitorTest { verify(mCallbacks, timeout(HANDLER_TIMEOUT_MS).atLeastOnce()) .notifyNetworkTested( eq(VALIDATION_RESULT_VALID | NETWORK_VALIDATION_PROBE_PRIVDNS), eq(null)); + verify(mCallbacks, timeout(HANDLER_TIMEOUT_MS).times(1)).notifyProbeStatusChanged( + eq(VALIDATION_RESULT_PRIVDNS_VALID), eq(VALIDATION_RESULT_PRIVDNS_VALID)); } @Test |