diff options
10 files changed, 689 insertions, 14 deletions
diff --git a/core/java/android/app/WallpaperColors.java b/core/java/android/app/WallpaperColors.java index 23e9ca5c32ae..8f172ba806f6 100644 --- a/core/java/android/app/WallpaperColors.java +++ b/core/java/android/app/WallpaperColors.java @@ -27,6 +27,7 @@ import android.os.Parcelable; import android.util.Size; import com.android.internal.graphics.palette.Palette; +import com.android.internal.graphics.palette.VariationalKMeansQuantizer; import java.util.ArrayList; import java.util.Collections; @@ -142,6 +143,8 @@ public final class WallpaperColors implements Parcelable { final Palette palette = Palette .from(bitmap) + .setQuantizer(new VariationalKMeansQuantizer()) + .maximumColorCount(5) .clearFilters() .resizeBitmapArea(MAX_WALLPAPER_EXTRACTION_AREA) .generate(); diff --git a/core/java/com/android/internal/graphics/palette/ColorCutQuantizer.java b/core/java/com/android/internal/graphics/palette/ColorCutQuantizer.java index 56d60a13e159..9ac753b6d6ce 100644 --- a/core/java/com/android/internal/graphics/palette/ColorCutQuantizer.java +++ b/core/java/com/android/internal/graphics/palette/ColorCutQuantizer.java @@ -61,7 +61,7 @@ import com.android.internal.graphics.palette.Palette.Swatch; * This means that the color space is divided into distinct colors, rather than representative * colors. */ -final class ColorCutQuantizer { +final class ColorCutQuantizer implements Quantizer { private static final String LOG_TAG = "ColorCutQuantizer"; private static final boolean LOG_TIMINGS = false; @@ -73,22 +73,22 @@ final class ColorCutQuantizer { private static final int QUANTIZE_WORD_WIDTH = 5; private static final int QUANTIZE_WORD_MASK = (1 << QUANTIZE_WORD_WIDTH) - 1; - final int[] mColors; - final int[] mHistogram; - final List<Swatch> mQuantizedColors; - final TimingLogger mTimingLogger; - final Palette.Filter[] mFilters; + int[] mColors; + int[] mHistogram; + List<Swatch> mQuantizedColors; + TimingLogger mTimingLogger; + Palette.Filter[] mFilters; private final float[] mTempHsl = new float[3]; /** - * Constructor. + * Execute color quantization. * * @param pixels histogram representing an image's pixel data * @param maxColors The maximum number of colors that should be in the result palette. * @param filters Set of filters to use in the quantization stage */ - ColorCutQuantizer(final int[] pixels, final int maxColors, final Palette.Filter[] filters) { + public void quantize(final int[] pixels, final int maxColors, final Palette.Filter[] filters) { mTimingLogger = LOG_TIMINGS ? new TimingLogger(LOG_TAG, "Creation") : null; mFilters = filters; @@ -160,7 +160,7 @@ final class ColorCutQuantizer { /** * @return the list of quantized colors */ - List<Swatch> getQuantizedColors() { + public List<Swatch> getQuantizedColors() { return mQuantizedColors; } diff --git a/core/java/com/android/internal/graphics/palette/Palette.java b/core/java/com/android/internal/graphics/palette/Palette.java index 9f1504a0495c..a4f9a596050c 100644 --- a/core/java/com/android/internal/graphics/palette/Palette.java +++ b/core/java/com/android/internal/graphics/palette/Palette.java @@ -613,6 +613,8 @@ public final class Palette { private final List<Palette.Filter> mFilters = new ArrayList<>(); private Rect mRegion; + private Quantizer mQuantizer; + /** * Construct a new {@link Palette.Builder} using a source {@link Bitmap} */ @@ -726,6 +728,18 @@ public final class Palette { } /** + * Set a specific quantization algorithm. {@link ColorCutQuantizer} will + * be used if unspecified. + * + * @param quantizer Quantizer implementation. + */ + @NonNull + public Palette.Builder setQuantizer(Quantizer quantizer) { + mQuantizer = quantizer; + return this; + } + + /** * Set a region of the bitmap to be used exclusively when calculating the palette. * <p>This only works when the original input is a {@link Bitmap}.</p> * @@ -818,17 +832,19 @@ public final class Palette { } // Now generate a quantizer from the Bitmap - final ColorCutQuantizer quantizer = new ColorCutQuantizer( - getPixelsFromBitmap(bitmap), - mMaxColors, - mFilters.isEmpty() ? null : mFilters.toArray(new Palette.Filter[mFilters.size()])); + if (mQuantizer == null) { + mQuantizer = new ColorCutQuantizer(); + } + mQuantizer.quantize(getPixelsFromBitmap(bitmap), + mMaxColors, mFilters.isEmpty() ? null : + mFilters.toArray(new Palette.Filter[mFilters.size()])); // If created a new bitmap, recycle it if (bitmap != mBitmap) { bitmap.recycle(); } - swatches = quantizer.getQuantizedColors(); + swatches = mQuantizer.getQuantizedColors(); if (logger != null) { logger.addSplit("Color quantization completed"); diff --git a/core/java/com/android/internal/graphics/palette/Quantizer.java b/core/java/com/android/internal/graphics/palette/Quantizer.java new file mode 100644 index 000000000000..db60f2e9dc69 --- /dev/null +++ b/core/java/com/android/internal/graphics/palette/Quantizer.java @@ -0,0 +1,27 @@ +/* + * Copyright (C) 2017 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License + */ + +package com.android.internal.graphics.palette; + +import java.util.List; + +/** + * Definition of an algorithm that receives pixels and outputs a list of colors. + */ +public interface Quantizer { + void quantize(final int[] pixels, final int maxColors, final Palette.Filter[] filters); + List<Palette.Swatch> getQuantizedColors(); +} diff --git a/core/java/com/android/internal/graphics/palette/VariationalKMeansQuantizer.java b/core/java/com/android/internal/graphics/palette/VariationalKMeansQuantizer.java new file mode 100644 index 000000000000..b0355350dc15 --- /dev/null +++ b/core/java/com/android/internal/graphics/palette/VariationalKMeansQuantizer.java @@ -0,0 +1,154 @@ +/* + * Copyright (C) 2017 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License + */ + +package com.android.internal.graphics.palette; + +import android.util.Log; + +import com.android.internal.graphics.ColorUtils; +import com.android.internal.ml.clustering.KMeans; + +import java.util.ArrayList; +import java.util.List; +import java.util.Random; + +/** + * A quantizer that uses k-means + */ +public class VariationalKMeansQuantizer implements Quantizer { + + private static final String TAG = "KMeansQuantizer"; + private static final boolean DEBUG = false; + + /** + * Clusters closer than this value will me merged. + */ + private final float mMinClusterSqDistance; + + /** + * K-means can get stuck in local optima, this can be avoided by + * repeating it and getting the "best" execution. + */ + private final int mInitializations; + + /** + * Initialize KMeans with a fixed random state to have + * consistent results across multiple runs. + */ + private final KMeans mKMeans = new KMeans(new Random(0), 30, 0); + + private List<Palette.Swatch> mQuantizedColors; + + public VariationalKMeansQuantizer() { + this(0.25f /* cluster distance */); + } + + public VariationalKMeansQuantizer(float minClusterDistance) { + this(minClusterDistance, 1 /* initializations */); + } + + public VariationalKMeansQuantizer(float minClusterDistance, int initializations) { + mMinClusterSqDistance = minClusterDistance * minClusterDistance; + mInitializations = initializations; + } + + /** + * K-Means quantizer. + * + * @param pixels Pixels to quantize. + * @param maxColors Maximum number of clusters to extract. + * @param filters Colors that should be ignored + */ + @Override + public void quantize(int[] pixels, int maxColors, Palette.Filter[] filters) { + // Start by converting all colors to HSL. + // HLS is way more meaningful for clustering than RGB. + final float[] hsl = {0, 0, 0}; + final float[][] hslPixels = new float[pixels.length][3]; + for (int i = 0; i < pixels.length; i++) { + ColorUtils.colorToHSL(pixels[i], hsl); + // Normalize hue so all values go from 0 to 1. + hslPixels[i][0] = hsl[0] / 360f; + hslPixels[i][1] = hsl[1]; + hslPixels[i][2] = hsl[2]; + } + + final List<KMeans.Mean> optimalMeans = getOptimalKMeans(maxColors, hslPixels); + + // Ideally we should run k-means again to merge clusters but it would be too expensive, + // instead we just merge all clusters that are closer than a threshold. + for (int i = 0; i < optimalMeans.size(); i++) { + KMeans.Mean current = optimalMeans.get(i); + float[] currentCentroid = current.getCentroid(); + for (int j = i + 1; j < optimalMeans.size(); j++) { + KMeans.Mean compareTo = optimalMeans.get(j); + float[] compareToCentroid = compareTo.getCentroid(); + float sqDistance = KMeans.sqDistance(currentCentroid, compareToCentroid); + // Merge them + if (sqDistance < mMinClusterSqDistance) { + optimalMeans.remove(compareTo); + current.getItems().addAll(compareTo.getItems()); + for (int k = 0; k < currentCentroid.length; k++) { + currentCentroid[k] += (compareToCentroid[k] - currentCentroid[k]) / 2.0; + } + j--; + } + } + } + + // Convert data to final format, de-normalizing the hue. + mQuantizedColors = new ArrayList<>(); + for (KMeans.Mean mean : optimalMeans) { + if (mean.getItems().size() == 0) { + continue; + } + float[] centroid = mean.getCentroid(); + mQuantizedColors.add(new Palette.Swatch(new float[]{ + centroid[0] * 360f, + centroid[1], + centroid[2] + }, mean.getItems().size())); + } + } + + private List<KMeans.Mean> getOptimalKMeans(int k, float[][] inputData) { + List<KMeans.Mean> optimal = null; + double optimalScore = -Double.MAX_VALUE; + int runs = mInitializations; + while (runs > 0) { + if (DEBUG) { + Log.d(TAG, "k-means run: " + runs); + } + List<KMeans.Mean> means = mKMeans.predict(k, inputData); + double score = KMeans.score(means); + if (optimal == null || score > optimalScore) { + if (DEBUG) { + Log.d(TAG, "\tnew optimal score: " + score); + } + optimalScore = score; + optimal = means; + } + runs--; + } + + return optimal; + } + + @Override + public List<Palette.Swatch> getQuantizedColors() { + return mQuantizedColors; + } +} diff --git a/core/java/com/android/internal/ml/clustering/KMeans.java b/core/java/com/android/internal/ml/clustering/KMeans.java new file mode 100644 index 000000000000..4d5b3331e7b1 --- /dev/null +++ b/core/java/com/android/internal/ml/clustering/KMeans.java @@ -0,0 +1,243 @@ +/* + * Copyright (C) 2017 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License + */ + +package com.android.internal.ml.clustering; + +import android.annotation.NonNull; +import android.util.Log; + +import com.android.internal.annotations.VisibleForTesting; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Random; + +/** + * Simple K-Means implementation + */ +public class KMeans { + + private static final boolean DEBUG = false; + private static final String TAG = "KMeans"; + private final Random mRandomState; + private final int mMaxIterations; + private float mSqConvergenceEpsilon; + + public KMeans() { + this(new Random()); + } + + public KMeans(Random random) { + this(random, 30 /* maxIterations */, 0.005f /* convergenceEpsilon */); + } + public KMeans(Random random, int maxIterations, float convergenceEpsilon) { + mRandomState = random; + mMaxIterations = maxIterations; + mSqConvergenceEpsilon = convergenceEpsilon * convergenceEpsilon; + } + + /** + * Runs k-means on the input data (X) trying to find k means. + * + * K-Means is known for getting stuck into local optima, so you might + * want to run it multiple time and argmax on {@link KMeans#score(List)} + * + * @param k The number of points to return. + * @param inputData Input data. + * @return An array of k Means, each representing a centroid and data points that belong to it. + */ + public List<Mean> predict(final int k, final float[][] inputData) { + checkDataSetSanity(inputData); + int dimension = inputData[0].length; + + final ArrayList<Mean> means = new ArrayList<>(); + for (int i = 0; i < k; i++) { + Mean m = new Mean(dimension); + for (int j = 0; j < dimension; j++) { + m.mCentroid[j] = mRandomState.nextFloat(); + } + means.add(m); + } + + // Iterate until we converge or run out of iterations + boolean converged = false; + for (int i = 0; i < mMaxIterations; i++) { + converged = step(means, inputData); + if (converged) { + if (DEBUG) Log.d(TAG, "Converged at iteration: " + i); + break; + } + } + if (!converged && DEBUG) Log.d(TAG, "Did not converge"); + + return means; + } + + /** + * Score calculates the inertia between means. + * This can be considered as an E step of an EM algorithm. + * + * @param means Means to use when calculating score. + * @return The score + */ + public static double score(@NonNull List<Mean> means) { + double score = 0; + final int meansSize = means.size(); + for (int i = 0; i < meansSize; i++) { + Mean mean = means.get(i); + for (int j = 0; j < meansSize; j++) { + Mean compareTo = means.get(j); + if (mean == compareTo) { + continue; + } + double distance = Math.sqrt(sqDistance(mean.mCentroid, compareTo.mCentroid)); + score += distance; + } + } + return score; + } + + @VisibleForTesting + public void checkDataSetSanity(float[][] inputData) { + if (inputData == null) { + throw new IllegalArgumentException("Data set is null."); + } else if (inputData.length == 0) { + throw new IllegalArgumentException("Data set is empty."); + } else if (inputData[0] == null) { + throw new IllegalArgumentException("Bad data set format."); + } + + final int dimension = inputData[0].length; + final int length = inputData.length; + for (int i = 1; i < length; i++) { + if (inputData[i] == null || inputData[i].length != dimension) { + throw new IllegalArgumentException("Bad data set format."); + } + } + } + + /** + * K-Means iteration. + * + * @param means Current means + * @param inputData Input data + * @return True if data set converged + */ + private boolean step(final ArrayList<Mean> means, final float[][] inputData) { + + // Clean up the previous state because we need to compute + // which point belongs to each mean again. + for (int i = means.size() - 1; i >= 0; i--) { + final Mean mean = means.get(i); + mean.mClosestItems.clear(); + } + for (int i = inputData.length - 1; i >= 0; i--) { + final float[] current = inputData[i]; + final Mean nearest = nearestMean(current, means); + nearest.mClosestItems.add(current); + } + + boolean converged = true; + // Move each mean towards the nearest data set points + for (int i = means.size() - 1; i >= 0; i--) { + final Mean mean = means.get(i); + if (mean.mClosestItems.size() == 0) { + continue; + } + + // Compute the new mean centroid: + // 1. Sum all all points + // 2. Average them + final float[] oldCentroid = mean.mCentroid; + mean.mCentroid = new float[oldCentroid.length]; + for (int j = 0; j < mean.mClosestItems.size(); j++) { + // Update each centroid component + for (int p = 0; p < mean.mCentroid.length; p++) { + mean.mCentroid[p] += mean.mClosestItems.get(j)[p]; + } + } + for (int j = 0; j < mean.mCentroid.length; j++) { + mean.mCentroid[j] /= mean.mClosestItems.size(); + } + + // We converged if the centroid didn't move for any of the means. + if (sqDistance(oldCentroid, mean.mCentroid) > mSqConvergenceEpsilon) { + converged = false; + } + } + return converged; + } + + @VisibleForTesting + public static Mean nearestMean(float[] point, List<Mean> means) { + Mean nearest = null; + float nearestDistance = Float.MAX_VALUE; + + final int meanCount = means.size(); + for (int i = 0; i < meanCount; i++) { + Mean next = means.get(i); + // We don't need the sqrt when comparing distances in euclidean space + // because they exist on both sides of the equation and cancel each other out. + float nextDistance = sqDistance(point, next.mCentroid); + if (nextDistance < nearestDistance) { + nearest = next; + nearestDistance = nextDistance; + } + } + return nearest; + } + + @VisibleForTesting + public static float sqDistance(float[] a, float[] b) { + float dist = 0; + final int length = a.length; + for (int i = 0; i < length; i++) { + dist += (a[i] - b[i]) * (a[i] - b[i]); + } + return dist; + } + + /** + * Definition of a mean, contains a centroid and points on its cluster. + */ + public static class Mean { + float[] mCentroid; + final ArrayList<float[]> mClosestItems = new ArrayList<>(); + + public Mean(int dimension) { + mCentroid = new float[dimension]; + } + + public Mean(float ...centroid) { + mCentroid = centroid; + } + + public float[] getCentroid() { + return mCentroid; + } + + public List<float[]> getItems() { + return mClosestItems; + } + + @Override + public String toString() { + return "Mean(centroid: " + Arrays.toString(mCentroid) + ", size: " + + mClosestItems.size() + ")"; + } + } +} diff --git a/tests/Internal/Android.mk b/tests/Internal/Android.mk new file mode 100644 index 000000000000..f59a6240f897 --- /dev/null +++ b/tests/Internal/Android.mk @@ -0,0 +1,20 @@ +LOCAL_PATH:= $(call my-dir) +include $(CLEAR_VARS) + +LOCAL_USE_AAPT2 := true +LOCAL_MODULE_TAGS := tests + +LOCAL_PROTOC_OPTIMIZE_TYPE := nano + +# Include some source files directly to be able to access package members +LOCAL_SRC_FILES := $(call all-java-files-under, src) + +LOCAL_JAVA_LIBRARIES := android.test.runner +LOCAL_STATIC_JAVA_LIBRARIES := junit legacy-android-test android-support-test + +LOCAL_CERTIFICATE := platform + +LOCAL_PACKAGE_NAME := InternalTests +LOCAL_COMPATIBILITY_SUITE := device-tests + +include $(BUILD_PACKAGE) diff --git a/tests/Internal/AndroidManifest.xml b/tests/Internal/AndroidManifest.xml new file mode 100644 index 000000000000..a2c95fbbfc0b --- /dev/null +++ b/tests/Internal/AndroidManifest.xml @@ -0,0 +1,28 @@ +<?xml version="1.0" encoding="utf-8"?> +<!-- + ~ Copyright (C) 2017 The Android Open Source Project + ~ + ~ Licensed under the Apache License, Version 2.0 (the "License"); + ~ you may not use this file except in compliance with the License. + ~ You may obtain a copy of the License at + ~ + ~ http://www.apache.org/licenses/LICENSE-2.0 + ~ + ~ Unless required by applicable law or agreed to in writing, software + ~ distributed under the License is distributed on an "AS IS" BASIS, + ~ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + ~ See the License for the specific language governing permissions and + ~ limitations under the License + --> + +<manifest xmlns:android="http://schemas.android.com/apk/res/android" + package="com.android.internal.tests"> + <uses-permission android:name="android.permission.READ_EXTERNAL_STORAGE" /> + <application> + <uses-library android:name="android.test.runner" /> + </application> + + <instrumentation android:name="android.support.test.runner.AndroidJUnitRunner" + android:targetPackage="com.android.internal.tests" + android:label="Internal Tests" /> +</manifest> diff --git a/tests/Internal/AndroidTest.xml b/tests/Internal/AndroidTest.xml new file mode 100644 index 000000000000..6531c9355e3d --- /dev/null +++ b/tests/Internal/AndroidTest.xml @@ -0,0 +1,29 @@ +<?xml version="1.0" encoding="utf-8"?> +<!-- + ~ Copyright (C) 2017 The Android Open Source Project + ~ + ~ Licensed under the Apache License, Version 2.0 (the "License"); + ~ you may not use this file except in compliance with the License. + ~ You may obtain a copy of the License at + ~ + ~ http://www.apache.org/licenses/LICENSE-2.0 + ~ + ~ Unless required by applicable law or agreed to in writing, software + ~ distributed under the License is distributed on an "AS IS" BASIS, + ~ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + ~ See the License for the specific language governing permissions and + ~ limitations under the License + --> +<configuration description="Runs tests for internal classes/utilities."> + <target_preparer class="com.android.tradefed.targetprep.TestAppInstallSetup"> + <option name="test-file-name" value="InternalTests.apk" /> + </target_preparer> + + <option name="test-suite-tag" value="apct" /> + <option name="test-suite-tag" value="framework-base-presubmit" /> + <option name="test-tag" value="InternalTests" /> + <test class="com.android.tradefed.testtype.AndroidJUnitTest" > + <option name="package" value="com.android.internal.tests" /> + <option name="runner" value="android.support.test.runner.AndroidJUnitRunner" /> + </test> +</configuration>
\ No newline at end of file diff --git a/tests/Internal/src/com/android/internal/ml/clustering/KMeansTest.java b/tests/Internal/src/com/android/internal/ml/clustering/KMeansTest.java new file mode 100644 index 000000000000..a64f8a60d485 --- /dev/null +++ b/tests/Internal/src/com/android/internal/ml/clustering/KMeansTest.java @@ -0,0 +1,155 @@ +/* + * Copyright (C) 2017 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.android.internal.ml.clustering; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +import android.annotation.SuppressLint; +import android.support.test.filters.SmallTest; +import android.support.test.runner.AndroidJUnit4; + +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; + +import java.util.Arrays; +import java.util.List; +import java.util.Random; + +@SmallTest +@RunWith(AndroidJUnit4.class) +public class KMeansTest { + + // Error tolerance (epsilon) + private static final double EPS = 0.01; + + private KMeans mKMeans; + + @Before + public void setUp() { + // Setup with a random seed to have predictable results + mKMeans = new KMeans(new Random(0), 30, 0); + } + + @Test + public void getCheckDataSanityTest() { + try { + mKMeans.checkDataSetSanity(new float[][] { + {0, 1, 2}, + {1, 2, 3} + }); + } catch (IllegalArgumentException e) { + Assert.fail("Valid data didn't pass sanity check"); + } + + try { + mKMeans.checkDataSetSanity(new float[][] { + null, + {1, 2, 3} + }); + Assert.fail("Data has null items and passed"); + } catch (IllegalArgumentException e) {} + + try { + mKMeans.checkDataSetSanity(new float[][] { + {0, 1, 2, 4}, + {1, 2, 3} + }); + Assert.fail("Data has invalid shape and passed"); + } catch (IllegalArgumentException e) {} + + try { + mKMeans.checkDataSetSanity(null); + Assert.fail("Null data should throw exception"); + } catch (IllegalArgumentException e) {} + } + + @Test + public void sqDistanceTest() { + float a[] = {4, 10}; + float b[] = {5, 2}; + float sqDist = (float) (Math.pow(a[0] - b[0], 2) + Math.pow(a[1] - b[1], 2)); + + assertEquals("Squared distance not valid", mKMeans.sqDistance(a, b), sqDist, EPS); + } + + @Test + public void nearestMeanTest() { + KMeans.Mean meanA = new KMeans.Mean(0, 1); + KMeans.Mean meanB = new KMeans.Mean(1, 1); + List<KMeans.Mean> means = Arrays.asList(meanA, meanB); + + KMeans.Mean nearest = mKMeans.nearestMean(new float[] {1, 1}, means); + + assertEquals("Unexpected nearest mean for point {1, 1}", nearest, meanB); + } + + @SuppressLint("DefaultLocale") + @Test + public void scoreTest() { + List<KMeans.Mean> closeMeans = Arrays.asList(new KMeans.Mean(0, 0.1f, 0.1f), + new KMeans.Mean(0, 0.1f, 0.15f), + new KMeans.Mean(0.1f, 0.2f, 0.1f)); + List<KMeans.Mean> farMeans = Arrays.asList(new KMeans.Mean(0, 0, 0), + new KMeans.Mean(0, 0.5f, 0.5f), + new KMeans.Mean(1, 0.9f, 0.9f)); + + double closeScore = KMeans.score(closeMeans); + double farScore = KMeans.score(farMeans); + assertTrue(String.format("Score of well distributed means should be greater than " + + "close means but got: %f, %f", farScore, closeScore), farScore > closeScore); + } + + @Test + public void predictTest() { + float[] expectedCentroid1 = {1, 1, 1}; + float[] expectedCentroid2 = {0, 0, 0}; + float[][] X = new float[][] { + {1, 1, 1}, + {1, 1, 1}, + {1, 1, 1}, + {0, 0, 0}, + {0, 0, 0}, + {0, 0, 0}, + }; + + final int numClusters = 2; + + // Here we assume that we won't get stuck into a local optima. + // It's fine because we're seeding a random, we won't ever have + // unstable results but in real life we need multiple initialization + // and score comparison + List<KMeans.Mean> means = mKMeans.predict(numClusters, X); + + assertEquals("Expected number of clusters is invalid", numClusters, means.size()); + + boolean exists1 = false, exists2 = false; + for (KMeans.Mean mean : means) { + if (Arrays.equals(mean.getCentroid(), expectedCentroid1)) { + exists1 = true; + } else if (Arrays.equals(mean.getCentroid(), expectedCentroid2)) { + exists2 = true; + } else { + throw new AssertionError("Unexpected mean: " + mean); + } + } + assertTrue("Expected means were not predicted, got: " + means, + exists1 && exists2); + } +} |