发布时间:2023-01-20 10:30
最近在阅读和复现各个大佬的空转论文,记录、交流学习下,如有错误,欢迎指出。
前言
首先是STAGATE,是中科院提出来的方法,具体发表在NC上,主要思路与空转普遍的思路类似,提取基因表达、空间信息和图像特征,然后进行聚类,以识别每个spot的类型。当然,STAGATE,没有用图像信息,就已经是是目前已发表论文中最好的结果了。
总体架构
总体架构如下。
总体来说模型就是一个四层的AutoEncode,两层编码器两层解码器,只是每一层都换成了GAT。将基因表达数据X输入进去再重构出来X’,损失函数自然而然的就是X和X’的MSE。值得注意的是第二层和第三层,第一层和第四层分别共用一组权重W,为转置关系,这点在图上已经表明。如果是spot级别的数据,模型就已经全部讲完了,如果是细胞级别的数据,还会构建SNN,即重新构建一个新的GAT的邻接矩阵,然后每一层的结果是新的邻接矩阵和旧邻接矩阵构成的GAT加权求和为下一层的输入。
代码
作者最初发布的是tensorflow1的代码,今年三月份又公布了torch的代码,但是torch版本没有构建SNN,在细节上与tensorflow也略有不同,比如损失函数,tensorflow中除了MSE,又加入了权重损失防止过拟合,具体的在代码中我发现的都会提到。下面我试着根据torch版本的代码来说下我对这篇论文的理解。(最好在linux系统上运行,在windows上总是会出现各种奇怪错误)
首先是数据预处理。包括数据读取,在根据论文下载数据就好。然后是Normalization,选择高表达基因,正则化,取对数。再然后是读取真实标签用于最后测评并做了可视化。
input_dir = os.path.join('Data', section_id)
adata = sc.read_visium(path=input_dir, count_file=section_id+'_filtered_feature_bc_matrix.h5')
adata.var_names_make_unique()
#Normalization
sc.pp.highly_variable_genes(adata, flavor="seurat_v3", n_top_genes=3000)
sc.pp.normalize_total(adata, target_sum=1e4)
sc.pp.log1p(adata)
Ann_df = pd.read_csv(os.path.join('Data',
section_id, "cluster_labels_"+section_id+'.csv'), sep=',', header=0, index_col=0)
adata.obs['ground_truth'] = Ann_df.loc[adata.obs_names, 'ground_truth']
plt.rcParams["figure.figsize"] = (3, 3)
sc.pl.spatial(adata, img_key="hires", color=["ground_truth"])
然后是spot和spot之间的距离。距离大于0小于150的spot构建邻接矩阵,在这个范围内认为有连接,邻接矩阵为1,否则是0。以下是计算符合距离范围的spot的距离,并保存adata.uns['Spatial_Net']中。
def Cal_Spatial_Net(adata, rad_cutoff=None, k_cutoff=None, model='Radius', verbose=True):
"""\
Construct the spatial neighbor networks.
Parameters
----------
adata
AnnData object of scanpy package.
rad_cutoff
radius cutoff when model='Radius'
k_cutoff
The number of nearest neighbors when model='KNN'
model
The network construction model. When model=='Radius', the spot is connected to spots whose distance is less than rad_cutoff. When model=='KNN', the spot is connected to its first k_cutoff nearest neighbors.
Returns
-------
The spatial networks are saved in adata.uns['Spatial_Net']
"""
assert(model in ['Radius', 'KNN'])
if verbose:
print('------Calculating spatial graph...')
coor = pd.DataFrame(adata.obsm['spatial'])
coor.index = adata.obs.index
coor.columns = ['imagerow', 'imagecol']
if model == 'Radius':
nbrs = sklearn.neighbors.NearestNeighbors(radius=rad_cutoff).fit(coor)
distances, indices = nbrs.radius_neighbors(coor, return_distance=True)
KNN_list = []
for it in range(indices.shape[0]):
KNN_list.append(pd.DataFrame(zip([it]*indices[it].shape[0], indices[it], distances[it])))
if model == 'KNN':
nbrs = sklearn.neighbors.NearestNeighbors(n_neighbors=k_cutoff+1).fit(coor)
distances, indices = nbrs.kneighbors(coor)
KNN_list = []
for it in range(indices.shape[0]):
KNN_list.append(pd.DataFrame(zip([it]*indices.shape[1],indices[it,:], distances[it,:])))
KNN_df = pd.concat(KNN_list)
KNN_df.columns = ['Cell1', 'Cell2', 'Distance']
Spatial_Net = KNN_df.copy()
Spatial_Net = Spatial_Net.loc[Spatial_Net['Distance']>0,]
id_cell_trans = dict(zip(range(coor.shape[0]), np.array(coor.index), ))
Spatial_Net['Cell1'] = Spatial_Net['Cell1'].map(id_cell_trans)
Spatial_Net['Cell2'] = Spatial_Net['Cell2'].map(id_cell_trans)
if verbose:
print('The graph contains %d edges, %d cells.' %(Spatial_Net.shape[0], adata.n_obs))
print('%.4f neighbors per cell on average.' %(Spatial_Net.shape[0]/adata.n_obs))
adata.uns['Spatial_Net'] = Spatial_Net
随后是一个可视化,平均每个spot有多少个邻居。
def Stats_Spatial_Net(adata):
import matplotlib.pyplot as plt
Num_edge = adata.uns['Spatial_Net']['Cell1'].shape[0]
Mean_edge = Num_edge/adata.shape[0]
plot_df = pd.value_counts(pd.value_counts(adata.uns['Spatial_Net']['Cell1']))
plot_df = plot_df/adata.shape[0]
fig, ax = plt.subplots(figsize=[3,2])
plt.ylabel('Percentage')
plt.xlabel('')
plt.title('Number of Neighbors (Mean=%.2f)'%Mean_edge)
ax.bar(plot_df.index, plot_df)
下面就正式进入STAGATE的训练阶段了。
首先将是数据准备,包括两部分:根据挑选出来的邻居构建邻接矩阵和基因表达数据。
def Transfer_pytorch_Data(adata):
G_df = adata.uns['Spatial_Net'].copy()
cells = np.array(adata.obs_names)
cells_id_tran = dict(zip(cells, range(cells.shape[0])))
G_df['Cell1'] = G_df['Cell1'].map(cells_id_tran)
G_df['Cell2'] = G_df['Cell2'].map(cells_id_tran)
G = sp.coo_matrix((np.ones(G_df.shape[0]), (G_df['Cell1'], G_df['Cell2'])), shape=(adata.n_obs, adata.n_obs))
G = G + sp.eye(G.shape[0])
edgeList = np.nonzero(G)
if type(adata.X) == np.ndarray:
data = Data(edge_index=torch.LongTensor(np.array(
[edgeList[0], edgeList[1]])), x=torch.FloatTensor(adata.X)) # .todense()
else:
data = Data(edge_index=torch.LongTensor(np.array(
[edgeList[0], edgeList[1]])), x=torch.FloatTensor(adata.X.todense())) # .todense()
return data
然后构建STAGATE模型 正如前边所说四层GAT,其中h2是最后的特征向量,h4是重建的基因表达数据。
class STAGATE(torch.nn.Module):
def __init__(self, hidden_dims):
super(STAGATE, self).__init__()
[in_dim, num_hidden, out_dim] = hidden_dims
self.conv1 = GATConv(in_dim, num_hidden, heads=1, concat=False,
dropout=0, add_self_loops=False, bias=False)
self.conv2 = GATConv(num_hidden, out_dim, heads=1, concat=False,
dropout=0, add_self_loops=False, bias=False)
self.conv3 = GATConv(out_dim, num_hidden, heads=1, concat=False,
dropout=0, add_self_loops=False, bias=False)
self.conv4 = GATConv(num_hidden, in_dim, heads=1, concat=False,
dropout=0, add_self_loops=False, bias=False)
def forward(self, features, edge_index):
h1 = F.elu(self.conv1(features, edge_index))
h2 = self.conv2(h1, edge_index, attention=False)
self.conv3.lin_src.data = self.conv2.lin_src.transpose(0, 1)
self.conv3.lin_dst.data = self.conv2.lin_dst.transpose(0, 1)
self.conv4.lin_src.data = self.conv1.lin_src.transpose(0, 1)
self.conv4.lin_dst.data = self.conv1.lin_dst.transpose(0, 1)
h3 = F.elu(self.conv3(h2, edge_index, attention=True,
tied_attention=self.conv1.attentions))
h4 = self.conv4(h3, edge_index, attention=False)
return h2, h4 # F.log_softmax(x, dim=-1)
具体的GAT代码不放了,详见`"Graph Attention Networks"
具体训练代码如下,不同点是加了梯度截断,最后返回h2,或者说是z,也就是特征向量用于下一步聚类分析,保存到adata中。
optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
loss_list = []
for epoch in tqdm(range(1, n_epochs+1)):
model.train()
optimizer.zero_grad()
z, out = model(data.x, data.edge_index)
loss = F.mse_loss(data.x, out) #F.nll_loss(out[data.train_mask], data.y[data.train_mask])
loss_list.append(loss)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), gradient_clipping)
optimizer.step()
model.eval()
z, out = model(data.x, data.edge_index)
STAGATE_rep = z.to('cpu').detach().numpy()
adata.obsm[key_added] = STAGATE_rep
if save_loss:
adata.uns['STAGATE_loss'] = loss
if save_reconstrction:
ReX = out.to('cpu').detach().numpy()
ReX[ReX<0] = 0
adata.layers['STAGATE_ReX'] = ReX
最后调用了R中的mclust包进行聚类。
def mclust_R(adata, num_cluster, modelNames='EEE', used_obsm='STAGATE', random_seed=2020):
"""\
Clustering using the mclust algorithm.
The parameters are the same as those in the R package mclust.
"""
np.random.seed(random_seed)
import rpy2.robjects as robjects
robjects.r.library("mclust")
import rpy2.robjects.numpy2ri
rpy2.robjects.numpy2ri.activate()
r_random_seed = robjects.r['set.seed']
r_random_seed(random_seed)
rmclust = robjects.r['Mclust']
res = rmclust(rpy2.robjects.numpy2ri.numpy2rpy(adata.obsm[used_obsm]), num_cluster, modelNames)
mclust_res = np.array(res[-2])
adata.obs['mclust'] = mclust_res
adata.obs['mclust'] = adata.obs['mclust'].astype('int')
adata.obs['mclust'] = adata.obs['mclust'].astype('category')
return adata
去掉缺失值并计算ARI。tensorflow版本和后续的数据分析解析等我看明白再来记录,最后附上测试DFPFC数据库的主函数。所有代码、数据和论文可以再github上下载,欢迎交流。
import warnings
warnings.filterwarnings("ignore")
import pandas as pd
import numpy as np
import scanpy as sc
import matplotlib.pyplot as plt
import os
import sys
from sklearn.metrics.cluster import adjusted_rand_score
# import sklearn
import STAGATE_pyG as STAGATE
os.environ['R_HOME'] = '/home/admin/anaconda3/envs/lib/R'
# os.environ['R_USER'] = '/home/admin/Anaconda3\Lib\site-packages/rpy2'
dataset = ["151507", "151508", "151509", "151510", "151669", "151670", "151671", "151672", "151673", "151674", "151675",
"151676"]
knn = [7, 7, 7, 7, 5, 5, 5, 5, 7, 7, 7, 7]
ARIlist = []
for section_id, k in zip(dataset, knn):
print(section_id,k)
input_dir = os.path.join('Data', section_id)
adata = sc.read_visium(path=input_dir, count_file=section_id+'_filtered_feature_bc_matrix.h5')
adata.var_names_make_unique()
#Normalization
sc.pp.highly_variable_genes(adata, flavor="seurat_v3", n_top_genes=3000)
sc.pp.normalize_total(adata, target_sum=1e4)
sc.pp.log1p(adata)
Ann_df = pd.read_csv(os.path.join('Data',
section_id, "cluster_labels_"+section_id+'.csv'), sep=',', header=0, index_col=0)
adata.obs['ground_truth'] = Ann_df.loc[adata.obs_names, 'ground_truth']
plt.rcParams["figure.figsize"] = (3, 3)
sc.pl.spatial(adata, img_key="hires", color=["ground_truth"])
STAGATE.Cal_Spatial_Net(adata, rad_cutoff=150)
STAGATE.Stats_Spatial_Net(adata)
adata = STAGATE.train_STAGATE(adata)
sc.pp.neighbors(adata, use_rep='STAGATE')
sc.tl.umap(adata)
adata = STAGATE.mclust_R(adata, used_obsm='STAGATE', num_cluster=k)
obs_df = adata.obs.dropna()
ARI = adjusted_rand_score(obs_df['mclust'], obs_df['ground_truth'])
ARIlist.append(ARI)
print('Adjusted rand index = %.2f' %ARI)
print("ari mean", np.mean(ARIlist))
print("ari median", np.median(ARIlist))