使用spark mllib 随机森林算法对文本进行多分类

阅读: 评论:0

使用spark mllib 随机森林算法对文本进行多分类

使用spark mllib 随机森林算法对文本进行多分类

1、数据准备

20W人工标注文本数据,样本如下:

1#k-v#*亮亮爱宠*波波宠物指甲钳指甲剪附送锉刀适用小型犬及猫特价
1#k-v#*顺丰包邮*宠物药品圣马利诺PowerIgG免疫力球蛋白犬猫细小病毒
1#k-v#*包邮*法国罗斯蔓草本精华宠物浴液薰衣草护色润泽香波拍套餐
1#k-v#*包邮*家朵102宠物沐浴液
1#k-v#*包邮*家朵102宠物沐浴液猫

2、分词

使用ansj包对文本数据去除停用词分词。代码如下:

import java.io.File;
import java.io.IOException;
import java.util.HashSet;
import java.util.List;
import java.util.Set;import org.ansj.domain.Result;
import org.ansj.domain.Term;
import org.ansj.splitWord.analysis.ToAnalysis;
import org.apachemons.io.FileUtils;
import org.apachemons.lang3.StringUtils;public class Seg{private static Set<String> stopwords = new HashSet<String>();static{File f = new File("");try {List<String> lines = adLines(f);for(String str : lines){stopwords.add(str);}} catch (IOException e) {e.printStackTrace();}}public static void main(String[] args) throws IOException {File f = new File("");File resultFile = new File("");List<String> lists = adLines(f);int count = 0;for(String str : lists){count++;String index = str.split("#k-v#")[0];
//          System.out.println(count + " " + Integer.parseInt(index));Result res = ToAnalysis.parse(str.split("#k-v#")[1]);List<Term> terms  = Terms();String wordStr = "";for(Term t : terms){String word = t.getName();if(word.length()>1&&!ains(word)){wordStr = wordStr + " " +  word;}}if(!StringUtils.isEmpty(wordStr)){FileUtils.write(resultFile, index + "#k-v#" + wordStr + "n" , true);}System.out.println(count);}}

3、对分词数据进行tfidf转换

这里我用到工具是sparkmllib的tfidf带的包,代码如下:

import org.apache.spark.ml.feature.{HashingTF, IDF, Tokenizer}import org.apache.spark.sql.SQLContext
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.pes.StructField
import org.apache.pes.StructType
import org.apache.pes.StringType
import org.apache.spark.sql.Row//case class FileRecord(index:Int,seg: String)object TfIdf {def main(args: Array[String]) {val conf = new SparkConf().setAppName("TfIdfExample")val sc = new SparkContext(conf)val sqlContext = new SQLContext(sc)val schemaString = "index seg"val fields = schemaString.split(" ").map(fieldName => StructField(fieldName, StringType, nullable = true))val schema = StructType(fields)val srcRDD = sc.textFile("/tmp/", 1).map(x => x.split("#k-v#")).map(attributes => Row(attributes(0), attributes(1).trim))val sentenceData = ateDataFrame(srcRDD, schema).toDF("label", "seg")val tokenizer = new Tokenizer().setInputCol("seg").setOutputCol("words")val wordsData = ansform(sentenceData)val hashingTF = new HashingTF().setInputCol("words").setOutputCol("rawFeatures").setNumFeatures(26)val featurizedData = ansform(wordsData)val idf = new IDF().setInputCol("rawFeatures").setOutputCol("features")val idfModel = idf.fit(featurizedData)val rescaledData = ansform(featurizedData)rescaledData.select("features", "label").take(3).foreach(println)rescaledData.select("features", "label").write.format("json").save("/del")}
}

得到的是json数据格式,示例数据如下:

{"features":{"type":0,"size":26,"indices":[0,5,6,7,9,10,14,17,21],"values":[2.028990788466258,1.8600672974067514,1.8464729103095205,2.037399707294254,1.908861495143531,3.6260607728633083,2.0363086347259687,1.8261747092361593,2.0640809711702492]},"label":"1"}
{"features":{"type":0,"size":26,"indices":[7,8,17],"values":[4.074799414588508,2.1216332358971366,1.8261747092361593]},"label":"1"}

4、json数据转libsvm数据格式

因为sparkmllib中随机森林算法需libsvm数据格式,故进行转换,代码如下:

    File f = new File("D:/sogouOutput/json_feature");File libsvmFile = new File("D:/sogouOutput/libsvm_feature");List<String> features = adLines(f);for(String str : features){JSONObject obj = new JSONObject(str);String label = String("label");JSONArray indexArr = JSONObject("features").getJSONArray("indices");JSONArray valueArr = JSONObject("features").getJSONArray("values");int length = indexArr.length();String line = label + " ";Map<Integer,Double> indiceAndValue = new TreeMap<Integer,Double>();for(int i=0;i<length;i++){indiceAndValue.Int(i), Double(i));
//              line = line + Int(i)+":" + Double(i) + " ";}//特征索引不能为0,不知为什么。ainsKey(0)){ve(0);}for(Map.Entry<Integer, Double> m : Set()){line = line + m.getKey()+":" + m.getValue() + " ";}
//          System.out.println(StringUtils.substring(line, 0, -1));FileUtils.write(libsvmFile, StringUtils.substring(line, 0, -1) + "n", true);}

结果示例数据如下:

1 7:2.037399707294254 
1 1:1.6033119355738932 5:1.8600672974067514 7:4.074799414588508 10:1.8130303864316542 13:2.0344821501999344 15:2.2043195316439834 18:2.0104112775954426 20:2.0108489143639154 25:1.9189925465072746 
1 3:5.510668692397079 5:1.8600672974067514 6:1.8464729103095205 7:4.074799414588508 17:1.8261747092361593 
1 3:5.510668692397079 5:1.8600672974067514 6:1.8464729103095205 7:2.037399707294254 13:2.0344821501999344 17:1.8261747092361593 20:2.0108489143639154 
1 7:2.037399707294254 

5、分类

分类代码如下:

import org.apache.spark.sql.SQLContext
import org.apache.spark.{SparkConf, SparkContext}
// $example on$
import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.classification.{RandomForestClassificationModel, RandomForestClassifier}
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
import org.apache.spark.ml.feature.{IndexToString, StringIndexer, VectorIndexer}
// $example off$object RandomForestClassifierExample {def main(args: Array[String]): Unit = {val conf = new SparkConf().setAppName("RandomForestClassifierExample")val sc = new SparkContext(conf)val sqlContext = new SQLContext(sc)// $example on$// Load and parse the data file, converting it to a DataFrame.val data = ad.format("libsvm").load("/tmp/libsvm_feature")// Index labels, adding metadata to the label column.// Fit on whole dataset to include all labels in index.//待征索引必须升序val labelIndexer = new StringIndexer().setInputCol("label").setOutputCol("indexedLabel").fit(data)// Automatically identify categorical features, and index them.// Set maxCategories so features with > 4 distinct values are treated as continuous.val featureIndexer = new VectorIndexer().setInputCol("features").setOutputCol("indexedFeatures").setMaxCategories(26).fit(data)// Split the data into training and test sets (30% held out for testing)val Array(trainingData, testData) = data.randomSplit(Array(0.7, 0.3))// Train a RandomForest model.val rf = new RandomForestClassifier().setLabelCol("indexedLabel").setFeaturesCol("indexedFeatures").setNumTrees(10)// Convert indexed labels back to original labels.val labelConverter = new IndexToString().setInputCol("prediction").setOutputCol("predictedLabel").setLabels(labelIndexer.labels)// Chain indexers and forest in a Pipelineval pipeline = new Pipeline().setStages(Array(labelIndexer, featureIndexer, rf, labelConverter))// Train model.  This also runs the indexers.val model = pipeline.fit(trainingData)// Make predictions.val predictions = ansform(testData)// Select example rows to display.predictions.select("predictedLabel", "label", "features").show(5)// Select (prediction, true label) and compute test errorval evaluator = new MulticlassClassificationEvaluator().setLabelCol("indexedLabel").setPredictionCol("prediction").setMetricName("precision")val accuracy = evaluator.evaluate(predictions)println("Test Error = " + (1.0 - accuracy))val rfModel = model.stages(2).asInstanceOf[RandomForestClassificationModel]println("Learned classification forest model:n" + DebugString)// $example off$sc.stop()}
}

在运行过程中,val labelIndexer = new StringIndexer().setInputCol(“label”).setOutputCol(“indexedLabel”).fit(data)
这句代码会报错:

Caused by: java.lang.IllegalArgumentException: requirement failed: indices should be one-based and in ascending order

经查找是因为特征索引不能为0,看它源代码是index作了-1处理导致的。

private[spark] def parseLibSVMRecord(line: String): (Double, Array[Int], Array[Double]) = {val items = line.split(' ')val label = Doubleval (indices, values) = items.tail.filter(_.nonEmpty).map { item =>val indexAndValue = item.split(':')val index = indexAndValue(0).toInt - 1 // Convert 1-based indices to 0-based.val value = indexAndValue(1).toDouble(index, value)}.unzip // check if indices are one-based and in ascending ordervar previous = -1var i = 0val indicesLength = indices.lengthwhile (i < indicesLength) {val current = indices(i)require(current > previous, s"indices should be one-based and in ascending order;"+ s""" found current=$current, previous=$previous; line="$line"""")previous = currenti += 1}(label, Array, Array)}

本文发布于:2024-01-29 12:02:41,感谢您对本站的认可!

本文链接:https://www.4u4v.net/it/170650096415140.html

版权声明:本站内容均来自互联网,仅供演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系,我们将在24小时内删除。

标签:算法   文本   森林   spark   mllib
留言与评论(共有 0 条评论)
   
验证码:

Copyright ©2019-2022 Comsenz Inc.Powered by ©

网站地图1 网站地图2 网站地图3 网站地图4 网站地图5 网站地图6 网站地图7 网站地图8 网站地图9 网站地图10 网站地图11 网站地图12 网站地图13 网站地图14 网站地图15 网站地图16 网站地图17 网站地图18 网站地图19 网站地图20 网站地图21 网站地图22/a> 网站地图23