diff options
3 files changed, 57 insertions, 15 deletions
diff --git a/src/com/android/networkstack/netlink/TcpSocketTracker.java b/src/com/android/networkstack/netlink/TcpSocketTracker.java index 51b23fd..ef33f13 100644 --- a/src/com/android/networkstack/netlink/TcpSocketTracker.java +++ b/src/com/android/networkstack/netlink/TcpSocketTracker.java @@ -49,6 +49,7 @@ import android.net.netlink.StructNlMsgHdr; import android.net.util.NetworkStackUtils; import android.net.util.SocketUtils; import android.os.AsyncTask; +import android.os.Build; import android.os.IBinder; import android.os.RemoteException; import android.os.SystemClock; @@ -65,6 +66,7 @@ import androidx.annotation.Nullable; import com.android.internal.annotations.VisibleForTesting; import com.android.networkstack.apishim.NetworkShimImpl; +import com.android.networkstack.apishim.common.ShimUtils; import com.android.networkstack.apishim.common.UnsupportedApiLevelException; import java.io.FileDescriptor; @@ -523,11 +525,9 @@ public class TcpSocketTracker { @VisibleForTesting public static class Dependencies { private final Context mContext; - private final boolean mIsTcpInfoParsingSupported; - public Dependencies(final Context context, final boolean tcpSupport) { + public Dependencies(final Context context) { mContext = context; - mIsTcpInfoParsingSupported = tcpSupport; } /** @@ -578,7 +578,7 @@ public class TcpSocketTracker { public boolean isTcpInfoParsingSupported() { // Request tcp info from NetworkStack directly needs extra SELinux permission added // after Q release. - return mIsTcpInfoParsingSupported; + return ShimUtils.isReleaseOrDevelopmentApiAbove(Build.VERSION_CODES.Q); } /** diff --git a/src/com/android/server/connectivity/NetworkMonitor.java b/src/com/android/server/connectivity/NetworkMonitor.java index 61dece0..7fb6761 100755 --- a/src/com/android/server/connectivity/NetworkMonitor.java +++ b/src/com/android/server/connectivity/NetworkMonitor.java @@ -3374,8 +3374,7 @@ public class NetworkMonitor extends StateMachine { CONFIG_DATA_STALL_EVALUATION_TYPE, DEFAULT_DATA_STALL_EVALUATION_TYPES) & DATA_STALL_EVALUATION_TYPE_TCP) != 0) - ? new TcpSocketTracker(new TcpSocketTracker.Dependencies(context, - ShimUtils.isReleaseOrDevelopmentApiAbove(Build.VERSION_CODES.Q)), network) + ? new TcpSocketTracker(new TcpSocketTracker.Dependencies(context), network) : null; } diff --git a/tests/unit/src/com/android/networkstack/netlink/TcpSocketTrackerTest.java b/tests/unit/src/com/android/networkstack/netlink/TcpSocketTrackerTest.java index 845ad1b..a884f7e 100644 --- a/tests/unit/src/com/android/networkstack/netlink/TcpSocketTrackerTest.java +++ b/tests/unit/src/com/android/networkstack/netlink/TcpSocketTrackerTest.java @@ -22,6 +22,8 @@ import static android.net.util.DataStallUtils.DEFAULT_TCP_PACKETS_FAIL_PERCENTAG import static android.provider.DeviceConfig.NAMESPACE_CONNECTIVITY; import static android.system.OsConstants.AF_INET; +import static androidx.test.platform.app.InstrumentationRegistry.getInstrumentation; + import static com.android.dx.mockito.inline.extended.ExtendedMockito.doReturn; import static com.android.dx.mockito.inline.extended.ExtendedMockito.mockitoSession; @@ -29,28 +31,37 @@ import static junit.framework.Assert.assertEquals; import static junit.framework.Assert.assertFalse; import static junit.framework.Assert.assertTrue; +import static org.junit.Assume.assumeTrue; import static org.mockito.ArgumentMatchers.anyInt; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.any; +import static org.mockito.Mockito.atLeastOnce; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; import static org.mockito.Mockito.when; import android.net.INetd; import android.net.MarkMaskParcel; import android.net.Network; import android.net.netlink.StructNlMsgHdr; +import android.os.Build; import android.util.Log; import android.util.Log.TerribleFailureHandler; import androidx.test.filters.SmallTest; import androidx.test.runner.AndroidJUnit4; +import com.android.networkstack.apishim.ConstantsShim; import com.android.networkstack.apishim.NetworkShimImpl; -import com.android.networkstack.apishim.common.NetworkShim; +import com.android.testutils.DevSdkIgnoreRule; +import com.android.testutils.DevSdkIgnoreRule.IgnoreAfter; +import com.android.testutils.DevSdkIgnoreRule.IgnoreUpTo; import libcore.util.HexEncoding; import org.junit.After; import org.junit.Before; +import org.junit.Rule; import org.junit.Test; import org.junit.runner.RunWith; import org.mockito.Mock; @@ -197,12 +208,15 @@ public class TcpSocketTrackerTest { private static final int NETID_MASK = 0xffff; @Mock private TcpSocketTracker.Dependencies mDependencies; @Mock private FileDescriptor mMockFd; - @Mock private Network mNetwork = new Network(TEST_NETID1); @Mock private INetd mNetd; + private final Network mNetwork = new Network(TEST_NETID1); + private final Network mOtherNetwork = new Network(TEST_NETID2); private MockitoSession mSession; - @Mock NetworkShim mNetworkShim; private TerribleFailureHandler mOldWtfHandler; + @Rule + public final DevSdkIgnoreRule mIgnoreRule = new DevSdkIgnoreRule(); + @Before public void setUp() throws Exception { MockitoAnnotations.initMocks(this); @@ -222,8 +236,6 @@ public class TcpSocketTrackerTest { .strictness(Strictness.WARN) .startMocking(); - doReturn(mNetworkShim).when(() -> NetworkShimImpl.newInstance(mNetwork)); - when(mNetworkShim.getNetId()).thenReturn(TEST_NETID1); when(mNetd.getFwmarkForNetwork(eq(TEST_NETID1))) .thenReturn(makeMarkMaskParcel(NETID_MASK, TEST_NETID1_FWMARK)); } @@ -275,8 +287,10 @@ public class TcpSocketTrackerTest { assertFalse(TcpSocketTracker.enoughBytesRemainForValidNlMsg(buffer)); } - @Test + @Test @IgnoreUpTo(Build.VERSION_CODES.Q) // TCP info parsing is not supported on Q public void testPollSocketsInfo() throws Exception { + // This test requires shims that provide API 30 access + assumeTrue(ConstantsShim.VERSION >= 30); when(mDependencies.isTcpInfoParsingSupported()).thenReturn(false); final TcpSocketTracker tst = new TcpSocketTracker(mDependencies, mNetwork); assertFalse(tst.pollSocketsInfo()); @@ -314,6 +328,34 @@ public class TcpSocketTrackerTest { assertTrue(tst.isDataStallSuspected()); } + @Test + public void testTcpInfoParsingUnsupported() { + doReturn(false).when(mDependencies).isTcpInfoParsingSupported(); + final TcpSocketTracker tst = new TcpSocketTracker(mDependencies, mNetwork); + verify(mDependencies).getNetd(); + + assertFalse(tst.pollSocketsInfo()); + assertEquals(-1, tst.getLatestPacketFailPercentage()); + assertEquals(-1, tst.getLatestReceivedCount()); + assertEquals(-1, tst.getSentSinceLastRecv()); + assertFalse(tst.isDataStallSuspected()); + + verify(mDependencies, atLeastOnce()).isTcpInfoParsingSupported(); + verifyNoMoreInteractions(mDependencies); + } + + @Test @IgnoreAfter(Build.VERSION_CODES.Q) + public void testTcpInfoParsingNotSupportedOnQ() { + assertFalse(new TcpSocketTracker.Dependencies(getInstrumentation().getContext()) + .isTcpInfoParsingSupported()); + } + + @Test @IgnoreUpTo(Build.VERSION_CODES.Q) + public void testTcpInfoParsingSupportedFromR() { + assertTrue(new TcpSocketTracker.Dependencies(getInstrumentation().getContext()) + .isTcpInfoParsingSupported()); + } + private static final String BAD_DIAG_MSG_HEX = // struct nlmsghdr. "00000058" + // length = 1476395008 @@ -341,8 +383,10 @@ public class TcpSocketTrackerTest { private static final byte[] BAD_SOCK_DIAG_MSG_BYTES = HexEncoding.decode(BAD_DIAG_MSG_HEX.toCharArray(), false); - @Test + @Test @IgnoreUpTo(Build.VERSION_CODES.Q) // TCP info parsing is not supported on Q public void testPollSocketsInfo_BadFormat() throws Exception { + // This test requires shims that provide API 30 access + assumeTrue(ConstantsShim.VERSION >= 30); final TcpSocketTracker tst = new TcpSocketTracker(mDependencies, mNetwork); ByteBuffer tcpBuffer = getByteBuffer(TEST_RESPONSE_BYTES); @@ -362,10 +406,9 @@ public class TcpSocketTrackerTest { @Test public void testUnMatchNetwork() throws Exception { - when(mNetworkShim.getNetId()).thenReturn(TEST_NETID2); when(mNetd.getFwmarkForNetwork(eq(TEST_NETID2))) .thenReturn(makeMarkMaskParcel(NETID_MASK, TEST_NETID2_FWMARK)); - final TcpSocketTracker tst = new TcpSocketTracker(mDependencies, mNetwork); + final TcpSocketTracker tst = new TcpSocketTracker(mDependencies, mOtherNetwork); final ByteBuffer tcpBuffer = getByteBuffer(TEST_RESPONSE_BYTES); when(mDependencies.recvMessage(any())).thenReturn(tcpBuffer); assertTrue(tst.pollSocketsInfo()); |