naive bayes识别垃圾短信
本文对brett的机器学习与R语言(Machine Learning with R)一书中的垃圾短信识别的笔记。在brett的书中,介绍了如何通过naive bayes对短信进行训练,并预测短信是否为垃圾短信。naive bayes的精度可以达到98.06%,召回率为87.6%,准确率为97.6%。
在kaggle上有一个类似的dataset,叫SMS Spam Collection Dataset,kaggle spam sms。Brett书中的数据是对该数据源的修正,内容上基本类似,这里使用kaggle的数据集。
首先加载需要用到的包
require(data.table) # read the csv file in a faster way
require(magrittr) # enable pipeline operator
require(tm) # construct text vector
require(e1071) # naive bayes
require(gmodels) # create contigency table
require(caret) # split data into training and test sets
首先通过data.table包来读取数据
sms_raw <- fread(
"../data/SMSSpamCollection",
header = FALSE,
encoding = "Latin-1",
sep = "\t"
)
需要注意的是,这里读取的是kaggle上的数据集,同时也在UCI上存放,uci sms。SMSSpamCollection有两列数据,第一列表示短信是否为垃圾短信,spam表示垃圾短信,ham表示正常短信;第二列是具体的短信内容。将SMSSpamCollection的内容读入R语言,并保存在sms_raw变量中,通过setnames修改列名。
setnames(sms_raw, c("type", "text"))
sms_raw[, type:= factor(type)]
sms_raw有5574行文本数据。
sms_raw[, .N]
[1] 5574
可以大致看看sms_raw中垃圾短信与正常短信的条数
table(sms_raw[, type]) %>% prop.table()
ham spam
0.8659849 0.1340151
其中有86.6%左右的短信是正常短信,有13.4%的短信为垃圾短信。下面分别看一下垃圾短信与正常短信的云图,看看二者在文本内容上是否有显著区别。
require(dplyr)
require(wordcloud)
require(RColorBrewer)
pal <- brewer.pal(7, "Dark2")
sms_raw[type == "spam", text] %>%
wordcloud(min.freq = 20,
random.order = FALSE, colors = pal
)
sms_raw[type == "ham", text] %>%
wordcloud(min.freq = 70,
random.order = FALSE, colors = pal
)
通过如上的程序,分别得到spam有ham短信的云图,可以从云图上可以得知,垃圾短信中以较大的概率出现free,而正常短信更多的是一些常规描述的词汇。根据出现的单词,可以大致判断出短信是否是垃圾短信。
ham_wordcloud.jpg spam_wordcloud.jpg本文采用naive bayes的方法识别垃圾短信,根据naive bayes的条件独立的假设,若确定短信的是垃圾短信与否之后,短信内的单词相互独立,相互不影响。虽然这个假设与事实不符,比如若垃圾短信中出现buy,那么出现on sale的概率会相对而言更大(相对于不知道buy这个单词,此处是条件独立的转述),但是这对结果却没多大影响。
为评估naive bayes的性能,将数据集分成训练集和测试集,75%的数据用于训练,25%的数据用于评估算法的性能。
set.seed(1071)
train_index <- createDataPartition(sms_raw$type, p = 0.75, list = FALSE)
sms_raw_train <- sms_raw[train_index, ]
sms_raw_test <- sms_raw[-train_index, ]
其中设置随机数是为了结果的可重复性,1071是我常用的一个随机数,是R语言中著名的机器学习包e1071的数字部分。createDataPartition则来自于caret包,该函数通过对数据进行抽样,保证训练集与测试集中,垃圾短信的比例一致,避免训练集中出现大量的正常短信,而几乎没有垃圾短信这样的情况。createDataPartition 的第一个参数是vector,函数根据这个参数内容进行抽样,p=0.75表示75%的数据进入训练集,则有25%的数据进入测试集,list=FASLE表示返回结果的格式为常规的数组,否则将返回一个列表,更具体的用法可以参考对应的帮助文档。
下面看看训练集与测试集中spam/ham邮件的分布情况,通过上述的抽样,训练集与测试集中邮件分布一致。
table(sms_raw_train[, type]) %>% prop.table()
ham spam
0.8658537 0.1341463
table(sms_raw_test[, type]) %>% prop.table()
ham spam
0.8663793 0.1336207
为简化后续的文本处理,定义两个辅助函数,corpus生成语料库,clean函数则对语料库进行一些清洗,比如删除数字,stopwords,标点符号,首尾的空白字符等。
corpus <- function(x) VectorSource(x) %>% VCorpus(readerControl = list(reader = readPlain))
clean <- function(x) {
x %>%
tm_map(content_transformer(tolower)) %>%
tm_map(content_transformer(removeNumbers)) %>%
tm_map(content_transformer(removeWords), stopwords()) %>%
tm_map(content_transformer(removePunctuation)) %>%
tm_map(content_transformer(stripWhitespace))
}
这里有一个坑,如果使用常规的Corpus函数代替上述的VCorpus,则有可能导致后续的预测出现反常的现象,比如大部分的spam邮件预测为ham。具体原因我没有细究,在调试跟踪tmt包多个函数后,发现若使用Corpus函数,在对训练集的文本构建document-term frequency矩阵时,出现错误的结果。所以解决方法是使用VCorpus,并指定reader,在后续的操作中,使用content_transformer包装所有的处理函数。
通过上述的辅助函数,构建训练集的预料数据
sms_corpus_train <- corpus(sms_raw_train[, text]) %>% clean
可以通过inspect函数查看对应的corpus
sms_raw_train[1, text]
[1] "Ok lar... Joking wif u oni..."
inspect(sms_corpus_train[[1]])
<<PlainTextDocument>>
Metadata: 7
Content: chars: 23
ok lar joking wif u oni
可知原始的文本,已经转换成小写单词,且删除了标点符号。
下一步根据corpus构建文档词频矩阵(document-term frequency)
sms_dtm_train_all <- DocumentTermMatrix(sms_corpus_train)
可以删除出现次数过少的单词,这些单词出现较少,删除这些单词对预测的结果没有(估计)影响。
sms_dict <- findFreqTerms(sms_dtm_train_all, 5)
删除词频少于5的单词,剩下的单词做为后续构建dtm的单词表sms_dict,sms_dict其实是字符串列表,是那些出现频次超过5次的单词。
根据前文构建的单词表,重新构建训练集和测试集的dtm矩阵。
sms_dtm_train <- DocumentTermMatrix(
sms_corpus_train, control = list(dictionary = sms_dict)
)
sms_dtm_test <- DocumentTermMatrix(
sms_corpus_test, control = list(dictionary = sms_dict)
)
DocumentTermMatrix返回的是一种特殊的矩阵,类似于稀疏矩阵,是继承于slam包的simple_triplet_matrix,不过不需要深入了解底层的结构,把sms_dtm_train当成普通的矩阵看待即可。
sms_dtm_train_all %>% class
[1] "DocumentTermMatrix" "simple_triplet_matrix"
在naive bayes的算法中,计算的是单词表中,每个单词出现与否的概率,上述产生的dtm矩阵记录的是每条短信中,每个单词的出现次数,因此需要做进一步的转换。若出现次数大于0次,表明该单词出现在短信中,需要做一定的操作。
convert_counts <- function(x) {
x <- ifelse(x > 0, "Yes", "No")
}
sms_train <- sms_dtm_train %>%
apply(MARGIN = 2, convert_counts)
sms_test <- sms_dtm_test %>%
apply(MARGIN = 2, convert_counts)
通过上述的方法,得到一个跟dtm矩阵同样大小的矩阵(因为使用了apply,sms_train是普通的矩阵),且各元素是字符“Yes”或者“No”。
下面调用e1071中的naive bayes函数,对训练数据进行模型训练,得到模型nb_model_0,并在测试集上使用该模型预测。(预测可能需要数秒钟的时间,后续将会介绍更快的glmnet方法对垃圾短信的识别,到时训练和预测都能比较快速的完成)
nb_model_0 <- naiveBayes(sms_train, sms_raw_train$type)
pred_0 <- predict(nb_model_0, sms_test)
为了看预测结果的精度,构建confusion matrix
CrossTable(
sms_raw_test$type, pred_0,
prop.t = FALSE, prop.chisq = FALSE,
dnn = c("actual", "pred")
)
Cell Contents
|-------------------------|
| N |
| N / Row Total |
| N / Col Total |
|-------------------------|
Total Observations in Table: 1392
| pred
actual | ham | spam | Row Total |
-------------|-----------|-----------|-----------|
ham | 1202 | 4 | 1206 |
| 0.997 | 0.003 | 0.866 |
| 0.981 | 0.024 | |
-------------|-----------|-----------|-----------|
spam | 23 | 163 | 186 |
| 0.124 | 0.876 | 0.134 |
| 0.019 | 0.976 | |
-------------|-----------|-----------|-----------|
Column Total | 1225 | 167 | 1392 |
| 0.880 | 0.120 | |
-------------|-----------|-----------|-----------|
使用naive bayes预测垃圾短信,有(1202+163)/1392 = 98.06%的短信被正确分类,spam邮件的召回率为87.6%,准确率为97.6%,这就是本文最开始提到的结果。没有使用特别复杂的特征提取方法,仅仅通过naive bayes,就达到98%的准确分类,效果良好。