summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--services/people/java/com/android/server/people/prediction/ShareTargetPredictor.java8
-rw-r--r--services/tests/servicestests/src/com/android/server/people/prediction/ShareTargetPredictorTest.java62
2 files changed, 56 insertions, 14 deletions
diff --git a/services/people/java/com/android/server/people/prediction/ShareTargetPredictor.java b/services/people/java/com/android/server/people/prediction/ShareTargetPredictor.java
index 2e60f2afcdea..236ac8407faa 100644
--- a/services/people/java/com/android/server/people/prediction/ShareTargetPredictor.java
+++ b/services/people/java/com/android/server/people/prediction/ShareTargetPredictor.java
@@ -16,6 +16,8 @@
package com.android.server.people.prediction;
+import static java.util.Collections.reverseOrder;
+
import android.annotation.NonNull;
import android.annotation.Nullable;
import android.annotation.UserIdInt;
@@ -39,6 +41,7 @@ import com.android.server.people.data.PackageData;
import java.util.ArrayList;
import java.util.Collections;
+import java.util.Comparator;
import java.util.List;
import java.util.function.Consumer;
@@ -85,7 +88,9 @@ class ShareTargetPredictor extends AppTargetPredictor {
List<ShareTarget> shareTargets = getDirectShareTargets();
SharesheetModelScorer.computeScore(shareTargets, getShareEventType(mIntentFilter),
System.currentTimeMillis());
- Collections.sort(shareTargets, (t1, t2) -> -Float.compare(t1.getScore(), t2.getScore()));
+ Collections.sort(shareTargets,
+ Comparator.comparing(ShareTarget::getScore, reverseOrder())
+ .thenComparing(t -> t.getAppTarget().getRank()));
List<AppTarget> res = new ArrayList<>();
for (int i = 0; i < Math.min(getPredictionContext().getPredictedTargetCount(),
shareTargets.size()); i++) {
@@ -135,6 +140,7 @@ class ShareTargetPredictor extends AppTargetPredictor {
new AppTargetId(shortcutInfo.getId()),
shortcutInfo)
.setClassName(shareShortcut.getTargetComponent().getClassName())
+ .setRank(shortcutInfo.getRank())
.build();
String packageName = shortcutInfo.getPackage();
int userId = shortcutInfo.getUserId();
diff --git a/services/tests/servicestests/src/com/android/server/people/prediction/ShareTargetPredictorTest.java b/services/tests/servicestests/src/com/android/server/people/prediction/ShareTargetPredictorTest.java
index 60104d390eb7..b09a3c374e86 100644
--- a/services/tests/servicestests/src/com/android/server/people/prediction/ShareTargetPredictorTest.java
+++ b/services/tests/servicestests/src/com/android/server/people/prediction/ShareTargetPredictorTest.java
@@ -117,10 +117,10 @@ public final class ShareTargetPredictorTest {
@Test
public void testPredictTargets() {
- mShareShortcuts.add(buildShareShortcut(PACKAGE_1, CLASS_1, "sc1"));
- mShareShortcuts.add(buildShareShortcut(PACKAGE_1, CLASS_1, "sc2"));
- mShareShortcuts.add(buildShareShortcut(PACKAGE_2, CLASS_2, "sc3"));
- mShareShortcuts.add(buildShareShortcut(PACKAGE_2, CLASS_2, "sc4"));
+ mShareShortcuts.add(buildShareShortcut(PACKAGE_1, CLASS_1, "sc1", 0));
+ mShareShortcuts.add(buildShareShortcut(PACKAGE_1, CLASS_1, "sc2", 0));
+ mShareShortcuts.add(buildShareShortcut(PACKAGE_2, CLASS_2, "sc3", 0));
+ mShareShortcuts.add(buildShareShortcut(PACKAGE_2, CLASS_2, "sc4", 0));
when(mPackageData1.getConversationInfo("sc1")).thenReturn(mock(ConversationInfo.class));
when(mPackageData1.getConversationInfo("sc2")).thenReturn(mock(ConversationInfo.class));
@@ -165,12 +165,12 @@ public final class ShareTargetPredictorTest {
@Test
public void testPredictTargets_reachTargetsLimit() {
- mShareShortcuts.add(buildShareShortcut(PACKAGE_1, CLASS_1, "sc1"));
- mShareShortcuts.add(buildShareShortcut(PACKAGE_1, CLASS_1, "sc2"));
- mShareShortcuts.add(buildShareShortcut(PACKAGE_2, CLASS_2, "sc3"));
- mShareShortcuts.add(buildShareShortcut(PACKAGE_2, CLASS_2, "sc4"));
- mShareShortcuts.add(buildShareShortcut(PACKAGE_1, CLASS_1, "sc5"));
- mShareShortcuts.add(buildShareShortcut(PACKAGE_2, CLASS_2, "sc6"));
+ mShareShortcuts.add(buildShareShortcut(PACKAGE_1, CLASS_1, "sc1", 0));
+ mShareShortcuts.add(buildShareShortcut(PACKAGE_1, CLASS_1, "sc2", 0));
+ mShareShortcuts.add(buildShareShortcut(PACKAGE_2, CLASS_2, "sc3", 0));
+ mShareShortcuts.add(buildShareShortcut(PACKAGE_2, CLASS_2, "sc4", 0));
+ mShareShortcuts.add(buildShareShortcut(PACKAGE_1, CLASS_1, "sc5", 0));
+ mShareShortcuts.add(buildShareShortcut(PACKAGE_2, CLASS_2, "sc6", 0));
when(mPackageData1.getConversationInfo("sc1")).thenReturn(mock(ConversationInfo.class));
when(mPackageData1.getConversationInfo("sc2")).thenReturn(mock(ConversationInfo.class));
@@ -250,6 +250,41 @@ public final class ShareTargetPredictorTest {
}
@Test
+ public void testPredictTargets_noSharingHistoryRankedByShortcutRank() {
+ mShareShortcuts.add(buildShareShortcut(PACKAGE_1, CLASS_1, "sc1", 3));
+ mShareShortcuts.add(buildShareShortcut(PACKAGE_1, CLASS_1, "sc2", 2));
+ mShareShortcuts.add(buildShareShortcut(PACKAGE_2, CLASS_2, "sc3", 1));
+ mShareShortcuts.add(buildShareShortcut(PACKAGE_2, CLASS_2, "sc4", 0));
+
+ when(mPackageData1.getConversationInfo("sc1")).thenReturn(mock(ConversationInfo.class));
+ when(mPackageData1.getConversationInfo("sc2")).thenReturn(mock(ConversationInfo.class));
+ when(mPackageData2.getConversationInfo("sc3")).thenReturn(mock(ConversationInfo.class));
+ // "sc4" does not have a ConversationInfo.
+
+ mPredictor.predictTargets();
+
+ verify(mUpdatePredictionsMethod).accept(mAppTargetCaptor.capture());
+ List<AppTarget> res = mAppTargetCaptor.getValue();
+ assertEquals(4, res.size());
+
+ assertEquals("sc4", res.get(0).getId().getId());
+ assertEquals(CLASS_2, res.get(0).getClassName());
+ assertEquals(PACKAGE_2, res.get(0).getPackageName());
+
+ assertEquals("sc3", res.get(1).getId().getId());
+ assertEquals(CLASS_2, res.get(1).getClassName());
+ assertEquals(PACKAGE_2, res.get(1).getPackageName());
+
+ assertEquals("sc2", res.get(2).getId().getId());
+ assertEquals(CLASS_1, res.get(2).getClassName());
+ assertEquals(PACKAGE_1, res.get(2).getPackageName());
+
+ assertEquals("sc1", res.get(3).getId().getId());
+ assertEquals(CLASS_1, res.get(3).getClassName());
+ assertEquals(PACKAGE_1, res.get(3).getPackageName());
+ }
+
+ @Test
public void testSortTargets() {
AppTarget appTarget1 = new AppTarget.Builder(
new AppTargetId("cls1#pkg1"), PACKAGE_1, UserHandle.of(USER_ID))
@@ -348,19 +383,20 @@ public final class ShareTargetPredictorTest {
}
private static ShareShortcutInfo buildShareShortcut(
- String packageName, String className, String shortcutId) {
- ShortcutInfo shortcutInfo = buildShortcut(packageName, shortcutId);
+ String packageName, String className, String shortcutId, int rank) {
+ ShortcutInfo shortcutInfo = buildShortcut(packageName, shortcutId, rank);
ComponentName componentName = new ComponentName(packageName, className);
return new ShareShortcutInfo(shortcutInfo, componentName);
}
- private static ShortcutInfo buildShortcut(String packageName, String shortcutId) {
+ private static ShortcutInfo buildShortcut(String packageName, String shortcutId, int rank) {
Context mockContext = mock(Context.class);
when(mockContext.getPackageName()).thenReturn(packageName);
when(mockContext.getUserId()).thenReturn(USER_ID);
when(mockContext.getUser()).thenReturn(UserHandle.of(USER_ID));
ShortcutInfo.Builder builder = new ShortcutInfo.Builder(mockContext, shortcutId)
.setShortLabel(shortcutId)
+ .setRank(rank)
.setIntent(new Intent("TestIntent"));
return builder.build();
}