利用 Segment Anything实现医学图像分割

2023-04-11  本文已影响0人  此间不留白

前言

Meta 前不久发布了图像分割通用模型Segment Anything, 标志者图像分割领域的chatgpt时刻来临,通用模型在图像处理领域表现出了强大的潜力,业内有人戏称 “Segment Anything 的出现,代表着曾经作为图像处理领域主流的图像分割任务,基本不存在了,分割已经没什么任务需要继续去做了^_^”。 Segment Anything 在自然图像分割领域,性能强大,但是,就医学图像分割 (如靶区勾画等)而言,Segment Anything 的局限性仍然显著存在,不过,Segment Anything 作为一个强大的工具,可以成为医学图像分割任务前处理或后处理的利器(如对分割的目标区域进行约束等)。以下,是一个 Segment Anything实现医学图像分割的简单demo,以供参考,基于此demo,可以发掘更多Segment Anything更为有趣的玩法。

基本流程

def transformNsCV(self):
        minWindow = (2 * self.windowsCenter - self.windowsLevel) / 2.0 + 0.5
        maxWindow = (2 * self.windowsCenter + self.windowsLevel) / 2.0 + 0.5
        dFactor = 255.0 / (maxWindow - minWindow)

        transArray = self.imgArray - minWindow
        transArray = np.trunc(transArray * dFactor)
        
        transArray[transArray > 255] = 255
        transArray[transArray < 0] = 0

        return transArray

predictor = SamPredictor(self.model)
predictor.set_image(targetSlice)
        
seedsArr = np.array([[100,245],[500,345]]))    
labels = [i+1 for i in range(0,len(seedsArr))]
masks, scores, logits = predictor.predict(
 point_coords= seedsArr,
 point_labels= np.array(labels),
multimask_output=True,
        )


def pipelineSegAllImage(self,targetSlice:np.ndarray):
        
        maskSliceRes = np.zeros_like(targetSlice[:,:,0])
        
        mask_generator = SamAutomaticMaskGenerator(self.mdoel)
        targetSlice = targetSlice.astype(np.uint8)
        masks = mask_generator.generate(targetSlice)
        for mask in masks:
            maskArrayBool = mask["segmentation"]
            maskArray = maskArrayBool.astype(np.uint8)
            maskSliceRes = maskSliceRes + maskArray

        return maskSliceRes

其中一张slice的实现效果如下,由于是通过2D实现的分割,所以最终效果可能在z方向上的连续性不够好。


综上,利用Segment Anything实现医学图像分割的全流程代码如下


from segment_anything import sam_model_registry
from segment_anything import SamPredictor,SamAutomaticMaskGenerator
import SimpleITK as sitk
import numpy as np
import torch as t
import cv2
import matplotlib.pyplot as plt
from itertools import product
import random
import sys
from ReadAndWrite import ReadImageBase

class SamMedSegmentation():
    def __init__(self,imgArray:np.ndarray,windowsCenter:int,windowsLevel:int,seedlists=None):
        
        self.imgArray = imgArray
        assert(len(self.imgArray.shape)) == 3
        if seedlists == None:
            self.seedlists = self.generateSeeds(1)

        else:
            self.seedlists = seedlists

        self.windowsCenter = windowsCenter
        self.windowsLevel = windowsLevel
        
        self.model = self.initializeModel()

    def initializeModel(self):
        modelType = "vit_l"
        checkPoint = r"/mnt/e/ChromeDwnLoad/sam_vit_l_0b3195.pth"
        sam = sam_model_registry[modelType](checkpoint=checkPoint)
        sam.to("cuda")

        return sam


    def generateSeeds(self,maxIter):
        z,y,x = self.imgArray.shape
        px = (int(x/8),int(x/7),int(x/6),int(x/4),int(x/3),int(x/2))
        py = (int(y/8),int(y/7),int(y/5),int(y/4),int(y/3),int(y/2))

        points = list(product(py,px))
        pointsNum = list(map(lambda x: np.array(x),points))
        point_xy = list()
        for i in range(0,maxIter):
            points_random = random.choice(pointsNum)
        #print([points_random[0], points_random[1]])
            point_x = points_random[0]
            point_y = points_random[1]
            point_xy.append([point_x,point_y])
        return point_xy
    

    def transformNsCV(self):
        minWindow = (2 * self.windowsCenter - self.windowsLevel) / 2.0 + 0.5
        maxWindow = (2 * self.windowsCenter + self.windowsLevel) / 2.0 + 0.5
        dFactor = 255.0 / (maxWindow - minWindow)

        transArray = self.imgArray - minWindow
        transArray = np.trunc(transArray * dFactor)
        
        transArray[transArray > 255] = 255
        transArray[transArray < 0] = 0

        return transArray




    def convertMedImg2CV(self,targetArray:np.ndarray):
        assert len(targetArray.shape) == 3
        z,y,x = targetArray.shape
        imglist = []
        for z_i in range(0,z):
            targetSlice = targetArray[z_i,:,:]
            targetSliceRGB = cv2.cvtColor(targetSlice,cv2.COLOR_GRAY2RGB)
            #print(type(targetSliceRGB))
            imglist.append(targetSliceRGB)
        return imglist


    def pipelineSegAllImage(self,targetSlice:np.ndarray):
        
        maskSliceRes = np.zeros_like(targetSlice[:,:,0])
        
        mask_generator = SamAutomaticMaskGenerator(self.mdoel)
        targetSlice = targetSlice.astype(np.uint8)
        masks = mask_generator.generate(targetSlice)
        for mask in masks:
            maskArrayBool = mask["segmentation"]
            maskArray = maskArrayBool.astype(np.uint8)
            maskSliceRes = maskSliceRes + maskArray

        return maskSliceRes
    

    def pipelineSegPoint(self,targetSlice:np.ndarray):
        targetSlice  = targetSlice.astype(np.uint8)
        maskSliceRes = np.zeros_like(targetSlice[:,:,0])
        
        predictor = SamPredictor(self.model)
        predictor.set_image(targetSlice)
        
        
        seeds = self.generateSeeds(4)
        seedsArr = np.array(seeds)
        print(seedsArr.shape)    
        #print(seeds)
        labels = [i+1 for i in range(0,len(seedsArr))]
        masks, scores, logits = predictor.predict(
        point_coords= seedsArr,
        point_labels= np.array(labels),
        multimask_output=True,
        )

        for idx,mask in enumerate(masks):
            mask = mask.astype(np.uint8)
            maskSliceRes += mask

        return maskSliceRes


    

    def processPipeline(self,all=False):
        targetArray = self.transformNsCV()
        imgslists = self.convertMedImg2CV(targetArray)
        if False == all:
            maskSlicesRes = list(map(self.pipelineSegPoint,imgslists))
        else:

            maskSlicesRes = list(map(self.pipelineSegAllImage,imgslists))
        maskResults = np.zeros_like(self.imgArray)
        for idx,mask in enumerate(maskSlicesRes):
            maskResults[idx,:,:] = mask

        
        return maskResults
    


if __name__ == '__main__':
    
    imgpath = r"./0522c0149"
    outpath = r"./out"
    reader = ReadImageBase(imgpath,outpath,'.nrrd')
    #fileLen = len(reader)
    imgArray,detail = reader[0]


    seg = SamMedSegmentation(imgArray,50,350)
    maskResults = seg.processPipeline(all=False)
    print(maskResults)
    print(np.max(maskResults))
    maskResults = maskResults.astype(np.uint8)
    reader.writer(maskResults,detail,"segPoint.nrrd",outPath = outpath)


上一篇 下一篇

猜你喜欢

热点阅读