summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/com/android/networkstack/netlink/TcpSocketTracker.java53
-rw-r--r--src/com/android/server/connectivity/NetworkMonitor.java7
-rw-r--r--tests/unit/src/com/android/networkstack/netlink/TcpSocketTrackerTest.java69
-rw-r--r--tests/unit/src/com/android/server/connectivity/NetworkMonitorTest.java3
4 files changed, 119 insertions, 13 deletions
diff --git a/src/com/android/networkstack/netlink/TcpSocketTracker.java b/src/com/android/networkstack/netlink/TcpSocketTracker.java
index 8f24aec..8929cdb 100644
--- a/src/com/android/networkstack/netlink/TcpSocketTracker.java
+++ b/src/com/android/networkstack/netlink/TcpSocketTracker.java
@@ -39,6 +39,9 @@ import static android.system.OsConstants.SOL_SOCKET;
import static android.system.OsConstants.SO_SNDTIMEO;
import android.content.Context;
+import android.net.INetd;
+import android.net.MarkMaskParcel;
+import android.net.Network;
import android.net.netlink.NetlinkConstants;
import android.net.netlink.NetlinkSocket;
import android.net.netlink.StructInetDiagMsg;
@@ -46,6 +49,8 @@ import android.net.netlink.StructNlMsgHdr;
import android.net.util.NetworkStackUtils;
import android.net.util.SocketUtils;
import android.os.AsyncTask;
+import android.os.IBinder;
+import android.os.RemoteException;
import android.os.SystemClock;
import android.provider.DeviceConfig;
import android.system.ErrnoException;
@@ -59,6 +64,8 @@ import androidx.annotation.NonNull;
import androidx.annotation.Nullable;
import com.android.internal.annotations.VisibleForTesting;
+import com.android.networkstack.apishim.NetworkShimImpl;
+import com.android.networkstack.apishim.UnsupportedApiLevelException;
import java.io.FileDescriptor;
import java.io.InterruptedIOException;
@@ -84,6 +91,8 @@ public class TcpSocketTracker {
private static final long IO_TIMEOUT = 3_000L;
/** Cookie offset of an InetMagMessage header. */
private static final int IDIAG_COOKIE_OFFSET = 44;
+ private static final int UNKNOWN_MARK = 0xffffffff;
+ private static final int NULL_MASK = 0;
/**
* Gather the socket info.
*
@@ -106,6 +115,12 @@ public class TcpSocketTracker {
*/
private final SparseArray<byte[]> mSockDiagMsg = new SparseArray<>();
private final Dependencies mDependencies;
+ private final INetd mNetd;
+ private final Network mNetwork;
+ // The fwmark value of {@code mNetwork}.
+ private final int mNetworkMark;
+ // The network id mask of fwmark.
+ private final int mNetworkMask;
private int mMinPacketsThreshold = DEFAULT_DATA_STALL_MIN_PACKETS_THRESHOLD;
private int mTcpPacketsFailRateThreshold = DEFAULT_TCP_PACKETS_FAIL_PERCENTAGE;
@VisibleForTesting
@@ -124,8 +139,17 @@ public class TcpSocketTracker {
}
};
- public TcpSocketTracker(@NonNull final Dependencies dps) {
+ public TcpSocketTracker(@NonNull final Dependencies dps, @NonNull final Network network) {
mDependencies = dps;
+ mNetwork = network;
+ mNetd = mDependencies.getNetd();
+
+ // If the parcel is null, nothing should be matched which is achieved by the combination of
+ // {@code NULL_MASK} and {@code UNKNOWN_MARK}.
+ final MarkMaskParcel parcel = getNetworkMarkMask();
+ mNetworkMark = (parcel != null) ? parcel.mark : UNKNOWN_MARK;
+ mNetworkMask = (parcel != null) ? parcel.mask : NULL_MASK;
+
// Request tcp info from NetworkStack directly needs extra SELinux permission added after Q
// release.
if (!mDependencies.isTcpInfoParsingSupported()) return;
@@ -145,12 +169,24 @@ public class TcpSocketTracker {
mDependencies.addDeviceConfigChangedListener(mConfigListener);
}
+ @Nullable
+ private MarkMaskParcel getNetworkMarkMask() {
+ try {
+ final int netId = NetworkShimImpl.newInstance(mNetwork).getNetId();
+ return mNetd.getFwmarkForNetwork(netId);
+ } catch (UnsupportedApiLevelException e) {
+ log("Get netId is not available in this API level.");
+ } catch (RemoteException e) {
+ Log.e(TAG, "Error getting fwmark for network, ", e);
+ }
+ return null;
+ }
+
/**
* Request to send a SockDiag Netlink request. Receive and parse the returned message. This
* function is not thread-safe and should only be called from only one thread.
*
* @Return if this polling request executes successfully or not.
- * TODO: Need to filter socket info based on the target network.
*/
public boolean pollSocketsInfo() {
if (!mDependencies.isTcpInfoParsingSupported()) return false;
@@ -294,6 +330,10 @@ public class TcpSocketTracker {
private TcpStat calculateLatestPacketsStat(@NonNull final SocketInfo current,
@Nullable final SocketInfo previous) {
final TcpStat stat = new TcpStat();
+ // Ignore non-target network sockets.
+ if ((current.fwmark & mNetworkMask) != mNetworkMark) {
+ return null;
+ }
if (current.tcpInfo == null) {
log("Current tcpInfo is null.");
@@ -428,7 +468,6 @@ public class TcpSocketTracker {
// One of {@code AF_INET6, AF_INET}.
public final int ipFamily;
// "fwmark" value of the socket queried from native.
- // TODO: Used to do bit-wise '&' operation to get netId information.
public final int fwmark;
// Socket information updated elapsed real time.
public final long updateTime;
@@ -554,6 +593,14 @@ public class TcpSocketTracker {
return mContext;
}
+ /**
+ * Get an INetd connector.
+ */
+ public INetd getNetd() {
+ return INetd.Stub.asInterface(
+ (IBinder) mContext.getSystemService(Context.NETD_SERVICE));
+ }
+
/** Add device config change listener */
public void addDeviceConfigChangedListener(
@NonNull final DeviceConfig.OnPropertiesChangedListener listener) {
diff --git a/src/com/android/server/connectivity/NetworkMonitor.java b/src/com/android/server/connectivity/NetworkMonitor.java
index 9b6915c..f950b47 100644
--- a/src/com/android/server/connectivity/NetworkMonitor.java
+++ b/src/com/android/server/connectivity/NetworkMonitor.java
@@ -405,7 +405,7 @@ public class NetworkMonitor extends StateMachine {
SharedLog validationLog) {
this(context, cb, network, new IpConnectivityLog(), validationLog,
Dependencies.DEFAULT, new DataStallStatsUtils(),
- getTcpSocketTrackerOrNull(context));
+ getTcpSocketTrackerOrNull(context, network));
}
@VisibleForTesting
@@ -2240,7 +2240,6 @@ public class NetworkMonitor extends StateMachine {
// Check TCP signal. Suspect it may be a data stall if :
// 1. TCP connection fail rate(lost+retrans) is higher than threshold.
// 2. Accumulate enough packets count.
- // TODO: Need to filter per target network.
final TcpSocketTracker tst = getTcpSocketTracker();
if (dataStallEvaluateTypeEnabled(DATA_STALL_EVALUATION_TYPE_TCP) && tst != null) {
if (tst.getLatestReceivedCount() > 0) {
@@ -2406,14 +2405,14 @@ public class NetworkMonitor extends StateMachine {
}
@Nullable
- private static TcpSocketTracker getTcpSocketTrackerOrNull(Context context) {
+ private static TcpSocketTracker getTcpSocketTrackerOrNull(Context context, Network network) {
return ((Dependencies.DEFAULT.getDeviceConfigPropertyInt(
NAMESPACE_CONNECTIVITY,
CONFIG_DATA_STALL_EVALUATION_TYPE,
DEFAULT_DATA_STALL_EVALUATION_TYPES)
& DATA_STALL_EVALUATION_TYPE_TCP) != 0)
? new TcpSocketTracker(new TcpSocketTracker.Dependencies(context,
- ShimUtils.isReleaseOrDevelopmentApiAbove(Build.VERSION_CODES.Q)))
+ ShimUtils.isReleaseOrDevelopmentApiAbove(Build.VERSION_CODES.Q)), network)
: null;
}
}
diff --git a/tests/unit/src/com/android/networkstack/netlink/TcpSocketTrackerTest.java b/tests/unit/src/com/android/networkstack/netlink/TcpSocketTrackerTest.java
index 25ffc66..341545f 100644
--- a/tests/unit/src/com/android/networkstack/netlink/TcpSocketTrackerTest.java
+++ b/tests/unit/src/com/android/networkstack/netlink/TcpSocketTrackerTest.java
@@ -22,6 +22,9 @@ import static android.net.util.DataStallUtils.DEFAULT_TCP_PACKETS_FAIL_PERCENTAG
import static android.provider.DeviceConfig.NAMESPACE_CONNECTIVITY;
import static android.system.OsConstants.AF_INET;
+import static com.android.dx.mockito.inline.extended.ExtendedMockito.doReturn;
+import static com.android.dx.mockito.inline.extended.ExtendedMockito.mockitoSession;
+
import static junit.framework.Assert.assertEquals;
import static junit.framework.Assert.assertFalse;
import static junit.framework.Assert.assertTrue;
@@ -31,18 +34,27 @@ import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.any;
import static org.mockito.Mockito.when;
+import android.net.INetd;
+import android.net.MarkMaskParcel;
+import android.net.Network;
import android.net.netlink.StructNlMsgHdr;
import androidx.test.filters.SmallTest;
import androidx.test.runner.AndroidJUnit4;
+import com.android.networkstack.apishim.NetworkShim;
+import com.android.networkstack.apishim.NetworkShimImpl;
+
import libcore.util.HexEncoding;
+import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
+import org.mockito.MockitoSession;
+import org.mockito.quality.Strictness;
import java.io.FileDescriptor;
import java.nio.ByteBuffer;
@@ -111,7 +123,7 @@ public class TcpSocketTrackerTest {
"00000000" + // data
"0800" + // len = 8
"0F00" + // type = 15(INET_DIAG_MARK)
- "851A0C00" + // data, socket mark=793221
+ "850A0C00" + // data, socket mark=789125
"AC00" + // len = 172
"0200" + // type = 2(INET_DIAG_INFO)
// tcp_info
@@ -175,18 +187,48 @@ public class TcpSocketTrackerTest {
+ "00"; // retrans
private static final byte[] TEST_RESPONSE_BYTES =
HexEncoding.decode(TEST_RESPONSE_HEX.toCharArray(), false);
+ private static final int TEST_NETID1 = 0xA85;
+ private static final int TEST_NETID2 = 0x1A85;
+ private static final int TEST_NETID1_FWMARK = 0x0A85;
+ private static final int TEST_NETID2_FWMARK = 0x1A85;
+ private static final int NETID_MASK = 0xffff;
@Mock private TcpSocketTracker.Dependencies mDependencies;
@Mock private FileDescriptor mMockFd;
-
+ @Mock private Network mNetwork = new Network(TEST_NETID1);
+ @Mock private INetd mNetd;
+ private MockitoSession mSession;
+ @Mock NetworkShim mNetworkShim;
@Before
public void setUp() throws Exception {
MockitoAnnotations.initMocks(this);
+ when(mDependencies.getNetd()).thenReturn(mNetd);
when(mDependencies.isTcpInfoParsingSupported()).thenReturn(true);
when(mDependencies.connectToKernel()).thenReturn(mMockFd);
when(mDependencies.getDeviceConfigPropertyInt(
eq(NAMESPACE_CONNECTIVITY),
eq(CONFIG_TCP_PACKETS_FAIL_PERCENTAGE),
anyInt())).thenReturn(DEFAULT_TCP_PACKETS_FAIL_PERCENTAGE);
+ mSession = mockitoSession()
+ .spyStatic(NetworkShimImpl.class)
+ .strictness(Strictness.WARN)
+ .startMocking();
+
+ doReturn(mNetworkShim).when(() -> NetworkShimImpl.newInstance(mNetwork));
+ when(mNetworkShim.getNetId()).thenReturn(TEST_NETID1);
+ when(mNetd.getFwmarkForNetwork(eq(TEST_NETID1)))
+ .thenReturn(makeMarkMaskParcel(NETID_MASK, TEST_NETID1_FWMARK));
+ }
+
+ @After
+ public void tearDown() {
+ mSession.finishMocking();
+ }
+
+ private MarkMaskParcel makeMarkMaskParcel(final int mask, final int mark) {
+ final MarkMaskParcel parcel = new MarkMaskParcel();
+ parcel.mask = mask;
+ parcel.mark = mark;
+ return parcel;
}
private ByteBuffer getByteBuffer(final byte[] bytes) {
@@ -198,7 +240,7 @@ public class TcpSocketTrackerTest {
@Test
public void testParseSockInfo() {
final ByteBuffer buffer = getByteBuffer(SOCK_DIAG_TCP_INET_BYTES);
- final TcpSocketTracker tst = new TcpSocketTracker(mDependencies);
+ final TcpSocketTracker tst = new TcpSocketTracker(mDependencies, mNetwork);
buffer.position(SOCKDIAG_MSG_HEADER_SIZE);
final TcpSocketTracker.SocketInfo parsed =
tst.parseSockInfo(buffer, AF_INET, 276, 100L);
@@ -248,7 +290,7 @@ public class TcpSocketTrackerTest {
expected.put(TcpInfo.Field.DELIVERY_RATE, 0L);
assertEquals(parsed.tcpInfo, new TcpInfo(expected));
- assertEquals(parsed.fwmark, 793221);
+ assertEquals(parsed.fwmark, 789125);
assertEquals(parsed.updateTime, 100);
assertEquals(parsed.ipFamily, AF_INET);
}
@@ -270,7 +312,7 @@ public class TcpSocketTrackerTest {
@Test
public void testPollSocketsInfo() throws Exception {
when(mDependencies.isTcpInfoParsingSupported()).thenReturn(false);
- final TcpSocketTracker tst = new TcpSocketTracker(mDependencies);
+ final TcpSocketTracker tst = new TcpSocketTracker(mDependencies, mNetwork);
assertFalse(tst.pollSocketsInfo());
when(mDependencies.isTcpInfoParsingSupported()).thenReturn(true);
@@ -335,7 +377,7 @@ public class TcpSocketTrackerTest {
@Test
public void testPollSocketsInfo_BadFormat() throws Exception {
- final TcpSocketTracker tst = new TcpSocketTracker(mDependencies);
+ final TcpSocketTracker tst = new TcpSocketTracker(mDependencies, mNetwork);
ByteBuffer tcpBuffer = getByteBuffer(TEST_RESPONSE_BYTES);
when(mDependencies.recvMessage(any())).thenReturn(tcpBuffer);
@@ -351,4 +393,19 @@ public class TcpSocketTrackerTest {
// Expect to reset to 0.
assertEquals(0, tst.getLatestPacketFailPercentage());
}
+
+ @Test
+ public void testUnMatchNetwork() throws Exception {
+ when(mNetworkShim.getNetId()).thenReturn(TEST_NETID2);
+ when(mNetd.getFwmarkForNetwork(eq(TEST_NETID2)))
+ .thenReturn(makeMarkMaskParcel(NETID_MASK, TEST_NETID2_FWMARK));
+ final TcpSocketTracker tst = new TcpSocketTracker(mDependencies, mNetwork);
+ final ByteBuffer tcpBuffer = getByteBuffer(TEST_RESPONSE_BYTES);
+ when(mDependencies.recvMessage(any())).thenReturn(tcpBuffer);
+ assertTrue(tst.pollSocketsInfo());
+
+ assertEquals(0, tst.getSentSinceLastRecv());
+ assertEquals(-1, tst.getLatestPacketFailPercentage());
+ assertFalse(tst.isDataStallSuspected());
+ }
}
diff --git a/tests/unit/src/com/android/server/connectivity/NetworkMonitorTest.java b/tests/unit/src/com/android/server/connectivity/NetworkMonitorTest.java
index 1e512f4..7289ab4 100644
--- a/tests/unit/src/com/android/server/connectivity/NetworkMonitorTest.java
+++ b/tests/unit/src/com/android/server/connectivity/NetworkMonitorTest.java
@@ -71,6 +71,7 @@ import android.content.res.Configuration;
import android.content.res.Resources;
import android.net.ConnectivityManager;
import android.net.DnsResolver;
+import android.net.INetd;
import android.net.INetworkMonitorCallbacks;
import android.net.LinkProperties;
import android.net.Network;
@@ -163,6 +164,7 @@ public class NetworkMonitorTest {
private @Mock WifiInfo mWifiInfo;
private @Captor ArgumentCaptor<String> mNetworkTestedRedirectUrlCaptor;
private @Mock TcpSocketTracker.Dependencies mTstDependencies;
+ private @Mock INetd mNetd;
private @Mock TcpSocketTracker mTst;
private HashSet<WrappedNetworkMonitor> mCreatedNetworkMonitors;
private HashSet<BroadcastReceiver> mRegisteredReceivers;
@@ -371,6 +373,7 @@ public class NetworkMonitorTest {
setFallbackSpecs(null); // Test with no fallback spec by default
when(mRandom.nextInt()).thenReturn(0);
+ when(mTstDependencies.getNetd()).thenReturn(mNetd);
// DNS probe timeout should not be defined more than half of HANDLER_TIMEOUT_MS. Otherwise,
// it will fail the test because of timeout expired for querying AAAA and A sequentially.
when(mResources.getInteger(eq(R.integer.config_captive_portal_dns_probe_timeout)))