/* * Copyright (C) 2017 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.android.internal.ml.clustering; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; import android.annotation.SuppressLint; import androidx.test.filters.SmallTest; import androidx.test.runner.AndroidJUnit4; import org.junit.Assert; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; import java.util.Arrays; import java.util.List; import java.util.Random; @SmallTest @RunWith(AndroidJUnit4.class) public class KMeansTest { // Error tolerance (epsilon) private static final double EPS = 0.01; private KMeans mKMeans; @Before public void setUp() { // Setup with a random seed to have predictable results mKMeans = new KMeans(new Random(0), 30, 0); } @Test public void getCheckDataSanityTest() { try { mKMeans.checkDataSetSanity(new float[][] { {0, 1, 2}, {1, 2, 3} }); } catch (IllegalArgumentException e) { Assert.fail("Valid data didn't pass sanity check"); } try { mKMeans.checkDataSetSanity(new float[][] { null, {1, 2, 3} }); Assert.fail("Data has null items and passed"); } catch (IllegalArgumentException e) {} try { mKMeans.checkDataSetSanity(new float[][] { {0, 1, 2, 4}, {1, 2, 3} }); Assert.fail("Data has invalid shape and passed"); } catch (IllegalArgumentException e) {} try { mKMeans.checkDataSetSanity(null); Assert.fail("Null data should throw exception"); } catch (IllegalArgumentException e) {} } @Test public void sqDistanceTest() { float a[] = {4, 10}; float b[] = {5, 2}; float sqDist = (float) (Math.pow(a[0] - b[0], 2) + Math.pow(a[1] - b[1], 2)); assertEquals("Squared distance not valid", mKMeans.sqDistance(a, b), sqDist, EPS); } @Test public void nearestMeanTest() { KMeans.Mean meanA = new KMeans.Mean(0, 1); KMeans.Mean meanB = new KMeans.Mean(1, 1); List means = Arrays.asList(meanA, meanB); KMeans.Mean nearest = mKMeans.nearestMean(new float[] {1, 1}, means); assertEquals("Unexpected nearest mean for point {1, 1}", nearest, meanB); } @SuppressLint("DefaultLocale") @Test public void scoreTest() { List closeMeans = Arrays.asList(new KMeans.Mean(0, 0.1f, 0.1f), new KMeans.Mean(0, 0.1f, 0.15f), new KMeans.Mean(0.1f, 0.2f, 0.1f)); List farMeans = Arrays.asList(new KMeans.Mean(0, 0, 0), new KMeans.Mean(0, 0.5f, 0.5f), new KMeans.Mean(1, 0.9f, 0.9f)); double closeScore = KMeans.score(closeMeans); double farScore = KMeans.score(farMeans); assertTrue(String.format("Score of well distributed means should be greater than " + "close means but got: %f, %f", farScore, closeScore), farScore > closeScore); } @Test public void predictTest() { float[] expectedCentroid1 = {1, 1, 1}; float[] expectedCentroid2 = {0, 0, 0}; float[][] X = new float[][] { {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}, {0, 0, 0}, }; final int numClusters = 2; // Here we assume that we won't get stuck into a local optima. // It's fine because we're seeding a random, we won't ever have // unstable results but in real life we need multiple initialization // and score comparison List means = mKMeans.predict(numClusters, X); assertEquals("Expected number of clusters is invalid", numClusters, means.size()); boolean exists1 = false, exists2 = false; for (KMeans.Mean mean : means) { if (Arrays.equals(mean.getCentroid(), expectedCentroid1)) { exists1 = true; } else if (Arrays.equals(mean.getCentroid(), expectedCentroid2)) { exists2 = true; } else { throw new AssertionError("Unexpected mean: " + mean); } } assertTrue("Expected means were not predicted, got: " + means, exists1 && exists2); } }