74-R使用朴素贝叶斯分类器识别垃圾邮件
《机器学习-实用案例解析》学习笔记
1、数据准备
数据下载:https://spamassassin.apache.org/old/publiccorpus/
> library(pacman)
> p_load(chinese.misc,stringr,dplyr)
列出文件夹下所有的文件。每个文件夹中都有一个名为cmds的文件,是一个很长的linux基本命令列表,读取时需要忽略这个文件。
> easy.ham.files <- dir_or_file("./easy_ham")
> easy.ham2.files <- dir_or_file("./easy_ham_2")
> hard.ham.files <- dir_or_file("./hard_ham")
> hard.ham2.files <- dir_or_file("./hard_ham_2")
> spam.files <- dir_or_file("./spam")
> spam2.files <- dir_or_file("./spam_2")
spam为垃圾邮件,easy为容易识别的正常邮件,hard为不易识别的正常邮件。
构造文件读取函数。RFC822协议规定邮件的头部和正文必须使用空行分隔,所以我们只需要读取第一个空行之后的内容。
使用readr包read_file函数读取整个文件,\n为换行,\n\n为一个空行。
> # 数据读取函数
> read_fun <- function(f) {
+ if (!str_detect(f,"cmds")) {
+ f.txt <- readr::read_file(f)
+ # 按第一个空行切割
+ txt <- str_split_fixed(f.txt,"\n\n",2)
+ txt <- txt[1,2]
+ id <- str_extract(f,"\\d+(?=\\.)")
+ df <- tibble(id=id,content=txt)
+ df <- filter(df,content!="")
+ return(df)
+ }
+ }
读取邮件内容,并将同类型邮件合并为一个文件,并做基本的数据整理。
> # 数据基本整理函数
> pre_fun <- function(string) {
+ # 清除空格,必须最先清除
+ string <- str_replace_all(string,"\\s+"," ")
+ # 清除html标识
+ # string <- str_replace_all(string,"<.*?>"," ")
+ # 全部转为小写
+ string <- tolower(string)
+ # 清除非字母字符
+ string <- str_replace_all(string,"[^a-z]"," ")
+ string <- str_replace_all(string,"\\s+"," ")
+ string <- str_trim(string,side = "both")
+ return(string)
+ }
>
> spam1 <- sapply(spam.files,read_fun) %>% do.call(bind_rows,.)
> spam2 <- sapply(spam2.files,read_fun) %>% do.call(bind_rows,.)
> # 合并并清除重复行
> spam <- bind_rows(spam1,spam2) %>% distinct(id,content)
> spam$content <- spam$content %>% pre_fun
最终没有清除HTML标记,是因为它们可能为垃圾邮件的典型特性。
> p_load(text2vec)
>
> it.spam <- itoken(spam$content,
+ ids = spam$id,
+ progressbar = F)
> # 创建训练集词汇表
> vocab.spam <- create_vocabulary(it.spam)
> # 选择至少两个文档都包含的词
> vocab.spam <- prune_vocabulary(vocab.spam,doc_count_min = 2)
> # 去除停用词
> stopword <- readr::read_table("D:/R/dict/english_stopword.txt",
+ col_names = F)
> vocab.spam <- anti_join(vocab.spam,stopword,by=c("term"="X1"))
2、构建垃圾邮件数据
构建特征词项在垃圾邮件中的条件概率,训练分类器,使之能在已知观测特征的前提下计算出邮件是垃圾的概率。
> spam.df <- as_tibble(vocab.spam) %>% rename(frequency=term_count) %>%
+ mutate(density_term=frequency/sum(frequency),
+ occurrence_doc=doc_count/nrow(spam)) %>%
+ select(-doc_count)
>
> arrange(spam.df,-occurrence_doc) %>% head()
## # A tibble: 6 x 4
## term frequency density_term occurrence_doc
## <chr> <int> <dbl> <dbl>
## 1 http 14254 0.0137 0.830
## 2 html 4932 0.00474 0.593
## 3 email 4047 0.00389 0.548
## 4 click 2630 0.00253 0.527
## 5 href 5993 0.00577 0.509
## 6 body 2745 0.00264 0.485
> arrange(spam.df,-frequency) %>% head()
## # A tibble: 6 x 4
## term frequency density_term occurrence_doc
## <chr> <int> <dbl> <dbl>
## 1 font 52565 0.0506 0.476
## 2 td 26787 0.0258 0.352
## 3 br 23361 0.0225 0.469
## 4 size 19393 0.0187 0.485
## 5 tr 16001 0.0154 0.360
## 6 width 14318 0.0138 0.370
HTML标签好像是垃圾邮件中最明显的文本特征,但是如果采用frequency和density_term作为训练数据,就会把包含HTML标签的垃圾邮件权重调得过高,然而并不是所有的垃圾邮件都是这种方式产生的,所以较好的方法是:根据有多少邮件包含这个特征词项(occurrebce_doc)来定义一封邮件是垃圾邮件的概率。
3、构建正常邮件数据
正常邮件包括容易识别的和不易识别的正常邮件,按同样的方法构建数据集。
> normal1 <- sapply(easy.ham.files,read_fun) %>% do.call(bind_rows,.)
> normal2 <- sapply(easy.ham2.files,read_fun) %>% do.call(bind_rows,.)
> normal3 <- lapply(hard.ham.files,read_fun) %>% do.call(bind_rows,.)
> normal4 <- sapply(hard.ham2.files,read_fun) %>% do.call(bind_rows,.)
> # 合并并清除重复行
> normal <- bind_rows(normal1,normal2,normal3,normal4) %>%
+ distinct(id,content)
>
> # 取出与spam相同的行数
> normal <- normal[1:nrow(spam),]
>
> normal$content <- normal$content %>% pre_fun
> it.normal <- itoken(normal$content,
+ ids = normal$id,
+ progressbar = F)
> # 创建训练集词汇表
> vocab.normal <- create_vocabulary(it.normal)
> # 选择至少两个文档都包含的词
> vocab.normal <- prune_vocabulary(vocab.normal,doc_count_min = 2)
> vocab.normal <- anti_join(vocab.normal,stopword,by=c("term"="X1"))
> normal.df <- as_tibble(vocab.normal) %>%
+ rename(frequency=term_count) %>%
+ mutate(density_term=frequency/sum(frequency),
+ occurrence_doc=doc_count/nrow(spam)) %>%
+ select(-doc_count)
>
> arrange(normal.df,-occurrence_doc) %>% head()
## # A tibble: 6 x 4
## term frequency density_term occurrence_doc
## <chr> <int> <dbl> <dbl>
## 1 http 3901 0.0156 0.724
## 2 list 2366 0.00944 0.432
## 3 listinfo 1078 0.00430 0.404
## 4 net 2384 0.00951 0.355
## 5 wrote 979 0.00391 0.338
## 6 mailing 877 0.00350 0.326
两个数据集都构建完毕,但是我们假定每一封邮件是垃圾邮件或正常邮件的概率是相等的,所以,后面我们会从数量更多的正常邮件中选择与垃圾邮件相同数量的邮件作为训练数据。
4、构建分类器
设置先验概率为0.5,对于未出现在训练数据中的词项,直接将概率赋值为0是不合适的,因为第一,没有出现不代表它永远不会在邮件中出现,第二,条件概率是通过乘积来计算的,如果赋值为0,一遇到未知的词项就会把所有的概率值都计算为0。
所以最好的方法是将未知词项的概率设置为很小,比如0.0001%。
> p_load(tidytext)
>
> prior=0.5
> c=1e-6
>
> classify_email <- function(path,train.df) {
+ msg <- read_fun(path)
+ msg$content <- pre_fun(msg$content)
+ # 计算词频,并去除停用词
+ msg.tf <- msg %>% unnest_tokens(term,content) %>% count(id,term) %>%
+ anti_join(stopword,by=c("term"="X1"))
+ # 求交集,找出匹配到的词项
+ msg.match <- intersect(msg.tf$term,train.df$term)
+ if (length(msg.match) < 1) {
+ # 如果全部没有匹配到
+ return(prior*c^(length(msg.tf$n)))
+ } else {
+ # 至少匹配到一个
+ match.prob <- train.df$occurrence_doc[match(msg.match,train.df$term)]
+ return(prior*prod(match.prob)*c^(length(msg.tf$term)-length(msg.match)))
+ }
+ }
msg.match将保存这封邮件中的所有在训练集train.df中出现过的特征词项。如果交集为空,那么msg.match的长度就比1小,于是就用先验概率prior乘以小概率值c的邮件特征数次幂。得到的结果就是这封邮件被分类为垃圾邮件的概率,值很小。
相反,如果这个交集不为空,我们需要找出这些同时出现在训练集和新邮件中的特征词项,然后查出它们在文档中出现的概率(occurrence)。使用match函数完成查找,它能找到词项在训练数据的term列中出现的位置。我们根据这些位置可以从occurrence列中返回特征所对应的文档概率,返回的值保存在match.probs中。然后,计算这些返回值的乘积(prod),并将乘积结果再与下列值相乘:邮件为垃圾邮件的先验概率、特征词项的出现概率以及缺失词项(未出现在训练集中的词项)的小概率。获得的结果就是在已知邮件中有哪些词项出现在训练集中后,对于它是垃圾邮件的贝叶斯概率估计值。
随机抽取较难识别的邮件中的一封做个简单的测试:
> path <- hard.ham2.files[sample(length(hard.ham2.files),1)]
>
> res.spam <- classify_email(path,spam.df)
> res.nrom <- classify_email(path,normal.df)
> res <- ifelse(res.spam > res.nrom,TRUE,FALSE)
> print(res)
## [1] FALSE
返回FALSE,说明识别对了。
5、用更多邮件测试分类器
> classifier <- function(emails) {
+ res.spam <- sapply(emails,
+ function(p) classify_email(p,spam.df))
+ res.nrom <- sapply(emails,
+ function(p) classify_email(p,normal.df))
+ res <- ifelse(res.spam > res.nrom,1,0)
+ results <- data.frame(res.spam=res.spam,
+ res.nrom=res.nrom,
+ result =res)
+ return(results)
+ }
> # 容易识别的正常邮件
> list1 <- easy.ham.files[!str_detect(easy.ham.files,"cmds")]
> result1 <- classifier(list1)
> tab1 <- table(result1$result)
> # 不容易识别的正常邮件
> list2 <- hard.ham.files[!str_detect(hard.ham.files,"cmds")]
> result2 <- classifier(list2)
> tab2 <- table(result2$result)
> # 垃圾邮件
> list3 <- spam.files[!str_detect(spam.files,"cmds")]
> result3 <- classifier(list3)
> tab3 <- table(result3$result)
在正常邮件中,result=0表示识别正确,在垃圾邮件中,result=1表示识别正确。那么,容易识别的正常邮件的结果为:
> prop.table(tab1)
##
## 0 1
## 0.98812116 0.01187884
不易识别的正常邮件的结果为:
> prop.table(tab2)
##
## 0 1
## 0.904 0.096
垃圾邮件的识别结果为:
> prop.table(tab3)
##
## 0 1
## 0.1698302 0.8301698
可以看到,分类器的效果还是相当不错的,容易识别的正确率超过98.8%,不易识别的正确率为90.4%,垃圾邮件的识别正确率约为83%。
用散点图把结果绘制出来:
> p_load(ggplot2)
>
> ggplot() +
+ geom_jitter(data = result1,aes(log(res.nrom),log(res.spam)),
+ col="green",size=1) +
+ geom_jitter(data = result2,aes(log(res.nrom),log(res.spam)),
+ col="blue",size=1) +
+ geom_jitter(data = result3,aes(log(res.nrom),log(res.spam)),
+ col="red",size=1) +
+ geom_abline(intercept = 0,size=1) +
+ theme_bw()
邮件识别情况
从图中也可以直观的看到分类器的效果非常不错。
用对数(log)转换的原因是因为许多预测概率非常小,而另一些又不是那么小,差别悬殊,不太容易直接比较结果,转换后更容易比较。
对角直线y=x是一个简单的决策边界,直线之上代表垃圾邮件,之下代表正常邮件。
y轴上有一些蓝色的点(不易识别的邮件),说明这些点被判定为垃圾邮件的概率大于0,而判定为正常邮件的概率接近0,所以落在了y轴上。
不易识别的邮件(蓝色点)分布在直线两边,说明可能训练集中不易识别的邮件数据不足,导致还有很多与正常邮件相关的特征没有被纳入训练集。
改进:我们默认正常邮件和垃圾邮件的先验概率都为0.5,但现实情况可能并非如此,可能为0.8:0.2。另外,可以通过增加正常邮件的数量来训练模型,所以模型应该还有很大的优化空间。