生信数学

基于EM算法的聚类R包mclust

2021-11-18  本文已影响0人  小潤澤

简介

mclust是一种基于高斯混合分布,利用EM算法聚类的方法:mclust文献转送

软件提供的高斯模型一共有14个:



这14个模型投影到二维空间呈现出如下的形式:



每个圆圈相当于高斯密度分布(pdf)的定义域(二维),这里作者给出了聚成3类的情况

而每个圆圈(定义域)还原成三维的高斯密度分布曲线如下图:



其中x,y轴代表两个变量(waiting和eruptions);而z轴代表密度,很显然在这幅图里面分为了两个高斯密度分布,也可以理解为分为了两个类

EM算法与高斯混合分布

EM算法比较经典的一种运用就是求解高斯混合分布,举个不那么恰当的例子,男女生身高的问题,一般情况下男生身高要高于女生,那么假设我有一组男生女生混合的身高数据(不请楚哪一个身高值是女生的,哪一个身高值是男生的),我们需要依据身高的数据反推哪一个身高值是女生的,哪一个身高值是男生的,所以隐含层是一个二分类问题,即判断哪一个身高值是女生的,哪一个身高值是男生的

如果数据好,那么会拟合出两个高斯密度分布曲线,因此只需要利用观测值(观测值是一组身高值),基于EM算法估计出这两个分布的密度曲线的参数即可。当新引入一个观测值时,我们就可以计算该观测值在两个分布中的似然值,根据两个分布中似然值的大小就可以判断出该观测值到底属于哪一个分布曲线了(哪一类)


如上图所示,① 在左边的密度分布的似然值要大于右边密度分布的似然值,因此更容易将 ① 划分到女生群体里面;② 在左边的密度分布的似然值要小于右边密度分布的似然值,因此更容易将 ② 划分到男生群体里面

代码分析

# 数据准备
install.packages("mclust")

install.packages("gclus")
data("wine", package = "gclus")
dim(wine)

# 获取一组观测值
data = wine[,-1]
G = NULL
# 初始化那14种模型
modelNames = NULL 
prior = NULL
control = emControl()
initialization = list(hcPairs = NULL, 
                      subset = NULL, 
                      noise = NULL)
Vinv = NULL
warn = mclust.options("warn")
x = NULL
verbose = interactive()

# 整合data,获取基本属性
dimData <- dim(data)
oneD <- (is.null(dimData) || length(dimData[dimData > 1]) == 1)
data <- as.matrix(data)
n <- nrow(data)
d <- ncol(data)
modelNames <- mclust.options("emModelNames")  
  
# 软件默认的分群是分出1-9类
G <- 1:9 
Gall <- G
Mall <- modelNames
l <- length(Gall)
m <- length(Mall)

# 这个是R里面进度条的写法
pbar <- txtProgressBar(min = 0, max = l*m+1, style = 3)
on.exit(close(pbar))
ipbar <- 0

# 初始化14个模型的BIC矩阵
EMPTY <- -.Machine$double.xmax
BIC <- RET <- matrix(EMPTY, nrow = l, ncol = m, 
                     dimnames = list(as.character(Gall), as.character(Mall)))

# 设置循环
G <- as.numeric(G)
Glabels <- as.character(G)
Gout <- G

# 对划分为一个高斯密度分布的参数进行估计,并用BIC准则筛选最佳模型
for(mdl in modelNames[BIC["1",] == EMPTY]) 
{
  # 每一个 mdl 代表14个模型中的一个
  ## mvn函数利用EM算法对一个高斯密度分布的参数进行估计
  out <- mvn(modelName = mdl, data = data, prior = prior)
  ## 计算模型的BIC准则
  ## out$loglik代表经过EM算法估计的最佳参数(期望)对应的似然函数的最大似然值(对数化)
  BIC["1", mdl] <- bic(modelName = mdl, loglik = out$loglik, 
                       n = n, d = d, G = 1, equalPro = FALSE)
  RET["1", mdl] <- attr(out, "returnCode")

  # 打印进度条
  if(verbose) 
  { ipbar <- ipbar+1; setTxtProgressBar(pbar, ipbar) }
}

# 对分2-9个类
G <- G[-1]
Glabels <- Glabels[-1]

# 先用层次聚类将数据分为2 - 9个类
hcPairs <- hc(data = data, modelName = mclust.options("hcModelName"),use = mclust.options("hcUse"))
clss <- hclass(hcPairs, G)

# 每个g代表2 - 9个类的其中一种情况;g=2—9
for (g in Glabels) 
{
  cl <- clss[,g]
  ipbar <- ipbar+1
  setTxtProgressBar(pbar, ipbar) 
  ## 划分为0,1矩阵表示某观测值是否属于该类
  z <- unmap(cl, groups = 1:max(cl))
    
  # 对14个模型进行遍历
  for(modelName in na.omit(modelNames[BIC[g,] == EMPTY])) 
  {  
    # 每一个 mdl 代表14个模型中的一个
    ## me函数利用EM算法对多个高斯密度分布的参数进行估计
    out <- me(data = data, modelName = modelName, z = z,
                prior = prior, control = control, warn = warn)
    ## 计算模型的BIC准则
    ## out$loglik代表经过EM算法估计的最佳参数(期望)对应的似然函数的最大似然值(对数化)
    BIC[g, modelName] <- bic(modelName = modelName, 
                             loglik = out$loglik,
                             n = n, d = d, G = as.numeric(g), 
                             equalPro = control$equalPro)
    RET[g, modelName] <- attr(out, "returnCode")
      
    # 打印进度条
    ipbar <- ipbar+1
    setTxtProgressBar(pbar, ipbar)  
  }
}

# 赋值
structure(BIC, G = Gout, modelNames = modelNames, 
          prior = prior, Vinv = Vinv, control = control, 
          initialization = list(hcPairs = hcPairs, 
                                subset = initialization$subset,
                                noise = initialization$noise), 
          warn = warn, n = n, d = d, oneD = oneD,
          criterion = "BIC", returnCodes = RET, 
          class = "mclustBIC")

  
BIC

每一行代表分为1-9个高斯混合模型,每一列代表14个模型,里面的元素值代表BIC准则

当层次聚为3类的时候:

z矩阵
每一行代表其中一个观测值,每一列代表分成的三个类,0和1分别代表不属于/属于该类

其中要注意的是:
EM算法的期望最大化是利用极大似然法计算的,那么代码中的out$loglik经过EM算法估计的最佳参数(期望)对应的似然函数的最大似然值(对数化)

整体代码流程

data("wine", package = "gclus")
dim(wine)
X <- data.matrix(wine[,-1])
mod <- Mclust(X)

table(wine$Class, mod$classification)
# 如下是输出信息
    1  2  3
 1 59  0  0
 2  0 69  2
 3  0  0 48
# adjustedRandIndex:评估聚类效果

结果显示聚成三类更好,仅有2例没有正确聚类

mvn函数:https://cran.r-project.org/web/packages/MVN/vignettes/MVN.html
参考:https://www.jianshu.com/p/538a3cc66697

上一篇下一篇

猜你喜欢

热点阅读