一个简单的spark贝叶斯分类程序

2018-08-21  本文已影响0人  snow_14b5

在笔记本跑了一个简单的贝叶斯分类示例,工程级的代码原理类似,只不过有些细节需要修改。

主要代码如下:

import org.apache.spark.sql.{DataFrame, Row, SparkSession}
import org.apache.spark.ml.feature.{HashingTF, }
import org.apache.hadoop.fs.Path
import org.apache.spark.ml.classification.NaiveBayes
import org.apache.spark.ml.feature.

object bayes {
def main(args: Array[String]) {
val spark = SparkSession
.builder
.appName("bayes")
.getOrCreate()

import spark.implicits._

val sentenceDataFrame = spark.createDataFrame(Seq(    //比较简单的样本数据 0分类 水果; 1分类 粮食
  (0,"水果","苹果 橘子 香蕉"),
  (1, "粮食","大米 小米 土豆")
)).toDF("label","category", "text")

val tokenizer = new Tokenizer().setInputCol("text").setOutputCol("words")
var wordData = tokenizer.transform(sentenceDataFrame)

val stopwordFile: String = "/applications/stopWords"        //引入停用词

val customizedStopWords: Array[String] = if (stopwordFile.isEmpty()) {
  Array.empty[String]
} else {
  val stopWordText = spark.read.text(stopwordFile).as[String].collect()
  stopWordText.flatMap(_.stripMargin.split("\\s+"))
}

val stopWordsRemover = new StopWordsRemover()
  .setInputCol("words")
  .setOutputCol("token")
stopWordsRemover.setStopWords(stopWordsRemover.getStopWords ++ customizedStopWords)
var wordDataWithOutStopWord = stopWordsRemover.transform(wordData)

var hashingTF = new HashingTF()
  .setInputCol("token").setOutputCol("tf")
val tf= hashingTF.transform(wordDataWithOutStopWord)
tf.cache()
tf.show(false)

val idf=new IDF().setInputCol("tf").setOutputCol("features").fit(tf)  //根据以上数据训练的idf模型,实际需要根据大量数据训练
val tfidf =idf.transform(tf)
tfidf.show(false)


val naiveBayesModel = new NaiveBayes()  //创建贝叶斯模型,用上面数据训练
  .setSmoothing(1)
  .fit(tfidf)

val training = spark.createDataFrame(List(  //待预测的测试数据
  (0, "大米")
)).toDF("id", "text")


var tokenfeature = tokenizer.transform(training)
wordDataWithOutStopWord = stopWordsRemover.transform(tokenfeature)

var trainRescaledData = hashingTF.transform(wordDataWithOutStopWord)
val tfidf1 = idf.transform(trainRescaledData)
val predictions = naiveBayesModel
  .transform(tfidf1)
predictions.printSchema()
val predict = predictions.first().getAs[Double]("prediction")   //预测结果 输出label 为1 粮食分类

println("predict aaaaa:")
println(predict)


spark.stop()

}

}

上一篇 下一篇

猜你喜欢

热点阅读