分享

使用机器学习算法和大数据工具集来预测已知的心脏疾病(附源码)

本帖最后由 howtodown 于 2016-11-24 09:58 编辑
问题导读:

1. 使用什么样的数据集来进行预测?
2. 使用哪些工具与技术去实现预测系统?
3. 如何对预测系统进行设计架构?

4. 如何对训练集数据进行分析?

5. 如何构造模型对象并进行实际预测?
6. 对于这个预测系统有什么缺点?




解决方案:

大数据和机器学习的组合是一项革命性的技术,如果以恰当的方式使用它,它可以在任何工业上产生影响。在医疗保健领域,它在很多情况下都有重要的使用,例如疾病检测、找到流行病早期爆发的迹象、使用集群来找到瘟疫流行的地区(例如寨卡(zika)易发区),或者在空气污染严重的国家找到空气质量最好的地带。在这篇文章里,我尝试用标准的机器学习算法和像 Apache Spark、parquet、Spark mllib和Spark SQL这样的大数据工具集,来探索已知的心脏疾病的预测。

源代码

这篇文章的源代码可以在GitHub的这里找到。此外,你可以从这里check out出整个eclipse项目。


使用的数据集

心脏疾病数据集是一个已经被机器学习研究人员深入研究过的数据集,它可以在UCI机器学习数据集仓库的这里免费获取。在这里有4个数据集,我已经使用了有14个主要特点的克利夫兰的数据集。这个数据集的的功能或属性如下:

  • age- 用年数表示的年龄
  • sex- 性别枚举(1 = 男性; 0 = 女性)
  • cp: 胸部疼痛的类型
      值为 '1': 典型的心绞痛
      值为 '2': 非典型的心绞痛
      值为 '3': 非心绞痛的疼痛
      值为 '4': 无临床症状
  • trestbpss: 静息血压 (准许入院的毫米汞柱(mm Hg))
  • chol:  以mg/dl为单位的血清类固醇
  • fbs: (空腹血糖 > 120 mg/dl) (1 = 是; 0 = 否)
  • restecg: 静息心电图结果

      值为 0: 正常

      值为 1: 有ST-T波异常 (T波倒置和/或ST段抬高或压低>0.05 mV)
      值为 2: 显示该标准下可能或明确的的左心室肥厚

  • thalach : 达到的最大心率
  • exang : 是否运动诱发的心绞痛 (1 = 是; 0 = 否)
  • oldpeak : 由与相对休息有明显差异的运动诱导的ST段压低
  • slope : 运动高峰期的ST段斜率
      值为 1: 上斜
      值为 2: 水平
  • ca : 被透视荧光检查(flourosopy)标注颜色的大血管的数量 (0-3)
  • thal : 3 = 正常; 6 = 固有缺陷; 7 = 可修复缺陷
  • num : 心脏病的诊断 (冠状动脉疾病状态)
      值为 0: < 50% 直径缩小  (意味着'没有疾病')
      值为 1: > 50% 直径缩小  (意味着'出现了疾病')


使用的技术

  • Apache Spark: Apache Spark是大数据栈的其中一个工具集,它是老技术map reduce的老大哥。相比于mapreduce,它在性能上要快得多,而且也更容易撰写代码。很多开发者常用的 RDD(弹性分布式数据集)是整个Apache Spark 块的缺陷所在,但在幕后,它很好的处理了所有的分布式计算工作。Spark配备了其他像Spark streaming、 Spark sql(在这篇文章中我用它来分析数据集)、spark mllib (我用它来应用机器学习片)这样很强大的组件包。从Spark官网能获取到的Spark的文档都非常出色,你可以在这里找到它们。
  • Spark SQL: Spark的类SQL API,支持数据帧 (和Python的Pandas library几乎相同,但它运行在一个完整的分布式数据集,因此并不所有功能类似)。
  • Parquet: Parquet是列式文件格式。原始数据文件用parquet格式被解析和存储。这大大加快了聚合查询的速度。一个列式存储格式在只获取需要的列的数据时大有帮助,也因此大大减少磁盘I / O消耗。
  • Spark MLLib: Spark的机器学习库。该库中的算法都是被优化过,能够分布式数据集上运行的算法。这是这个库和像SciKit那样在单进程上运行的其他流行的库的主要区别。
  • HDFS : 用于存储原始文件,存储生成的模型并存储结果。



设计

模型生成和存储层


屏幕快照 2016-11-23 上午10.15.52.jpg

如上图所示,原始文件要么被HDFS获取,要么被程序导入到HDFS。该文件或数据也可以通过Kafka的topics接收和使用spark streaming读取。对于本文和在GitHub上的示例代码的例子,我假设原文件驻留在HDFS。

这些文件通过用Java(也可以是python或scala )编写的Spark程序读取。

这些文件包含必须被转换为模型所需要的格式的数据。该模型需要的全是数字。 一些为空或没有值的数据点会被一个大的值,如“99”,取代。这种取代没有特定的意义,它只帮助我们通过数据的非空校验。同样的,最后的“num”参数基于用户是否有心脏病转换为数字“1”或“0”。因此在最后的“num”字段中,大于“1”的任何值会被转换为“1”,这意味着心脏病的存在。

数据文件现在被读到RDD去了。

对于这个数据集,我使用了朴素贝叶斯算法(这个算法在垃圾邮件过滤器中被使用)。利用机器学习库Spark (mllib),算法现在在被数据集中的数据训练。请注意:决策树算法在这个例子中可能也能给出很好的结果。

算法训练后,模型被存储到了hdfs额外的存储空间,用于在将来对测试数据进行预测。

下面是上面描述的行为的一段代码的截取:


[mw_shl_code=java,true]SparkConfAndCtxBuilder ctxBuilder = new SparkConfAndCtxBuilder();
JavaSparkContext jctx = ctxBuilder.loadSimpleSparkContext("Heart Disease Detection App", "local");

//读取数据到RDD,数据是逐行分割的字符串格式
JavaRDD<String> dsLines = jctx.textFile(trainDataLoc);
        // 使用适配器类解析每个文本行
        // 现在数据已经被转换成模型需要的格式了
JavaRDD<LabeledPoint> _modelTrainData = dsLines.map(new DataToModelAdapterMapper());

    //我们需要的模型在被数据训练
        //你可以替代下面的代码,来尝试使用决策树模型,并比较返回数据的精度
NaiveBayesModel _model = NaiveBayes.train(_modelTrainData.rdd());
_model.save(jctx.sc(), modelStorageLoc);

ctxBuilder.closeCtx();[/mw_shl_code]

应用到上面每个数据行的mapper类的截取代码如下:

[mw_shl_code=java,true]public LabeledPoint call(String dataRow) throws Exception {
    //用一个很大的值替代空的数据点,来避免不必要的空数据点
    String newLine = dataRow.replaceAll("\\?", "99.0");
    String[] tokens = newLine.split(",");

    System.out.println("tokens count : " + tokens.length);
    // 上一个token有被训练模型使用的实际预测值
    Integer lastToken = Integer.parseInt(tokens[13]);

    double[] featuresDblArr = new double[13];
    for(int i = 0; i < 13; i++) {
        featuresDblArr = Double.parseDouble(tokens);
    }
    // 构建特征向量
    Vector featuresVector = new DenseVector(featuresDblArr);

    Double classValue = 0.0;
    if(lastToken.intValue() > 0 ) classValue = 1.0;

    LabeledPoint _lp = new LabeledPoint(classValue, featuresVector);

    return _lp;
}
[/mw_shl_code]


数据分析层

此层用于训练数据的分析,这些分析数据可以用于查询像患者的最小年龄、所有患病女性和患病男性的总数对比的情况。这些查询的参数几乎总是在疾病出现的,或虽然没有病但出现了症状的人的情况下出现。

要在训练数据上运行数据分析,首先,要加载完整的数据(被清除了空值的数据)到rdd使用的一个文本文件。

然后用parquet格式保存这个rdd文本文件到额外存储空间。

从另一个程序加载数据到这个parquet存储空间的数据帧。

点击这里你可以看到下面这段截取代码的完整源码。


[mw_shl_code=java,true]String schemaString = "age sex cp trestbps chol fbs restecg thalach exang oldpeak slope ca thal num";
List<StructField> fields = new ArrayList<>();
for (String fieldName : schemaString.split(" ")) {
    fields.add(DataTypes.createStructField(fieldName, DataTypes.StringType, true));
}

StructType schema = DataTypes.createStructType(fields);
JavaRDD<Row> rowRdd = rows.map(new Function<String, Row>() {

    @Override
    public Row call(String record) throws Exception {
        String[] fields = record.split(",");
        return RowFactory.create(fields[0],fields[1],fields[2],fields[3],fields[4],fields[5],fields[6],fields[7],fields[8],fields[9],fields[10],fields[11],fields[12],fields[13]);
    }
});

DataFrame df = sqlCtx.createDataFrame(rowRdd, schema);
df.registerTempTable("heartDisData");

DataFrame results = sqlCtx.sql("select min(age) from heartDisData");

JavaRDD<String> jrdd = results.javaRDD().map(new Function<Row, String>() {

    @Override
    public String call(Row arg0) throws Exception {
        return arg0.toString();
    }
});

List<String> rstList = jrdd.collect();

for (String rStr : rstList) {
    System.out.println(" Minimum Age : " + rStr);
}[/mw_shl_code]

疾病预测层(参考Github的代码)

屏幕快照 2016-11-23 上午10.24.55.jpg

现在,使用Apache Spark加载测试数据到一个RDD。
对测试数据做模型适配和清除。
使用spark mllib从存储空间加载模型。
使用模型对象来预测疾病的出现。例如:

[mw_shl_code=java,true]NaiveBayesModel _model = NaiveBayesModel.load(<Spark Context>, <Model Storage Location>);[/mw_shl_code]
代码的截取如下所示,你可以点击这里看Github上的完整源码。

[mw_shl_code=java,true]SparkConfAndCtxBuilder ctxBuilder = new SparkConfAndCtxBuilder();
JavaSparkContext jctx = ctxBuilder.loadSimpleSparkContext("Heart Disease Detection App", "local");

JavaRDD<String> dsLines = jctx.textFile(testDataLoc);
JavaRDD<Vector> fRdd = dsLines.map(new TestDataToFeatureVectorMapper());

NaiveBayesModel _model = NaiveBayesModel.load(jctx.sc(), modelStorageLoc);

JavaRDD<Double> predictedResults = _model.predict(fRdd);
List<Double> prl = predictedResults.collect();
for (Double pr : prl) {
    System.out.println("Predicted Value : " + pr);
}[/mw_shl_code]

上述设计的问题

任何疾病预测系统的最重要的问题是准确度。一个错误的阴性的结果可能是一个危险的预测,它可能导致一种疾病被忽视。

深度学习已经发展到能够比普通机器学习算法提供更好的预测。在之后,我将尝试探索通过深度学习神经网络做同样的疾病预测。

总结

使用像 Apache Spark这样的工具和它的机器学习库,我们能够轻易地加载到一个心脏病数据集(从UCI),并训练常规机器学习模型。这个模型稍后会在测试数据上运行,用来预测心脏疾病的出现。


来源:可译网
作者:Rajat Mehta

已有(3)人评论

跳转到指定楼层
、一毛 发表于 2016-11-24 09:15:03
where is the github url ?
回复

使用道具 举报

howtodown 发表于 2016-11-24 09:57:21
回复

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

关闭

推荐上一条 /2 下一条