空间转录组 STAGATE

发布时间:2023-01-20 10:30

最近在阅读和复现各个大佬的空转论文,记录、交流学习下,如有错误,欢迎指出。

前言

首先是STAGATE,是中科院提出来的方法,具体发表在NC上,主要思路与空转普遍的思路类似,提取基因表达、空间信息和图像特征,然后进行聚类,以识别每个spot的类型。当然,STAGATE,没有用图像信息,就已经是是目前已发表论文中最好的结果了。

总体架构

总体架构如下。

总体来说模型就是一个四层的AutoEncode,两层编码器两层解码器,只是每一层都换成了GAT。将基因表达数据X输入进去再重构出来X’,损失函数自然而然的就是X和X’的MSE。值得注意的是第二层和第三层,第一层和第四层分别共用一组权重W,为转置关系,这点在图上已经表明。如果是spot级别的数据,模型就已经全部讲完了,如果是细胞级别的数据,还会构建SNN,即重新构建一个新的GAT的邻接矩阵,然后每一层的结果是新的邻接矩阵和旧邻接矩阵构成的GAT加权求和为下一层的输入。

代码

作者最初发布的是tensorflow1的代码,今年三月份又公布了torch的代码,但是torch版本没有构建SNN,在细节上与tensorflow也略有不同,比如损失函数,tensorflow中除了MSE,又加入了权重损失防止过拟合,具体的在代码中我发现的都会提到。下面我试着根据torch版本的代码来说下我对这篇论文的理解。(最好在linux系统上运行,在windows上总是会出现各种奇怪错误)

首先是数据预处理。包括数据读取,在根据论文下载数据就好。然后是Normalization,选择高表达基因,正则化,取对数。再然后是读取真实标签用于最后测评并做了可视化。

    input_dir = os.path.join('Data', section_id)
    adata = sc.read_visium(path=input_dir, count_file=section_id+'_filtered_feature_bc_matrix.h5')
    adata.var_names_make_unique()

    #Normalization
    sc.pp.highly_variable_genes(adata, flavor="seurat_v3", n_top_genes=3000)
    sc.pp.normalize_total(adata, target_sum=1e4)
    sc.pp.log1p(adata)

    Ann_df = pd.read_csv(os.path.join('Data',
                                      section_id, "cluster_labels_"+section_id+'.csv'), sep=',', header=0, index_col=0)
    adata.obs['ground_truth'] = Ann_df.loc[adata.obs_names, 'ground_truth']

    plt.rcParams["figure.figsize"] = (3, 3)
    sc.pl.spatial(adata, img_key="hires", color=["ground_truth"])

然后是spot和spot之间的距离。距离大于0小于150的spot构建邻接矩阵,在这个范围内认为有连接,邻接矩阵为1,否则是0。以下是计算符合距离范围的spot的距离,并保存adata.uns['Spatial_Net']中。

def Cal_Spatial_Net(adata, rad_cutoff=None, k_cutoff=None, model='Radius', verbose=True):
    """\
    Construct the spatial neighbor networks.

    Parameters
    ----------
    adata
        AnnData object of scanpy package.
    rad_cutoff
        radius cutoff when model='Radius'
    k_cutoff
        The number of nearest neighbors when model='KNN'
    model
        The network construction model. When model=='Radius', the spot is connected to spots whose distance is less than rad_cutoff. When model=='KNN', the spot is connected to its first k_cutoff nearest neighbors.
    
    Returns
    -------
    The spatial networks are saved in adata.uns['Spatial_Net']
    """

    assert(model in ['Radius', 'KNN'])
    if verbose:
        print('------Calculating spatial graph...')
    coor = pd.DataFrame(adata.obsm['spatial'])
    coor.index = adata.obs.index
    coor.columns = ['imagerow', 'imagecol']

    if model == 'Radius':
        nbrs = sklearn.neighbors.NearestNeighbors(radius=rad_cutoff).fit(coor)
        distances, indices = nbrs.radius_neighbors(coor, return_distance=True)
        KNN_list = []
        for it in range(indices.shape[0]):
            KNN_list.append(pd.DataFrame(zip([it]*indices[it].shape[0], indices[it], distances[it])))
    
    if model == 'KNN':
        nbrs = sklearn.neighbors.NearestNeighbors(n_neighbors=k_cutoff+1).fit(coor)
        distances, indices = nbrs.kneighbors(coor)
        KNN_list = []
        for it in range(indices.shape[0]):
            KNN_list.append(pd.DataFrame(zip([it]*indices.shape[1],indices[it,:], distances[it,:])))

    KNN_df = pd.concat(KNN_list)
    KNN_df.columns = ['Cell1', 'Cell2', 'Distance']

    Spatial_Net = KNN_df.copy()
    Spatial_Net = Spatial_Net.loc[Spatial_Net['Distance']>0,]
    id_cell_trans = dict(zip(range(coor.shape[0]), np.array(coor.index), ))
    Spatial_Net['Cell1'] = Spatial_Net['Cell1'].map(id_cell_trans)
    Spatial_Net['Cell2'] = Spatial_Net['Cell2'].map(id_cell_trans)
    if verbose:
        print('The graph contains %d edges, %d cells.' %(Spatial_Net.shape[0], adata.n_obs))
        print('%.4f neighbors per cell on average.' %(Spatial_Net.shape[0]/adata.n_obs))

    adata.uns['Spatial_Net'] = Spatial_Net

随后是一个可视化,平均每个spot有多少个邻居。

def Stats_Spatial_Net(adata):
    import matplotlib.pyplot as plt
    Num_edge = adata.uns['Spatial_Net']['Cell1'].shape[0]
    Mean_edge = Num_edge/adata.shape[0]
    plot_df = pd.value_counts(pd.value_counts(adata.uns['Spatial_Net']['Cell1']))
    plot_df = plot_df/adata.shape[0]
    fig, ax = plt.subplots(figsize=[3,2])
    plt.ylabel('Percentage')
    plt.xlabel('')
    plt.title('Number of Neighbors (Mean=%.2f)'%Mean_edge)
    ax.bar(plot_df.index, plot_df)

下面就正式进入STAGATE的训练阶段了。

首先将是数据准备,包括两部分:根据挑选出来的邻居构建邻接矩阵和基因表达数据。

def Transfer_pytorch_Data(adata):
    G_df = adata.uns['Spatial_Net'].copy()
    cells = np.array(adata.obs_names)
    cells_id_tran = dict(zip(cells, range(cells.shape[0])))
    G_df['Cell1'] = G_df['Cell1'].map(cells_id_tran)
    G_df['Cell2'] = G_df['Cell2'].map(cells_id_tran)

    G = sp.coo_matrix((np.ones(G_df.shape[0]), (G_df['Cell1'], G_df['Cell2'])), shape=(adata.n_obs, adata.n_obs))
    G = G + sp.eye(G.shape[0])

    edgeList = np.nonzero(G)
    if type(adata.X) == np.ndarray:
        data = Data(edge_index=torch.LongTensor(np.array(
            [edgeList[0], edgeList[1]])), x=torch.FloatTensor(adata.X))  # .todense()
    else:
        data = Data(edge_index=torch.LongTensor(np.array(
            [edgeList[0], edgeList[1]])), x=torch.FloatTensor(adata.X.todense()))  # .todense()
    return data

然后构建STAGATE模型 正如前边所说四层GAT,其中h2是最后的特征向量,h4是重建的基因表达数据。

class STAGATE(torch.nn.Module):
    def __init__(self, hidden_dims):
        super(STAGATE, self).__init__()

        [in_dim, num_hidden, out_dim] = hidden_dims
        self.conv1 = GATConv(in_dim, num_hidden, heads=1, concat=False,
                             dropout=0, add_self_loops=False, bias=False)
        self.conv2 = GATConv(num_hidden, out_dim, heads=1, concat=False,
                             dropout=0, add_self_loops=False, bias=False)
        self.conv3 = GATConv(out_dim, num_hidden, heads=1, concat=False,
                             dropout=0, add_self_loops=False, bias=False)
        self.conv4 = GATConv(num_hidden, in_dim, heads=1, concat=False,
                             dropout=0, add_self_loops=False, bias=False)

    def forward(self, features, edge_index):

        h1 = F.elu(self.conv1(features, edge_index))
        h2 = self.conv2(h1, edge_index, attention=False)
        self.conv3.lin_src.data = self.conv2.lin_src.transpose(0, 1)
        self.conv3.lin_dst.data = self.conv2.lin_dst.transpose(0, 1)
        self.conv4.lin_src.data = self.conv1.lin_src.transpose(0, 1)
        self.conv4.lin_dst.data = self.conv1.lin_dst.transpose(0, 1)
        h3 = F.elu(self.conv3(h2, edge_index, attention=True,
                              tied_attention=self.conv1.attentions))
        h4 = self.conv4(h3, edge_index, attention=False)

        return h2, h4  # F.log_softmax(x, dim=-1)

具体的GAT代码不放了,详见`"Graph Attention Networks"

具体训练代码如下,不同点是加了梯度截断,最后返回h2,或者说是z,也就是特征向量用于下一步聚类分析,保存到adata中。

    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)

    loss_list = []
    for epoch in tqdm(range(1, n_epochs+1)):
        model.train()
        optimizer.zero_grad()
        z, out = model(data.x, data.edge_index)
        loss = F.mse_loss(data.x, out) #F.nll_loss(out[data.train_mask], data.y[data.train_mask])
        loss_list.append(loss)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), gradient_clipping)
        optimizer.step()
    
    model.eval()
    z, out = model(data.x, data.edge_index)
    
    STAGATE_rep = z.to('cpu').detach().numpy()
    adata.obsm[key_added] = STAGATE_rep

    if save_loss:
        adata.uns['STAGATE_loss'] = loss
    if save_reconstrction:
        ReX = out.to('cpu').detach().numpy()
        ReX[ReX<0] = 0
        adata.layers['STAGATE_ReX'] = ReX

最后调用了R中的mclust包进行聚类。

def mclust_R(adata, num_cluster, modelNames='EEE', used_obsm='STAGATE', random_seed=2020):
    """\
    Clustering using the mclust algorithm.
    The parameters are the same as those in the R package mclust.
    """
    
    np.random.seed(random_seed)
    import rpy2.robjects as robjects
    robjects.r.library("mclust")

    import rpy2.robjects.numpy2ri
    rpy2.robjects.numpy2ri.activate()
    r_random_seed = robjects.r['set.seed']
    r_random_seed(random_seed)
    rmclust = robjects.r['Mclust']

    res = rmclust(rpy2.robjects.numpy2ri.numpy2rpy(adata.obsm[used_obsm]), num_cluster, modelNames)
    mclust_res = np.array(res[-2])

    adata.obs['mclust'] = mclust_res
    adata.obs['mclust'] = adata.obs['mclust'].astype('int')
    adata.obs['mclust'] = adata.obs['mclust'].astype('category')
    return adata

去掉缺失值并计算ARI。tensorflow版本和后续的数据分析解析等我看明白再来记录,最后附上测试DFPFC数据库的主函数。所有代码、数据和论文可以再github上下载,欢迎交流。

import warnings
warnings.filterwarnings("ignore")

import pandas as pd
import numpy as np
import scanpy as sc
import matplotlib.pyplot as plt
import os
import sys
from sklearn.metrics.cluster import adjusted_rand_score
# import sklearn
import STAGATE_pyG as STAGATE
os.environ['R_HOME'] = '/home/admin/anaconda3/envs/lib/R'
# os.environ['R_USER'] = '/home/admin/Anaconda3\Lib\site-packages/rpy2'



dataset = ["151507", "151508", "151509", "151510", "151669", "151670", "151671", "151672", "151673", "151674", "151675",
           "151676"]
knn = [7, 7, 7, 7, 5, 5, 5, 5, 7, 7, 7, 7]
ARIlist = []
for section_id, k in zip(dataset, knn):
    print(section_id,k)
    input_dir = os.path.join('Data', section_id)
    adata = sc.read_visium(path=input_dir, count_file=section_id+'_filtered_feature_bc_matrix.h5')
    adata.var_names_make_unique()

    #Normalization
    sc.pp.highly_variable_genes(adata, flavor="seurat_v3", n_top_genes=3000)
    sc.pp.normalize_total(adata, target_sum=1e4)
    sc.pp.log1p(adata)

    Ann_df = pd.read_csv(os.path.join('Data',
                                      section_id, "cluster_labels_"+section_id+'.csv'), sep=',', header=0, index_col=0)
    adata.obs['ground_truth'] = Ann_df.loc[adata.obs_names, 'ground_truth']

    plt.rcParams["figure.figsize"] = (3, 3)
    sc.pl.spatial(adata, img_key="hires", color=["ground_truth"])

    STAGATE.Cal_Spatial_Net(adata, rad_cutoff=150)
    STAGATE.Stats_Spatial_Net(adata)

    adata = STAGATE.train_STAGATE(adata)

    sc.pp.neighbors(adata, use_rep='STAGATE')
    sc.tl.umap(adata)
    adata = STAGATE.mclust_R(adata, used_obsm='STAGATE', num_cluster=k)

    obs_df = adata.obs.dropna()
    ARI = adjusted_rand_score(obs_df['mclust'], obs_df['ground_truth'])
    ARIlist.append(ARI)
    print('Adjusted rand index = %.2f' %ARI)
print("ari mean", np.mean(ARIlist))
print("ari median", np.median(ARIlist))

ItVuer - 免责声明 - 关于我们 - 联系我们

本网站信息来源于互联网,如有侵权请联系:561261067@qq.com

桂ICP备16001015号