如题,写了一段spark kmeans的代码,发现能预测新的数据的分类,但是我还想打印原始数据的分类,求解答
package com.marstor.mllib
import org.apache.log4j.PropertyConfigurator
import org.apache.spark.mllib.clustering.{KMeans, KMeansModel}
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.{SparkConf, SparkContext}
/**
* Created by root on 5/24/17.
* 统计国足最近五年的战绩,看国足是属于亚洲几流水平
*/
object KMeansGuoZu {
def main(args: Array[String]): Unit = {
PropertyConfigurator.configure("log4j.properties")
val trainingDataFilePath = "data/ml/guozu.csv"
val testDataFilePath = "data/ml/guozu_train.csv"
val numClusters = 3
val numIterations = 20
val runTimes = 2
val conf = new SparkConf()
.setAppName("Spark MLlib Exercise:KMeansGuoZu Clustering")
.setMaster("local[3]")
val sc = new SparkContext(conf)
/**
* country,2006,2007,2010,2014,2015
* 日本,28,4,9,29,5
* 韩国,17,3,15,27,2
*/
val rawTrainingData = sc.textFile(trainingDataFilePath)
val parsedTrainingData =
rawTrainingData.filter(!isColumnNameLine(_)).map(line => {
Vectors.dense(line.split(",").map(_.trim).filter(!"".equals(_)).map(_.toDouble))
}).cache()
val clusters: KMeansModel =
KMeans.train(parsedTrainingData, numClusters, numIterations, runTimes)
println("Cluster Number:" + clusters.clusterCenters.length)
println("Cluster Centers Information Overview:")
for (c <- clusters.clusterCenters) {
println(" " + c.toString)
}
//begin to check which cluster each test data belongs to based on the clustering result
val rawTestData = sc.textFile(testDataFilePath)
val parsedTestData = rawTestData.filter(!isColumnNameLine(_)).map(line => {
Vectors.dense(line.split(",").map(_.trim).filter(!"".equals(_)).map(_.toDouble))
})
parsedTestData.collect().foreach(testDataLine => {
val predictedClusterIndex:
Int = clusters.predict(testDataLine)
println("The data " + testDataLine.toString + " belongs to cluster " +
predictedClusterIndex)
})
println("Spark MLlib K-means clustering test finished.")
}
private def
isColumnNameLine(line: String): Boolean = {
if (line != null &&
line.contains("2006")) true
else false
}
}
|