机器学习入门:K-Means聚类算法

December 10, 2016
Machine LearningJavaBig DataAlgorithm

聚类算法是机器学习和数据挖掘领域中的一种常用算法,用于进行数据分类,把不同的数据分到不同的群组,听起来没什么的,但是用途还是挺多的,公司可以对客户资料进行聚类来对不同的客户采用不同的商业模式,电商可以根据聚类来为你推荐相似的商品。学校可以对学生考试成绩聚类来看你是好学生还是差学生。这篇博客将会讲述一种简单的聚类算法,K-Means聚类算法。

  • 问题引出

先来一个简单的问题吧,初中知识,在一个二维坐标系中有很多的点,如下图

看图,我需要吧这些点分为三组,怎么分?这是一道送分题,人眼可以直观的看出些点可以分为三组:

那么问题来了,在知道有三组的点的情况下,如何让计算机去为这些点分组呢?

我传入点的坐标
(1,1),(1,2),(2,4),(10,10),(11,12),(11,13),(80,90),(82,94),(83,89)这九个点给计算机,怎么样设计一个算法得出
(1,1),(1,2),(2,4) 是一个组,(10,10),(11,12)是一个组,(80,90),(82,94),(83,89)是一个组呢?

这个时候就就要用到聚类算法了,首先对之后的关键词进行一下解释,我刚才提到的组,在算法中叫做 “聚类”(Cluster)

  • 算法流程

这个算法过程很简单,我要把这堆点分成三组的话,首先生成三个随机的点,这三个随机点叫做质心,然后把各个点分配到离他最近的质心。三个质心最后肯定会形成三个组,每个组都有若干个被分配到的点,分组完成后,对每个组的点求平均位置,这个位置就是新的质心的位置,之后重复上面的步骤,根据新的质心再分组,然后再求平均,若干次后质心的位置可能不会再变了,这个现象称为
收敛(Converge) ,然后分组结束。

算法之所以叫做K-Means算法,K代表要分成K个聚类,Means就是平均啦,刚才也提到了。

用一张图来直观表示一下,这张图是一本书上的,网上也到处都有这个图,我就不重绘了

这张图是吧A,B,C,D,E,5个点分成两个聚类,实心点是两个随机生成的点,迭代了4次完成聚类。

  • 定义距离度量标准

上文提到了,将每个点分配到最近的中心点,但是怎么知道每个点离最近的中心点近呢?假设中心点是(5,5)和(30,30),点(4,5)离哪个中心点近呢?当然是(5,5)啦,怎么算呢?初中知识,两点间距离公式。专业点的话,这个距离叫做欧几里得距离,之前的博客提到过可以去看看推荐系统入门之协作型过滤算法

这个距离叫做 紧密度(Closeness)
,在这个计算点的应用中,欧式距离越小代表两个点越近,紧密度越高。在其他应用中还会有其他的距离度量标准,比如刚才那片博客,根据用户对不同电影的评分寻找相似用户的应用中就采用了皮尔逊相关系数来评估两个用户之间的距离,1代表紧密,-1代表没关系。

一定要注意距离度量的标准不同的应用应该采取不同的方法!

在这个给坐标系中的点分组的应用里,使用欧式距离,不能使用皮尔逊相关系数,即使我上篇博客说他很牛叉,皮尔逊相关系数代表的是趋势。

点(2,5)和(20,50),和(400,100)求皮尔逊相关系数的结果都是1,非常紧密,但实际上这三个点差的远着呢。

  • 程序设计

好了,理论够多了,看看代码怎么写吧,这里出现的所有代码在文章末尾都有Github链接可供下载。

首先是 Distance类 ,集成了几个常用的计算距离的算法:

  1. /**
  2. * Created by Mike on 12/9/2016.
  3. */
  4. public class Distance {
  5. /**
  6. * 计算欧式距离
  7. * @param vector1
  8. * @param vector2
  9. * @return
  10. */
  11. public static double getSimDistance(MyVector vector1, MyVector vector2){
  12. double sumOfSquare = 0.0;
  13. for (int i=0;i<vector1.size();i++){
  14. double vector1Score = vector1.get(i);
  15. double vector2Score = vector2.get(i);
  16. sumOfSquare += Math.pow(vector1Score - vector2Score,2);
  17. }
  18. return Math.sqrt(sumOfSquare);
  19. //return (1 / (1 + Math.sqrt(sumOfSquare)));
  20. }
  21. /**计算皮尔逊相关系数
  22. * 返回 -1 ~ 1
  23. * -1-0为正相关
  24. * 0-1为正相关
  25. */
  26. public static double getPearsonDistance(MyVector vector1, MyVector vector2){
  27. double sum1 = 0.0;
  28. double sum2 = 0.0;
  29. double sum1Sq = 0.0;
  30. double sum2Sq = 0.0;
  31. double pSum = 0.0;
  32. for (int i =0;i<vector1.size();i++){
  33. sum1 += vector1.get(i);
  34. sum2 += vector2.get(i);
  35. sum1Sq += Math.pow(vector1.get(i),2);
  36. sum2Sq += Math.pow(vector2.get(i),2);
  37. pSum += vector1.get(i) * vector2.get(i);
  38. }
  39. int n = vector1.size();
  40. double num = pSum - ((sum1*sum2)/n);
  41. double den = Math.sqrt( (sum1Sq - Math.pow(sum1,2)/n) * (sum2Sq - Math.pow(sum2,2)/n) );
  42. if (den ==0){
  43. return 0;
  44. }else {
  45. return num/den;
  46. }
  47. }
  48. }

Cluster类

这个类代表了一个聚类的抽象

维护了一个Vector的列表,这里面我用MyVector类重新封装了Java的Vector类,其实一个MyVector就是一个向量,当这个向量长度为2时,就是一个点啦。这样设计可以用来适配之后的应用,而不是只能计算点。

  1. /**
  2. * 聚类
  3. * Created by Mike on 12/9/2016.
  4. */
  5. public class Cluster {
  6. String tag = "";
  7. private ArrayList<MyVector> vectors = new ArrayList<>();
  8. private int eachVectorSize = 0;
  9. public ArrayList<MyVector> getVectors(){
  10. return vectors;
  11. }
  12. public void addToCluster(MyVector vector){
  13. this.vectors.add(vector);
  14. this.eachVectorSize = vector.size();
  15. }
  16. public void printCluster(boolean withTag){
  17. if (withTag){
  18. System.out.println(tag);
  19. }
  20. for (int i=0;i<vectors.size();i++){
  21. MyVector vector = vectors.get(i);
  22. System.out.print("Vector "+i+": ");
  23. vector.printVector();
  24. }
  25. }
  26. public void clearCluster(){
  27. this.vectors.clear();
  28. }
  29. /**
  30. * 计算簇的中点
  31. * @return
  32. */
  33. public MyVector getCenterVector(){
  34. //预先初始化定长Vector
  35. MyVector vector = new MyVector(eachVectorSize);
  36. for (MyVector tempVector:vectors){
  37. for (int i=0; i < tempVector.size();i++){
  38. Double original = vector.get(i);
  39. vector.set(i,original+tempVector.get(i));
  40. }
  41. }
  42. for (int i=0;i<vector.size();i++){
  43. vector.set(i,vector.get(i)/this.vectors.size());
  44. }
  45. return vector;
  46. }
  47. }

然后是用来实现K-Means算法的类:

K-Means类

这个类看起来长一点实际上也很好理解

centers储存的是所有的中心点,vectors是所有的点,clusters用来储存所有的聚类。

第一次生成中心点的时候我没有随机生成点,而是选择了几个先有的点。

调用startClustering()函数开始聚类,这个函数中的lastCenters是用来储存上一次的中心点,每次生成新的中心点后会与lastCenters作比较,如果中心点两次完全一样的话代表收敛了,可以直接结束程序,聚类已经完成了。

  1. import java.util.ArrayList;
  2. import java.util.List;
  3. /**
  4. * KMeans聚类算法
  5. * Created by Mike on 12/9/2016.
  6. */
  7. public class KMeans {
  8. private List<MyVector> vectors = new ArrayList<>();
  9. private List<MyVector> centers = new ArrayList<>(); //质心
  10. private int numberOfCluster = 0;
  11. private List<Cluster> clusters = new ArrayList<>(); //储存所有的族
  12. private int numberOfIteration = 100;
  13. public KMeans(List<MyVector> vectors,int numberOfCluster){
  14. this.vectors = vectors;
  15. this.numberOfCluster = numberOfCluster;
  16. initCenters();
  17. initClusters();
  18. }
  19. public List<Cluster> getClusters(){
  20. return clusters;
  21. }
  22. private void initClusters(){
  23. clusters.clear();
  24. //预先初始化所有的簇
  25. for (int i=0;i<numberOfCluster;i++){
  26. clusters.add(new Cluster());
  27. }
  28. }
  29. public void setNumberOfIteration(int numberOfIteration){
  30. if (numberOfIteration > 0){
  31. this.numberOfIteration = numberOfIteration;
  32. }else {
  33. System.out.println("numberOfIteration should be greater than 0");
  34. }
  35. }
  36. public void startClustering(){
  37. System.out.println("开始聚类");
  38. int counter = 0;
  39. List<MyVector> lastCenters = new ArrayList<>();
  40. boolean converged = false;
  41. while (!converged && counter < numberOfIteration){
  42. System.out.println("第"+counter+"次迭代");
  43. double[][] distanceMatrix = new double[vectors.size()][numberOfCluster];
  44. //生成距离矩阵
  45. for (int i=0;i<vectors.size();i++){
  46. for (int j=0;j<centers.size();j++){
  47. MyVector currentVector = vectors.get(i);
  48. MyVector centerVector = centers.get(j);
  49. double distance = Distance.getSimDistance(centerVector,currentVector);
  50. distanceMatrix[i][j] = distance;
  51. }
  52. }
  53. //add vectors to different clusters
  54. for (int i=0;i<distanceMatrix.length;i++){
  55. double[] centerDistance = distanceMatrix[i];
  56. MyVector vector = vectors.get(i);
  57. int index = getMinDistanceIndex(centerDistance);
  58. clusters.get(index).addToCluster(vector);
  59. }
  60. counter++;
  61. lastCenters.clear();
  62. lastCenters.addAll(centers);
  63. // printClusters();
  64. //Refresh centers
  65. for (int i=0;i<numberOfCluster;i++){
  66. MyVector vector = clusters.get(i).getCenterVector();
  67. centers.set(i,vector);
  68. }
  69. converged = isConverged(lastCenters);
  70. if (!converged && counter != numberOfIteration){
  71. initClusters();
  72. }
  73. }
  74. System.out.println("聚类完成\n迭代次数"+counter);
  75. // printClusters();
  76. }
  77. /**
  78. * 检测是否收敛(中心点和上次比是否变化)
  79. * @param lastCenters
  80. * @return
  81. */
  82. private boolean isConverged(List<MyVector> lastCenters){
  83. if (lastCenters.size() != numberOfCluster){
  84. return false;
  85. }
  86. boolean converge = true;
  87. for (int i=0;i<this.centers.size();i++){
  88. MyVector thisVector = this.centers.get(i);
  89. MyVector thatVector = lastCenters.get(i);
  90. if (!thisVector.isSameVector(thatVector)){
  91. converge = false;
  92. }
  93. }
  94. if (converge){
  95. System.out.println("检测到收敛");
  96. }
  97. return converge;
  98. }
  99. /**
  100. * 得到数组中最小值的下标
  101. * @param arr
  102. * @return
  103. */
  104. private int getMinDistanceIndex(double[] arr){
  105. double min = 99999999999999999999999.0;
  106. int index = 0;
  107. for (int i=0;i<arr.length;i++){
  108. if (arr[i] < min){
  109. index = i;
  110. min = arr[i];
  111. }
  112. }
  113. return index;
  114. }
  115. /**
  116. * 打印二维数组
  117. * @param mat
  118. */
  119. private void printMatrix(double[][] mat){
  120. int height = mat.length;
  121. int width = mat[0].length;
  122. for (int i= 0;i< height;i++){
  123. for (int j=0;j<width;j++){
  124. System.out.print(mat[i][j] + "\t");
  125. }
  126. System.out.println();
  127. }
  128. }
  129. /**
  130. * 打印中心点
  131. */
  132. private void printCenters(){
  133. System.out.println("Printing Centers");
  134. for (MyVector vector:centers){
  135. vector.printVector();
  136. }
  137. }
  138. /**
  139. * 打印所有聚类
  140. */
  141. private void printClusters(){
  142. for (Cluster cluster:clusters){
  143. cluster.printCluster(false);
  144. System.out.println();
  145. }
  146. }
  147. /**
  148. * 初始质心
  149. * 从所有的向量中挑选 numberOfCluster 个
  150. */
  151. private void initCenters(){
  152. int sizeOfVectors = vectors.size();
  153. for (int i=0;i<numberOfCluster;i++){
  154. int index= (int)(Math.random()*sizeOfVectors);
  155. centers.add(vectors.get(index));
  156. }
  157. }
  158. }

好了现在来试一试效果吧:

简单粗暴的加入这么多点,需要生成3组聚类;

接下来运行:

可能不太明显,但是还是能看出程序在不断迭代来逼近正确的分组。

  • 优点

优点嘛,速度快,而且整个算法很简单。

  • 缺点

说下这个算法的缺点吧,第一个当然就是 聚类个数要人为规定
,这是最大的一个缺点,因为在大数据中很多时候人们是不知道有多少个聚类的,所以就得乱猜。不过已经有K-Means+算法能解决这个问题。

第二个就是有可能会聚类失败,看上面的例子,一开始随机选择的点中,有两个是一组的,1个是另外一组的,如果一开始的三个点全都是一组的可能会造成聚类失败。

第三点是可能会受到单独点的干扰,比如有个点是(1000,1000)这个点是唯一的,没有点离他近,算法可能无法判断然后分到其他组中。

  • 应用

好了,写了这么多,是不是觉得给坐标点聚类太过于无聊,那咱们来玩个有聊的,给你一张图片,提取他的主色,这个应用也挺广泛的。比如iTunes11的专辑列表:

好了话不多说咱们开始思考怎么弄吧:

当然是聚类了,可是这次不再是二维坐标点了,使用每个像素颜色的R,G,B值生成一个三维向量,然后在计算每个向量的相似度作为距离比较,这里使用欧氏距离还是合适的,当然也可以使用其他度量标准,比如,余弦相似度:

其实就是高中的余弦定理公式,用来求两个向量夹角的,夹角越小当然向量相似度越高啦,上面的公式吧向量拓展到了n维度。在这里只是提一下。

所以这个应用完全可以使用刚才的代码,初始化MyVector的时候传入三个元素就代表一个三维向量了。

ThemeColorPicker类

这个类用来读取一个图片,然后读取每个像素的R,G,B值然后生成一个MyVector的列表,然后使用刚才设计的KMeans类进行聚类就行啦,完事之后我还写了个GUI来测试。

  1. package ColorPicker;
  2. import Algo.Cluster;
  3. import Algo.KMeans;
  4. import Algo.MyVector;
  5. import javax.imageio.ImageIO;
  6. import java.awt.image.BufferedImage;
  7. import java.io.File;
  8. import java.io.IOException;
  9. import java.util.ArrayList;
  10. import java.util.List;
  11. /**
  12. * Created by Mike on 12/9/2016.
  13. */
  14. public class ThemeColorPicker {
  15. private BufferedImage image;
  16. List<MyVector> vectors;
  17. private String path;
  18. public void setColorNumber(int colorNumber) {
  19. this.colorNumber = colorNumber;
  20. }
  21. private int colorNumber = 5;
  22. public ThemeColorPicker(String path){
  23. image = FileUtil.loadImg(path);
  24. vectors = generateColorVectors();
  25. this.path = path;
  26. }
  27. public void getThemeColor(){
  28. KMeans kMeans = new KMeans(vectors,colorNumber);
  29. kMeans.startClustering();
  30. List<Cluster> clusters = kMeans.getClusters();
  31. List<MyVector> vectors = new ArrayList<>();
  32. for (Cluster cluster:clusters){
  33. MyVector vector = cluster.getCenterVector();
  34. vectors.add(vector);
  35. System.out.println(rgbToHex(vector));
  36. }
  37. new PaletteGUI(vectors,path);
  38. }
  39. private String rgbToHex(MyVector vector){
  40. int r = vector.get(0).intValue();
  41. int g = vector.get(1).intValue();
  42. int b = vector.get(2).intValue();
  43. String hex = String.format("#%02x%02x%02x", r, g, b);
  44. return hex;
  45. }
  46. private List<MyVector> generateColorVectors(){
  47. List<MyVector> vectors = new ArrayList<>();
  48. for (int i = 0;i<image.getWidth();i++){
  49. for (int j=0;j<image.getHeight();j++){
  50. double r = getR(i,j);
  51. double g = getG(i,j);
  52. double b = getB(i,j);
  53. vectors.add(new MyVector(new double[]{r,g,b}));
  54. }
  55. }
  56. return vectors;
  57. }
  58. private double getR(int x,int y){
  59. int rgb = image.getRGB(x, y);
  60. int r = (rgb & 0xff0000) >> 16;
  61. return r;
  62. }
  63. private double getG(int x,int y){
  64. int rgb = image.getRGB(x, y);
  65. int g = (rgb & 0xff00) >> 8;
  66. return g;
  67. }
  68. private double getB(int x,int y){
  69. int rgb = image.getRGB(x, y);
  70. int b = (rgb & 0xff);
  71. return b;
  72. }
  73. public static void main(String[] args){
  74. ThemeColorPicker picker = new ThemeColorPicker("/Users/Mike/Desktop/1.jpg");
  75. picker.setColorNumber(5);
  76. picker.getThemeColor();
  77. }
  78. }
  79. class FileUtil{
  80. public static BufferedImage loadImg(String path){
  81. File imgFile = new File(path);
  82. try {
  83. return ImageIO.read(imgFile);
  84. } catch (IOException e) {
  85. e.printStackTrace();
  86. }
  87. return null;
  88. }
  89. }

运行以下代码:

设置生成5个主色

看看控制台输出:

迭代了30次收敛,当设置了需要更多聚类的时候迭代次数会增加的,而且迭代次数是不固定的。

看看效果:

所有的代码可以在我的Github中下载到:

https://github.com/Yigang0622/K-Means

我的博客MikeTech app现已登陆iPhone和Android

iPhone版下载
Android版下载

Comments

July 21, 2018 at 10:52 am

There are no comments

keyboard_arrow_up