diff options
author | Lucas Dupin <dupin@google.com> | 2017-06-05 08:40:39 -0700 |
---|---|---|
committer | Lucas Dupin <dupin@google.com> | 2017-06-16 18:04:42 -0700 |
commit | 1d3c00d5c7a906d5a75ff66e05c7a865ff82fd4e (patch) | |
tree | 330c3e0fcb94ee985f8917467f899382d1676523 /tests/Internal/src | |
parent | 14d9e3a5295a846863c90d4313b26ae2f5e7f17f (diff) |
K-Means color clustering
Test: runtest -x tests/Internal/src/com/android/internal/ml/clustering/KMeansTest.java
Bug: 37014702
Change-Id: Idabc163df5ded362acbe462ae6b351394a36db10
Diffstat (limited to 'tests/Internal/src')
-rw-r--r-- | tests/Internal/src/com/android/internal/ml/clustering/KMeansTest.java | 155 |
1 files changed, 155 insertions, 0 deletions
diff --git a/tests/Internal/src/com/android/internal/ml/clustering/KMeansTest.java b/tests/Internal/src/com/android/internal/ml/clustering/KMeansTest.java new file mode 100644 index 000000000000..a64f8a60d485 --- /dev/null +++ b/tests/Internal/src/com/android/internal/ml/clustering/KMeansTest.java @@ -0,0 +1,155 @@ +/* + * Copyright (C) 2017 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.android.internal.ml.clustering; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +import android.annotation.SuppressLint; +import android.support.test.filters.SmallTest; +import android.support.test.runner.AndroidJUnit4; + +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; + +import java.util.Arrays; +import java.util.List; +import java.util.Random; + +@SmallTest +@RunWith(AndroidJUnit4.class) +public class KMeansTest { + + // Error tolerance (epsilon) + private static final double EPS = 0.01; + + private KMeans mKMeans; + + @Before + public void setUp() { + // Setup with a random seed to have predictable results + mKMeans = new KMeans(new Random(0), 30, 0); + } + + @Test + public void getCheckDataSanityTest() { + try { + mKMeans.checkDataSetSanity(new float[][] { + {0, 1, 2}, + {1, 2, 3} + }); + } catch (IllegalArgumentException e) { + Assert.fail("Valid data didn't pass sanity check"); + } + + try { + mKMeans.checkDataSetSanity(new float[][] { + null, + {1, 2, 3} + }); + Assert.fail("Data has null items and passed"); + } catch (IllegalArgumentException e) {} + + try { + mKMeans.checkDataSetSanity(new float[][] { + {0, 1, 2, 4}, + {1, 2, 3} + }); + Assert.fail("Data has invalid shape and passed"); + } catch (IllegalArgumentException e) {} + + try { + mKMeans.checkDataSetSanity(null); + Assert.fail("Null data should throw exception"); + } catch (IllegalArgumentException e) {} + } + + @Test + public void sqDistanceTest() { + float a[] = {4, 10}; + float b[] = {5, 2}; + float sqDist = (float) (Math.pow(a[0] - b[0], 2) + Math.pow(a[1] - b[1], 2)); + + assertEquals("Squared distance not valid", mKMeans.sqDistance(a, b), sqDist, EPS); + } + + @Test + public void nearestMeanTest() { + KMeans.Mean meanA = new KMeans.Mean(0, 1); + KMeans.Mean meanB = new KMeans.Mean(1, 1); + List<KMeans.Mean> means = Arrays.asList(meanA, meanB); + + KMeans.Mean nearest = mKMeans.nearestMean(new float[] {1, 1}, means); + + assertEquals("Unexpected nearest mean for point {1, 1}", nearest, meanB); + } + + @SuppressLint("DefaultLocale") + @Test + public void scoreTest() { + List<KMeans.Mean> closeMeans = Arrays.asList(new KMeans.Mean(0, 0.1f, 0.1f), + new KMeans.Mean(0, 0.1f, 0.15f), + new KMeans.Mean(0.1f, 0.2f, 0.1f)); + List<KMeans.Mean> farMeans = Arrays.asList(new KMeans.Mean(0, 0, 0), + new KMeans.Mean(0, 0.5f, 0.5f), + new KMeans.Mean(1, 0.9f, 0.9f)); + + double closeScore = KMeans.score(closeMeans); + double farScore = KMeans.score(farMeans); + assertTrue(String.format("Score of well distributed means should be greater than " + + "close means but got: %f, %f", farScore, closeScore), farScore > closeScore); + } + + @Test + public void predictTest() { + float[] expectedCentroid1 = {1, 1, 1}; + float[] expectedCentroid2 = {0, 0, 0}; + float[][] X = new float[][] { + {1, 1, 1}, + {1, 1, 1}, + {1, 1, 1}, + {0, 0, 0}, + {0, 0, 0}, + {0, 0, 0}, + }; + + final int numClusters = 2; + + // Here we assume that we won't get stuck into a local optima. + // It's fine because we're seeding a random, we won't ever have + // unstable results but in real life we need multiple initialization + // and score comparison + List<KMeans.Mean> means = mKMeans.predict(numClusters, X); + + assertEquals("Expected number of clusters is invalid", numClusters, means.size()); + + boolean exists1 = false, exists2 = false; + for (KMeans.Mean mean : means) { + if (Arrays.equals(mean.getCentroid(), expectedCentroid1)) { + exists1 = true; + } else if (Arrays.equals(mean.getCentroid(), expectedCentroid2)) { + exists2 = true; + } else { + throw new AssertionError("Unexpected mean: " + mean); + } + } + assertTrue("Expected means were not predicted, got: " + means, + exists1 && exists2); + } +} |