Spark MLlib中的随机森林(Random Forest)算法原理及实例(Scala/Java/python)

发布于:2021-07-27 01:46:16

随机森林分类器:


算法简介:


? ? ? ? 随机森林是决策树的集成算法。随机森林包含多个决策树来降低过拟合的风险。随机森林同样具有易解释性、可处理类别特征、易扩展到多分类问题、不需特征缩放等性质。


? ? ? ?随机森林分别训练一系列的决策树,所以训练过程是并行的。因算法中加入随机过程,所以每个决策树又有少量区别。通过合并每个树的预测结果来减少预测的方差,提高在测试集上的性能表现。


? ? ? ?随机性体现:
1.每次迭代时,对原始数据进行二次抽样来获得不同的训练数据。


2.对于每个树节点,考虑不同的随机特征子集来进行分裂。


? ? ? ? 除此之外,决策时的训练过程和单独决策树训练过程相同。


? ? ? ? 对新实例进行预测时,随机森林需要整合其各个决策树的预测结果。回归和分类问题的整合的方式略有不同。分类问题采取投票制,每个决策树投票给一个类别,获得最多投票的类别为最终结果。回归问题每个树得到的预测结果为实数,最终的预测结果为各个树预测结果的*均值。


? ? ? ? spark.ml支持二分类、多分类以及回归的随机森林算法,适用于连续特征以及类别特征。


参数:


checkpointInterval:


类型:整数型。


含义:设置检查点间隔(>=1),或不设置检查点(-1)。


featureSubsetStrategy:


类型:字符串型。


含义:每次分裂候选特征数量。


featuresCol:


类型:字符串型。


含义:特征列名。


impurity:


类型:字符串型。


含义:计算信息增益的准则(不区分大小写)。


labelCol:


类型:字符串型。


含义:标签列名。


maxBins:


类型:整数型。


含义:连续特征离散化的最大数量,以及选择每个节点分裂特征的方式。


maxDepth:


类型:整数型。


含义:树的最大深度(>=0)。


minInfoGain:


类型:双精度型。


含义:分裂节点时所需最小信息增益。


minInstancesPerNode:


类型:整数型。


含义:分裂后自节点最少包含的实例数量。


numTrees:


类型:整数型。


含义:训练的树的数量。


predictionCol:


类型:字符串型。


含义:预测结果列名。


probabilityCol:


类型:字符串型。


含义:类别条件概率预测结果列名。


rawPredictionCol:


类型:字符串型。


含义:原始预测。


seed:


类型:长整型。


含义:随机种子。


subsamplingRate:


类型:双精度型。


含义:学*一棵决策树使用的训练数据比例,范围[0,1]。


thresholds:


类型:双精度数组型。


含义:多分类预测的阀值,以调整预测结果在各个类别的概率。


示例:


? ? ? ? 下面的例子导入LibSVM格式数据,并将之划分为训练数据和测试数据。使用第一部分数据进行训练,剩下数据来测试。训练之前我们使用了两种数据预处理方法来对特征进行转换,并且添加了元数据到DataFrame。


Scala:






[plain]?
view plain
?copy





  1. import?org.apache.spark.ml.Pipeline??
  2. import?org.apache.spark.ml.classification.{RandomForestClassificationModel,?RandomForestClassifier}??
  3. import?org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator??
  4. import?org.apache.spark.ml.feature.{IndexToString,?StringIndexer,?VectorIndexer}??
  5. ??
  6. //?Load?and?parse?the?data?file,?converting?it?to?a?DataFrame.??
  7. val?data?=?spark.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt")??
  8. ??
  9. //?Index?labels,?adding?metadata?to?the?label?column.??
  10. //?Fit?on?whole?dataset?to?include?all?labels?in?index.??
  11. val?labelIndexer?=?new?StringIndexer()??
  12. ??.setInputCol("label")??
  13. ??.setOutputCol("indexedLabel")??
  14. ??.fit(data)??
  15. //?Automatically?identify?categorical?features,?and?index?them.??
  16. //?Set?maxCategories?so?features?with?>?4?distinct?values?are?treated?as?continuous.??
  17. val?featureIndexer?=?new?VectorIndexer()??
  18. ??.setInputCol("features")??
  19. ??.setOutputCol("indexedFeatures")??
  20. ??.setMaxCategories(4)??
  21. ??.fit(data)??
  22. ??
  23. //?Split?the?data?into?training?and?test?sets?(30%?held?out?for?testing).??
  24. val?Array(trainingData,?testData)?=?data.randomSplit(Array(0.7,?0.3))??
  25. ??
  26. //?Train?a?RandomForest?model.??
  27. val?rf?=?new?RandomForestClassifier()??
  28. ??.setLabelCol("indexedLabel")??
  29. ??.setFeaturesCol("indexedFeatures")??
  30. ??.setNumTrees(10)??
  31. ??
  32. //?Convert?indexed?labels?back?to?original?labels.??
  33. val?labelConverter?=?new?IndexToString()??
  34. ??.setInputCol("prediction")??
  35. ??.setOutputCol("predictedLabel")??
  36. ??.setLabels(labelIndexer.labels)??
  37. ??
  38. //?Chain?indexers?and?forest?in?a?Pipeline.??
  39. val?pipeline?=?new?Pipeline()??
  40. ??.setStages(Array(labelIndexer,?featureIndexer,?rf,?labelConverter))??
  41. ??
  42. //?Train?model.?This?also?runs?the?indexers.??
  43. val?model?=?pipeline.fit(trainingData)??
  44. ??
  45. //?Make?predictions.??
  46. val?predictions?=?model.transform(testData)??
  47. ??
  48. //?Select?example?rows?to?display.??
  49. predictions.select("predictedLabel",?"label",?"features").show(5)??
  50. ??
  51. //?Select?(prediction,?true?label)?and?compute?test?error.??
  52. val?evaluator?=?new?MulticlassClassificationEvaluator()??
  53. ??.setLabelCol("indexedLabel")??
  54. ??.setPredictionCol("prediction")??
  55. ??.setMetricName("accuracy")??
  56. val?accuracy?=?evaluator.evaluate(predictions)??
  57. println("Test?Error?=?"?+?(1.0?-?accuracy))??
  58. ??
  59. val?rfModel?=?model.stages(2).asInstanceOf[RandomForestClassificationModel]??
  60. println("Learned?classification?forest?model:
    "?+?rfModel.toDebugString)??


Java:






[java]?
view plain
?copy





  1. import?org.apache.spark.ml.Pipeline;??
  2. import?org.apache.spark.ml.PipelineModel;??
  3. import?org.apache.spark.ml.PipelineStage;??
  4. import?org.apache.spark.ml.classification.RandomForestClassificationModel;??
  5. import?org.apache.spark.ml.classification.RandomForestClassifier;??
  6. import?org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator;??
  7. import?org.apache.spark.ml.feature.*;??
  8. import?org.apache.spark.sql.Dataset;??
  9. import?org.apache.spark.sql.Row;??
  10. import?org.apache.spark.sql.SparkSession;??
  11. ??
  12. //?Load?and?parse?the?data?file,?converting?it?to?a?DataFrame.??
  13. Dataset?data?=?spark.read().format("libsvm").load("data/mllib/sample_libsvm_data.txt");??
  14. ??
  15. //?Index?labels,?adding?metadata?to?the?label?column.??
  16. //?Fit?on?whole?dataset?to?include?all?labels?in?index.??
  17. StringIndexerModel?labelIndexer?=?new?StringIndexer()??
  18. ??.setInputCol("label")??
  19. ??.setOutputCol("indexedLabel")??
  20. ??.fit(data);??
  21. //?Automatically?identify?categorical?features,?and?index?them.??
  22. //?Set?maxCategories?so?features?with?>?4?distinct?values?are?treated?as?continuous.??
  23. VectorIndexerModel?featureIndexer?=?new?VectorIndexer()??
  24. ??.setInputCol("features")??
  25. ??.setOutputCol("indexedFeatures")??
  26. ??.setMaxCategories(4)??
  27. ??.fit(data);??
  28. ??
  29. //?Split?the?data?into?training?and?test?sets?(30%?held?out?for?testing)??
  30. Dataset[]?splits?=?data.randomSplit(new?double[]?{0.7,?0.3});??
  31. Dataset?trainingData?=?splits[0];??
  32. Dataset?testData?=?splits[1];??
  33. ??
  34. //?Train?a?RandomForest?model.??
  35. RandomForestClassifier?rf?=?new?RandomForestClassifier()??
  36. ??.setLabelCol("indexedLabel")??
  37. ??.setFeaturesCol("indexedFeatures");??
  38. ??
  39. //?Convert?indexed?labels?back?to?original?labels.??
  40. IndexToString?labelConverter?=?new?IndexToString()??
  41. ??.setInputCol("prediction")??
  42. ??.setOutputCol("predictedLabel")??
  43. ??.setLabels(labelIndexer.labels());??
  44. ??
  45. //?Chain?indexers?and?forest?in?a?Pipeline??
  46. Pipeline?pipeline?=?new?Pipeline()??
  47. ??.setStages(new?PipelineStage[]?{labelIndexer,?featureIndexer,?rf,?labelConverter});??
  48. ??
  49. //?Train?model.?This?also?runs?the?indexers.??
  50. PipelineModel?model?=?pipeline.fit(trainingData);??
  51. ??
  52. //?Make?predictions.??
  53. Dataset?predictions?=?model.transform(testData);??
  54. ??
  55. //?Select?example?rows?to?display.??
  56. predictions.select("predictedLabel",?"label",?"features").show(5);??
  57. ??
  58. //?Select?(prediction,?true?label)?and?compute?test?error??
  59. MulticlassClassificationEvaluator?evaluator?=?new?MulticlassClassificationEvaluator()??
  60. ??.setLabelCol("indexedLabel")??
  61. ??.setPredictionCol("prediction")??
  62. ??.setMetricName("accuracy");??
  63. double?accuracy?=?evaluator.evaluate(predictions);??
  64. System.out.println("Test?Error?=?"?+?(1.0?-?accuracy));??
  65. ??
  66. RandomForestClassificationModel?rfModel?=?(RandomForestClassificationModel)(model.stages()[2]);??
  67. System.out.println("Learned?classification?forest?model:
    "?+?rfModel.toDebugString());??


Python:






[python]?
view plain
?copy





  1. from?pyspark.ml?import?Pipeline??
  2. from?pyspark.ml.classification?import?RandomForestClassifier??
  3. from?pyspark.ml.feature?import?StringIndexer,?VectorIndexer??
  4. from?pyspark.ml.evaluation?import?MulticlassClassificationEvaluator??
  5. ??
  6. #?Load?and?parse?the?data?file,?converting?it?to?a?DataFrame.??
  7. data?=?spark.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt")??
  8. ??
  9. #?Index?labels,?adding?metadata?to?the?label?column.??
  10. #?Fit?on?whole?dataset?to?include?all?labels?in?index.??
  11. labelIndexer?=?StringIndexer(inputCol="label",?outputCol="indexedLabel").fit(data)??
  12. #?Automatically?identify?categorical?features,?and?index?them.??
  13. #?Set?maxCategories?so?features?with?>?4?distinct?values?are?treated?as?continuous.??
  14. featureIndexer?=??
  15. ????VectorIndexer(inputCol="features",?outputCol="indexedFeatures",?maxCategories=4).fit(data)??
  16. ??
  17. #?Split?the?data?into?training?and?test?sets?(30%?held?out?for?testing)??
  18. (trainingData,?testData)?=?data.randomSplit([0.7,?0.3])??
  19. ??
  20. #?Train?a?RandomForest?model.??
  21. rf?=?RandomForestClassifier(labelCol="indexedLabel",?featuresCol="indexedFeatures",?numTrees=10)??
  22. ??
  23. #?Chain?indexers?and?forest?in?a?Pipeline??
  24. pipeline?=?Pipeline(stages=[labelIndexer,?featureIndexer,?rf])??
  25. ??
  26. #?Train?model.??This?also?runs?the?indexers.??
  27. model?

相关推荐

最新更新

猜你喜欢