[Java] K-means算法(Spark Demo) →→→→→进入此内容的聊天室

来自 , 2020-11-07, 写在 Java, 查看 161 次.
URL http://www.code666.cn/view/c0f168ce
  1. package spark.examples
  2.  
  3. import java.util.Random
  4. import spark.SparkContext
  5. import spark.SparkContext._
  6. import spark.examples.Vector._
  7.  
  8. object SparkKMeans {
  9.     /**
  10.      * line -> vector
  11.      */
  12. def parseVector (line: String) : Vector = {
  13.         return new Vector (line.split (' ').map (_.toDouble) )
  14.     }
  15.  
  16.     /**
  17.      * 计算该节点的最近中心节点
  18.      */
  19. def closestCenter (p: Vector, centers: Array[Vector]) : Int = {
  20.         var bestIndex = 0
  21.         var bestDist = p.squaredDist (centers (0) ) //差平方之和
  22.         for (i < - 1 until centers.length) {
  23.             val dist = p.squaredDist (centers (i) )
  24.             if (dist < bestDist) {
  25.                 bestDist = dist
  26.                 bestIndex = i
  27.             }
  28.         }
  29.         return bestIndex
  30.     }
  31.  
  32. def main (args: Array[String]) {
  33.         if (args.length < 3) {
  34.             System.err.println ("Usage: SparkKMeans <master> <file> <dimensions> <k> <iters>")
  35.             System.exit (1)
  36.         }
  37.         val sc = new SparkContext (args (0), "SparkKMeans")
  38.         val lines = sc.textFile (args (1), args (5).toInt)
  39.                     val points = lines.map (parseVector (_) ).cache() //文本中每行为一个节点,再将每个节点转换成Vector
  40.                                  val dimensions = args (2).toInt //节点的维度
  41.                                          val k = args (3).toInt //聚类个数
  42.                                                  val iterations = args (4).toInt //迭代次数
  43.  
  44.                                                          // 随机初始化k个中心节点
  45.                                                          val rand = new Random (42)
  46.         var centers = new Array[Vector] (k)
  47.         for (i < - 0 until k)
  48.             centers (i) = Vector (dimensions, _ => 2 * rand.nextDouble - 1)
  49.                           println ("Initial centers: " + centers.mkString (", ") )
  50.                           val time1 = System.currentTimeMillis()
  51.             for (i < - 1 to iterations) {
  52.                 println ("On iteration " + i)
  53.  
  54.                 // Map each point to the index of its closest center and a (point, 1) pair
  55.                 // that we will use to compute an average later
  56.                 val mappedPoints = points.map { p => (closestCenter (p, centers), (p, 1) ) }
  57.  
  58.                 val newCenters = mappedPoints.reduceByKey {
  59.                 case ( (sum1, count1), (sum2, count2) ) => (sum1 + sum2, count1 + count2) //(向量相加, 计数器相加)
  60.                     } .map {
  61.                 case (id, (sum, count) ) => (id, sum / count) //根据前面的聚类,重新计算中心节点的位置
  62.                     } .collect
  63.  
  64.                 // 更新中心节点
  65.                 for ( (id, value) < - newCenters) {
  66.                     centers (id) = value
  67.                 }
  68.             }
  69.                        val time2 = System.currentTimeMillis()
  70.                                    println ("Final centers: " + centers.mkString (", ") + ", time: " + (time2 - time1) )
  71.     }
  72. }

回复 "K-means算法(Spark Demo)"

这儿你可以回复上面这条便签

captcha