分享

kmeans打印原始数据的分组

pandatyut 发表于 2017-6-2 17:25:54 [显示全部楼层] 回帖奖励 阅读模式 关闭右栏 1 5453
如题,写了一段spark kmeans的代码,发现能预测新的数据的分类,但是我还想打印原始数据的分类,求解答

package com.marstor.mllib

import org.apache.log4j.PropertyConfigurator
import org.apache.spark.mllib.clustering.{KMeans, KMeansModel}
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.{SparkConf, SparkContext}

/**
  * Created by root on 5/24/17.
  * 统计国足最近五年的战绩,看国足是属于亚洲几流水平
  */
object KMeansGuoZu {

  def main(args: Array[String]): Unit = {

    PropertyConfigurator.configure("log4j.properties")

    val trainingDataFilePath = "data/ml/guozu.csv"
    val testDataFilePath = "data/ml/guozu_train.csv"
    val numClusters = 3
    val numIterations = 20
    val runTimes = 2

    val conf = new SparkConf()
      .setAppName("Spark MLlib Exercise:KMeansGuoZu Clustering")
      .setMaster("local[3]")

    val sc = new SparkContext(conf)

    /**
      * country,2006,2007,2010,2014,2015
      * 日本,28,4,9,29,5
      * 韩国,17,3,15,27,2
      */

    val rawTrainingData = sc.textFile(trainingDataFilePath)
    val parsedTrainingData =
      rawTrainingData.filter(!isColumnNameLine(_)).map(line => {
        Vectors.dense(line.split(",").map(_.trim).filter(!"".equals(_)).map(_.toDouble))
      }).cache()

    val clusters: KMeansModel =
      KMeans.train(parsedTrainingData, numClusters, numIterations, runTimes)

    println("Cluster Number:" + clusters.clusterCenters.length)

    println("Cluster Centers Information Overview:")
    for (c <- clusters.clusterCenters) {
      println(" " + c.toString)
    }

    //begin to check which cluster each test data belongs to based on the clustering result

    val rawTestData = sc.textFile(testDataFilePath)
    val parsedTestData = rawTestData.filter(!isColumnNameLine(_)).map(line => {
      Vectors.dense(line.split(",").map(_.trim).filter(!"".equals(_)).map(_.toDouble))
    })

    parsedTestData.collect().foreach(testDataLine => {
      val predictedClusterIndex:
      Int = clusters.predict(testDataLine)

      println("The data " + testDataLine.toString + " belongs to cluster " +
        predictedClusterIndex)
    })

    println("Spark MLlib K-means clustering test finished.")

  }

  private def
  isColumnNameLine(line: String): Boolean = {
    if (line != null &&
      line.contains("2006")) true
    else false
  }


}


已有(1)人评论

跳转到指定楼层
einhep 发表于 2017-6-2 18:10:28
对这个也是刚学:
原始的分类,不是自己分类的吗,如果迭代的话,迭代钱打印出来就好了
回复

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

关闭

推荐上一条 /2 下一条