单细胞单细胞数据分析单细胞测序

10X单细胞(10X空间转录组)批次效应去除大盘点2 & AWG

2021-11-23  本文已影响0人  单细胞空间交响乐

hello,大家好,今天我们要再度深入认识一下批次效应,关于批次效应,我们之前分享了很多了,文章列在这里,供大家参考

10X单细胞(10X空间转录组)多样本批次效应去除分析之RCA2

10X单细胞(10X空间转录组)整合分析批次处理之细节(harmony)

10X单细胞(10X空间转录组)批次去除(整合)分析之Scanorama

10X单细胞(10X空间转录组)批次效应去除大盘点

10X单细胞(空间转录组)数据整合分析批次矫正之liger

单细胞数据用Harmony算法进行批次矫正

批次效应

10X单细胞(10X空间转录组)数据分析总结之各种NMF

今天我们要深入一下批次去除的方法AWGAN,参考文献在AWGAN: A Powerful Batch Correction Model for scRNA-seq Data

目前去除批次效应的两种思路

Wasserstein Generative Adversarial Network (WGAN) combined with an attention mechanism to reduce the differences among batches.

图片.png

AWGAN的分析步骤,three key steps: attention-driven data preprocessing, AWGAN training, and model evaluation

各个软件之间的比较

Small-scale scRNA-seq Datasets

图片.png

Large-scale scRNA-seq Datasets

图片.png
图片.png

示例代码(python)

import torch.autograd
import torch.nn as nn
from torch.autograd import Variable
import torch.nn.functional as F
from torchvision import transforms
from torchvision import datasets
import torch.utils.data as Data  #Data是用来批训练的模块
from torchvision.utils import save_image
import numpy as np
import os
import pandas as pd
import torch.optim.lr_scheduler as lr_s 
from collections import Counter
import loompy
from scipy.spatial.distance import cdist
import scprep
import imap  #used for feature detected
import numpy as np
import squidpy as sq
import pandas as pd
import matplotlib.pyplot as plt
import phate
import graphtools as gt
import magic
import os
import datetime
import scanpy as sc
from skmisc.loess import loess
import sklearn.preprocessing as preprocessing
import umap.umap_ as umap
from numba import jit
from sklearn.metrics import silhouette_score
import random

def silhouette_coeff_ASW(adata, method_use='raw',save_dir='', save_fn='', percent_extract=0.8):
    random.seed(0)
    asw_fscore = []
    asw_bn = []
    asw_bn_sub = []
    asw_ctn = [] 
    iters = []
    for i in range(20):
        iters.append('iteration_'+str(i+1))
        rand_cidx = np.random.choice(adata.obs_names, size=int(len(adata.obs_names) * percent_extract), replace=False)
        print('nb extracted cells: ',len(rand_cidx))
        adata_ext = adata[rand_cidx,:]
        asw_batch = silhouette_score(adata_ext.X, adata_ext.obs['batch'])
        asw_celltype = silhouette_score(adata_ext.X, adata_ext.obs['louvain'])
        min_val = -1
        max_val = 1
        asw_batch_norm = (asw_batch - min_val) / (max_val - min_val)
        asw_celltype_norm = (asw_celltype - min_val) / (max_val - min_val)
        
        fscoreASW = (2 * (1 - asw_batch_norm)*(asw_celltype_norm))/(1 - asw_batch_norm + asw_celltype_norm)
        asw_fscore.append(fscoreASW)
        asw_bn.append(asw_batch_norm)
        asw_bn_sub.append(1-asw_batch_norm)
        asw_ctn.append(asw_celltype_norm)
    
#     iters.append('median_value')
#     asw_fscore.append(np.round(np.median(fscoreASW),3))
#     asw_bn.append(np.round(np.median(asw_batch_norm),3))
#     asw_bn_sub.append(np.round(1 - np.median(asw_batch_norm),3))
#     asw_ctn.append(np.round(np.median(asw_celltype_norm),3))
    df = pd.DataFrame({'asw_batch_norm':asw_bn, 'asw_batch_norm_sub': asw_bn_sub,
                       'asw_celltype_norm': asw_ctn, 'fscore':asw_fscore,
                       'method_use':np.repeat(method_use, len(asw_fscore))})
    df.to_csv(save_dir + save_fn + '.csv')
    print('Save output of pca in: ',save_dir)
    print(df.values.shape)
    print(df.keys())
    return df

Real AWGAN

adata = sc.read_loom('CRC_CONCAT.loom', sparse=False)
#preprocessing, same as the preprocessing code in the model
adata = imap.stage1.data_preprocess(adata)
res1 = sq.gr.ligrec(
    adata,
    n_perms=2000,
    cluster_key="celltype",
    copy=True,
    use_raw=False,
    transmitter_params={"categories": "ligand"},
    receiver_params={"categories": "receptor"}
)
res1

{'means': cluster_1 B cell ... Myeloid cell
cluster_2 B cell CD4 T cell ... ILC Myeloid cell
source target ...
FYN NTRK2 0.000000 0.020737 ... 0.000000 0.000000
CSF1 NTRK2 0.000000 0.003346 ... 0.000000 0.000000
HGF NTRK2 0.000000 0.002333 ... 0.000000 0.000000
AREG NTRK2 0.000000 0.177876 ... 0.000000 0.000000
PDGFC NTRK2 0.000000 0.001336 ... 0.000000 0.000000
... ... ... ... ... ...
SERPINF1 PLXDC2 0.009675 0.009378 ... 0.045745 0.186307
HPGDS PTGDR 0.004825 0.017215 ... 0.280152 0.043597
PTGDR2 0.000000 0.002751 ... 0.000000 0.045051
EBI3 IL12RB2 0.012528 0.015942 ... 0.017023 0.014924
VSTM1 ADGRG3 0.000805 0.000315 ... 0.075256 0.070944
[1167 rows x 25 columns],
'metadata': aspect_intercell_source ... uniprot_intercell_target
source target ...
FYN NTRK2 functional ... Q16620
CSF1 NTRK2 functional ... Q16620
HGF NTRK2 functional ... Q16620
AREG NTRK2 functional ... Q16620
PDGFC NTRK2 functional ... Q16620
... ... ... ...
SERPINF1 PLXDC2 functional ... Q6UX71
HPGDS PTGDR functional ... Q13258
PTGDR2 functional ... Q9Y5Y4
EBI3 IL12RB2 functional ... COMPLEX:P40189_Q99665
VSTM1 ADGRG3 functional ... Q86Y34
[1167 rows x 42 columns],
'pvalues': cluster_1 B cell ... Myeloid cell
cluster_2 B cell CD4 T cell CD8 T cell ... CD8 T cell ILC Myeloid cell
source target ...
FYN NTRK2 NaN NaN NaN ... NaN NaN NaN
CSF1 NTRK2 NaN NaN NaN ... NaN NaN NaN
HGF NTRK2 NaN NaN NaN ... NaN NaN NaN
AREG NTRK2 NaN NaN NaN ... NaN NaN NaN
PDGFC NTRK2 NaN NaN NaN ... NaN NaN NaN
... ... ... ... ... ... ... ...
SERPINF1 PLXDC2 NaN NaN NaN ... NaN 0.998 0.0
HPGDS PTGDR NaN NaN NaN ... 0.0 0.000 NaN
PTGDR2 NaN NaN NaN ... NaN NaN NaN
EBI3 IL12RB2 NaN NaN NaN ... NaN NaN NaN
VSTM1 ADGRG3 NaN NaN NaN ... NaN NaN NaN
[1167 rows x 25 columns]}

sq.pl.ligrec(res1, alpha=0.005, save='crc_ligen_origin.pdf')
index.png
adata1 = adata[adata.obs['batch'] == 'batch2']
adata2 = adata[adata.obs['batch'] == 'batch1']
adata = adata1.concatenate(adata2, batch_categories=['batch2','batch1'])
data_umap = umap.UMAP().fit_transform(adata.X)
scprep.plot.scatter2d(data_umap,adata.obs['batch'],figsize=(12,8), cmap="Spectral",
                      ticks=False, label_prefix="UMAP", s = 20, title = 'Batch before removal')
index.png
scprep.plot.scatter2d(data_umap,adata.obs['celltype'],figsize=(12,8), cmap="Spectral",
                      ticks=False, label_prefix="UMAP", s = 20, title = 'Celltype before removal')
index.png
adata_all = adata
adata_all.X = adata_all.X.todense()
#calculate cos distence
@jit(nopython=True)
def pdist(vec1,vec2):
  return np.dot(vec1,vec2)/(np.linalg.norm(vec1)*np.linalg.norm(vec2))
#calculate correlation index
@jit(nopython=True)
def find_correlation_index(frame1, frame2):
  result=[(1,1) for _ in range(len(frame2))]
  for i in range(len(frame2)):
    max_dist = -10
    it1=0
    it2=0
    for j in range(len(frame1)):
      dist = pdist(frame2[i],frame1[j])
      if dist>max_dist:
        max_dist = dist
        it1 = i
        it2 = j 
    result[i] = (it1, it2)
  return result
# A new approach to get the index, what is faster based on our research.
def find_correlation_index(frame1, frame2):
  distlist =  cdist(frame2,frame1,metric='cosine')
  result = np.argmin(distlist,axis=1)
  result2 = []
  for i in range(len(frame2)):
    result2.append((i,result[i]))
  return result2
adata_all = adata.copy()
adata3 = adata_all.copy()
ref_adata = adata3[adata3.obs['batch'] != adata3.obs['batch'][0]]
batch_adata = adata3[adata3.obs['batch'] == adata3.obs['batch'][0]]
c=Counter(adata3.obs['batch'])
c=dict(c)
ind_list = find_correlation_index(ref_adata.X, batch_adata.X)
common_pair = ind_list
donar_1_d = ref_adata.X
donar_2_d = batch_adata.X
result=[]
result1=[]
for i in common_pair:
  result.append(donar_1_d[i[1],:])
  result1.append(donar_2_d[i[0],:])
donar_1_t=np.array(result)
donar_2_t=np.array(result1)
train_data = donar_2_t
train_label = donar_1_t
def training_set_generator(frame1,frame2,ref,batch):
  common_pair = find_correlation_index(frame1,frame2)
  result = []
  result1 = []
  for i in common_pair:
    result.append(ref[i[1],:])
    result1.append(batch[i[0],:])
  return np.array(result),np.array(result1)
np.random.seed(999)
torch.manual_seed(999)
torch.cuda.manual_seed_all(999)
class Mish(nn.Module):
  def __init__(self):
    super().__init__()

  def forward(self,x):
    return x*torch.tanh(F.softplus(x))
#WGAN model, and it does not need to use bath normalization based on WGAN paper.
class discriminator(nn.Module):
    def __init__(self):
        super(discriminator, self).__init__()
        self.dis = nn.Sequential(
            nn.Linear(2000, 1024),  
            Mish(),
            nn.Linear(1024, 512),  
            Mish(),
            nn.Linear(512, 256),  
            Mish(),
            nn.Linear(256, 128),  
            Mish(),
            nn.Linear(128, 1),  
            Mish()

        )

    def forward(self, x):
        x = self.dis(x)
        return x
 
 
# WGAN generator
# Require batch normalization
class generator(nn.Module):
    def __init__(self):
        super(generator, self).__init__()
        self.relu_l = nn.ReLU(True)
        self.gen = nn.Sequential(
         
            nn.Linear(2000, 1024),  
            nn.BatchNorm1d(1024, eps = 1e-7, momentum=0.01),
            nn.Dropout(0.5),
            Mish(),

            nn.Linear(1024, 512),  
            nn.BatchNorm1d(512, eps = 1e-7, momentum=0.01),
            Mish(),

            nn.Linear(512, 256),  
            nn.BatchNorm1d(256, eps = 1e-7, momentum=0.01),
            Mish(),

            nn.Linear(256, 512),  
            nn.BatchNorm1d(512, eps = 1e-7, momentum=0.01),
            Mish(),
  

            nn.Linear(512, 1024),  
            nn.BatchNorm1d(1024, eps = 1e-7, momentum=0.01),
            Mish(),

            nn.Linear(1024, 2000), 
            nn.Dropout(0.5) 
           
        )

    def forward(self, x):
        gre = self.gen(x)
        return self.relu_l(gre+x)    #residual network
 
 
# 创建对象
D = discriminator()
G = generator()

if torch.cuda.is_available():
  D = D.cuda()
  G = G.cuda()
# calculate gradient penalty
def calculate_gradient_penalty(real_data, fake_data, D): 
  eta = torch.FloatTensor(real_data.size(0),1).uniform_(0,1) 
  eta = eta.expand(real_data.size(0), real_data.size(1)) 
  cuda = True if torch.cuda.is_available() else False 
  if cuda: 
    eta = eta.cuda() 
  else: 
    eta = eta 
  interpolated = eta * real_data + ((1 - eta) * fake_data) 
  if cuda: 
    interpolated = interpolated.cuda() 
  else: 
    interpolated = interpolated 
   # define it to calculate gradient 
  interpolated = Variable(interpolated, requires_grad=True) 
   # calculate probability of interpolated examples 
  prob_interpolated = D(interpolated) 
  # calculate gradients of probabilities with respect to examples 
  gradients = torch.autograd.grad(outputs=prob_interpolated, inputs=interpolated, 
  grad_outputs=torch.ones( 
  prob_interpolated.size()).cuda() if cuda else torch.ones( 
  prob_interpolated.size()), 
  create_graph=True, retain_graph=True)[0] 
  grad_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() 
  return grad_penalty
@jit(nopython = True)
def determine_batch(val1):
  val_list =[32,40,52,64,128,256]
  for i in val_list:
    if val1%i !=1:
      return i
    else:
      continue
  return val1
# parameters
EPOCH = 40
MAX_ITER = train_data.shape[0]
batch = 128
b1 = 0.9
b2 = 0.999
lambda_1 = 1/100


d_optimizer = torch.optim.AdamW(D.parameters(), lr=0.0001)
g_optimizer = torch.optim.AdamW(G.parameters(), lr=0.0001)

c=Counter(adata3.obs['batch'])
c=dict(c)

err_G = []
err_D = []
stop = 0
iter = 0
#####################Since we only have two batches, so we adopt easier structure
for epoch in range(EPOCH):
  print(epoch)
  for time in range(0,MAX_ITER,batch):
    true_data = torch.FloatTensor(train_label[time:time+batch,:]).cuda()
    false_data = torch.FloatTensor(train_data[time:time+batch,:]).cuda()
    

    #train d at first

    d_optimizer.zero_grad()

    real_out = D(true_data)
    real_label_loss = -torch.mean(real_out)

    err_D.append(real_label_loss.cpu().float())

    # train use WGAN

    fake_out_new = G(false_data).detach()
    fake_out = D(fake_out_new)

    div = calculate_gradient_penalty(true_data, fake_out_new, D)

    label_loss = real_label_loss+torch.mean(fake_out)+div/lambda_1
    label_loss.backward()

    err_D.append(label_loss.cpu().item())
    

    d_optimizer.step()
  
    #train G

    real_out = G(false_data)
    real_output = D(real_out)

    real_loss1 = -torch.mean(real_output)
    err_G.append(real_loss1.cpu().item())

    g_optimizer.zero_grad()

    real_loss1.backward()
    g_optimizer.step()

    if(time%100==0):
      print("g step loss",real_loss1)
    iter += 1

  if stop == 1:
    break

G.eval()
test_data = torch.FloatTensor(train_data).cuda()
remove_batch_data = G(test_data).detach().cpu().numpy()

data =np.vstack([remove_batch_data, donar_1_d])

data.shape

(53018, 3892)

data_umap = umap.UMAP().fit_transform(data)

scprep.plot.scatter2d(data_umap, c=adata3.obs['batch'], figsize=(12,8), cmap="Spectral",
                      ticks=False, label_prefix="UMAP", s = 20, title = 'Batch after removal')

scprep.plot.scatter2d(data_umap, c=adata3.obs['celltype'], figsize=(12,8), cmap="Spectral",
                      ticks=False, label_prefix="UMAP", s = 20, title = 'Celltype after removal')

obs = pd.DataFrame()

obs['batch'] = adata3.obs['batch']
obs['louvain'] = adata3.obs['celltype']

funcdata = sc.AnnData(data, obs)


funcdata = sc.AnnData(data, obs)
silhouette_coeff_ASW(funcdata)

Evaluation

sc.set_figure_params(dpi=100,color_map = 'viridis_r',fontsize=25)
sc.settings.verbosity = 1
sc.logging.print_header()
adata_gold = sc.read_loom('CRC_gold.loom', sparse=False)
adata = sc.read_loom('CRC_CONCAT.loom', sparse=False)
adata_all = imap.stage1.data_preprocess(adata,'batch')
adata_gold.obs['celltype'] = adata_gold.obs.louvain.copy()
adata_gold.var_names = adata_all.var_names.copy()
adata_gold.obs_names = adata_all.obs_names.copy()
sc.tl.rank_genes_groups(adata_gold, groupby='celltype', method='wilcoxon')
sc.pl.rank_genes_groups_heatmap(adata_gold, n_genes=2, use_raw=False, swap_axes=True, vmin=-3, vmax=3, cmap='bwr',figsize=(20,7), show=False)
index.png
sc.pl.rank_genes_groups_tracksplot(adata_gold, n_genes=2,figsize=(25,7))
index.png
adata_new = adata_gold
res = sq.gr.ligrec(
    adata_new,
    n_perms=1000,
    cluster_key="celltype",
    copy=True,
    use_raw=False,
    transmitter_params={"categories": "ligand"},
    receiver_params={"categories": "receptor"}
)
df_new.to_csv('cellphone_new10x.csv')
sq.pl.ligrec(res, alpha=0.005,save='crc_ligen.pdf')
图片.png

生活很好,有你更好

上一篇 下一篇

猜你喜欢

热点阅读