r"""
Useful functions
"""
import time
import math
from typing import List, Mapping, Optional, Union
from pathlib import Path
from joblib import Parallel, delayed
import faiss
import scanpy as sc
import torch
import numpy as np
import pandas as pd
from sklearn.neighbors import NearestNeighbors
from sklearn.metrics.pairwise import euclidean_distances
from anndata import AnnData
from ..utils import get_free_gpu
from .train import train_GAN, train_reconstruct
from .graphmodel import LGCN, LGCN_mlp, WDiscriminator, ReconDNN
from .loaddata import load_anndatas
from .preprocess import Cal_Spatial_Net
[docs]def run_LGCN(features:List,
edges:List,
LGCN_layer:Optional[int]=2
):
"""
Run LGCN model
Parameters
----------
features
list of graph node features
edges
list of graph edges
LGCN_layer
LGCN layer number, we suggest set 2 for barcode based and 4 for fluorescence based
"""
try:
gpu_index = get_free_gpu()
print(f"Choose GPU:{gpu_index} as device")
except:
print('GPU is not available')
device = torch.device(f'cuda:{gpu_index}' if torch.cuda.is_available() else 'cpu')
for i in range(len(features)):
features[i] = features[i].to(device)
for j in range(len(edges)):
edges[j] = edges[j].to(device)
LGCN_model =LGCN(input_size=features[0].size(1), K=LGCN_layer).to(device=device)
time1 = time.time()
embd0 = LGCN_model(features[0], edges[0])
embd1 = LGCN_model(features[1], edges[1])
run_time = time.time() - time1
print(f'LGCN time: {run_time}')
return embd0, embd1, run_time
[docs]def run_SLAT(features:List,
edges:List,
epochs:Optional[int]=6,
LGCN_layer:Optional[int]=1,
mlp_hidden:Optional[int]=256,
hidden_size:Optional[int]=2048,
alpha:Optional[float]=0.01,
anchor_scale:Optional[float]=0.8,
lr_mlp:Optional[float]=0.0001,
lr_wd:Optional[float]=0.0001,
lr_recon:Optional[float]=0.01,
batch_d_per_iter:Optional[int]=5,
batch_r_per_iter:Optional[int]=10
) -> List:
r"""
Run SLAT model
Parameters
----------
features
list of graph node features
edges
list of graph edges
epochs
epoch number of SLAT (not exceed 10)
LGCN_layer
LGCN layer number, we suggest set 1 for barcode based and 4 for fluorescence based
mlp_hidden
MLP hidden layer size
hidden_size
size of LGCN output
transform
if use transform
alpha
scale of loss
anchor_scale
ratio of cells selected as pairs
lr_mlp
learning rate of MLP
lr_wd
learning rate of WGAN discriminator
lr_recon
learning rate of reconstruction
batch_d_per_iter
batch number for WGAN train per iter
batch_r_per_iter
batch number for reconstruct train per iter
Return
----------
embd0
cell embedding of dataset1
embd1
cell embedding of dataset2
time
run time of SLAT model
"""
feature_size = features[0].size(1)
feature_output_size = hidden_size
try:
gpu_index = get_free_gpu()
print(f"Choose GPU:{gpu_index} as device")
except:
print('GPU is not available')
device = torch.device(f'cuda:{gpu_index}' if torch.cuda.is_available() else 'cpu')
for i in range(len(features)):
features[i] = features[i].to(device)
for j in range(len(edges)):
edges[j] = edges[j].to(device)
feature_size = features[0].size(1)
feature_output_size = hidden_size
LGCN_model = LGCN_mlp(feature_size, hidden_size, K=LGCN_layer, hidden_size=mlp_hidden).to(device)
optimizer_LGCN = torch.optim.Adam(LGCN_model.parameters(), lr=lr_mlp, weight_decay=5e-4)
wdiscriminator = WDiscriminator(feature_output_size).to(device)
optimizer_wd = torch.optim.Adam(wdiscriminator.parameters(), lr=lr_wd, weight_decay=5e-4)
recon_model0 = ReconDNN(feature_output_size, feature_size).to(device)
recon_model1 = ReconDNN(feature_output_size, feature_size).to(device)
optimizer_recon0 = torch.optim.Adam(recon_model0.parameters(), lr=lr_recon, weight_decay=5e-4)
optimizer_recon1 = torch.optim.Adam(recon_model1.parameters(), lr=lr_recon, weight_decay=5e-4)
print('Running')
time1 = time.time()
for i in range(1, epochs + 1):
print(f'---------- epochs: {i} ----------')
LGCN_model.train()
optimizer_LGCN.zero_grad()
embd0 = LGCN_model(features[0], edges[0])
embd1 = LGCN_model(features[1], edges[1])
loss = train_GAN(wdiscriminator, optimizer_wd, [embd0,embd1], batch_d_per_iter=batch_d_per_iter, anchor_scale=anchor_scale)
loss_feature = train_reconstruct([recon_model0, recon_model1], [optimizer_recon0, optimizer_recon1], [embd0,embd1], features,batch_r_per_iter=batch_r_per_iter)
loss = (1-alpha) * loss + alpha * loss_feature
loss.backward()
optimizer_LGCN.step()
LGCN_model.eval()
embd0 = LGCN_model(features[0], edges[0])
embd1 = LGCN_model(features[1], edges[1])
time2 = time.time()
print('Training model time: %.2f' % (time2-time1))
# torch.cuda.empty_cache()
return embd0, embd1, time2-time1
[docs]def spatial_match(embds:List[torch.Tensor],
reorder:Optional[bool]=True,
top_n:Optional[int]=20,
smooth:Optional[bool]=True,
smooth_range:Optional[int]=20,
scale_coord:Optional[bool]=True,
adatas:Optional[List[AnnData]]=None,
verbose:Optional[bool]=False
)-> List[Union[np.ndarray,torch.Tensor]]:
r"""
Use embedding to match cells from different datasets based on cosine similarity
Parameters
----------
embds
list of embeddings
reorder
if reorder embedding by cell numbers
top_n
return top n of cosine similarity
smooth
if smooth the mapping by Euclid distance
smooth_range
use how many candidates to do smooth
scale_coord
if scale the coordinate to [0,1]
adatas
list of adata object
verbose
if print log
Note
----------
Automatically use larger dataset as source
Return
----------
Best matching, Top n matching and cosine similarity matrix of top n
Note
----------
Use faiss to accelerate, refer https://github.com/facebookresearch/faiss/issues/95
"""
if reorder and embds[0].shape[0] < embds[1].shape[0]:
embd0 = embds[1]
embd1 = embds[0]
adatas = adatas[::-1] if adatas is not None else None
else:
embd0 = embds[0]
embd1 = embds[1]
index = faiss.index_factory(embd1.shape[1], "Flat", faiss.METRIC_INNER_PRODUCT)
embd0_np = embd0.detach().cpu().numpy() if torch.is_tensor(embd0) else embd0
embd1_np = embd1.detach().cpu().numpy() if torch.is_tensor(embd1) else embd1
embd0_np = embd0_np.copy().astype('float32')
embd1_np = embd1_np.copy().astype('float32')
faiss.normalize_L2(embd0_np)
faiss.normalize_L2(embd1_np)
index.add(embd0_np)
distance, order = index.search(embd1_np, top_n)
best = []
if smooth and adatas != None:
smooth_range = min(smooth_range, top_n)
if verbose:
print('Smoothing mapping, make sure object is in same direction')
if scale_coord:
# scale spatial coordinate of every adata to [0,1]
adata1_coord = adatas[0].obsm['spatial'].copy()
adata2_coord = adatas[1].obsm['spatial'].copy()
for i in range(2):
adata1_coord[:,i] = (adata1_coord[:,i]-np.min(adata1_coord[:,i]))/(np.max(adata1_coord[:,i])-np.min(adata1_coord[:,i]))
adata2_coord[:,i] = (adata2_coord[:,i]-np.min(adata2_coord[:,i]))/(np.max(adata2_coord[:,i])-np.min(adata2_coord[:,i]))
for query in range(embd1_np.shape[0]):
ref_list = order[query, :smooth_range]
dis = euclidean_distances(adata2_coord[query,:].reshape(1, -1),
adata1_coord[ref_list,:])
best.append(ref_list[np.argmin(dis)])
else:
best = order[:,0]
return np.array(best), order, distance
[docs]def run_SLAT_multi(adatas:List[AnnData],
order:Optional[list]=None,
k_cutoff:Optional[int]=10,
feature:Optional[str]='DPCA',
cos_cutoff:Optional[float]=0.85,
n_jobs:Optional[int]=-1,
top_K:Optional[int]=50
)->List[np.ndarray]:
r"""
Run SLAT on multi-dataset for 3D re-construct
Parameters
-----------
adatas
list of adatas
order
biological order of the slides
k_cutoff
k nearest neighbor
feature
feature to use, one of ['DPCA', 'PCA', 'harmony']
cos_cutoff
cosine similarity cutoff of mapping results
n_jobs
cpu cores to use
top_K
top K smooth mapping results
Return
----------
matching_list
list of precise mapping results
index_list
list of top mapping index
"""
order = range(len(adatas)) if order == None else order
n_jobs = len(adatas) + 1 if n_jobs < 0 else n_jobs
# for adata in adatas:
# Cal_Spatial_Net(adata, k_cutoff=k_cutoff, model='KNN')
adatas = Parallel(n_jobs=n_jobs)(delayed(Cal_Spatial_Net)(adata, k_cutoff=k_cutoff, model='KNN',return_data=True) for adata in adatas)
matching_list = []
def parall_SLAT(a1, a2, i):
print(f'Parallel mapping dataset:{i} --- dataset:{i+1}')
edges, features = load_anndatas([a1, a2], feature=feature, check_order=False)
embd0, embd1, _ = run_SLAT(features, edges)
best, index, distance = spatial_match([embd0,embd1], reorder=False, smooth_range=top_K, adatas=[a1,a2])
return best, index, distance
unfiltered_zip_res = Parallel(n_jobs=n_jobs)(delayed(parall_SLAT)(a1,a2,i)\
for i, (a1, a2) in enumerate(zip(adatas, adatas[1:])))
matching_list = []
for (best, index, distance) in unfiltered_zip_res:
matching = np.array([range(index.shape[0]), best])
matching_filter = matching[:, distance[:,0]>cos_cutoff]
matching_list.append(matching_filter)
# assert [x.shape[1] for x in matching_list] == [y.shape[0] for y in adatas[1:]] # check order
return matching_list, unfiltered_zip_res
[docs]def calc_k_neighbor(features:List[torch.Tensor],
k_list:List[int]
) -> Mapping:
r"""
cal k nearest neighbor
Parameters:
----------
features
feature list to find KNN
k
list of k to find (must have 2 elements)
"""
assert len(k_list) == 2
k_list = sorted(k_list)
nbr_dict = {}
for k in k_list:
nbr_dict[k] = [None, None]
for i, feature in enumerate(features): # feature loop first
for k in k_list: # then k list loop
nbrs = NearestNeighbors(n_neighbors=k, algorithm='ball_tree', n_jobs=-1).fit(feature)
distances, indices = nbrs.kneighbors(feature) # indices include it self
nbr_dict[k][i] = nbrs
return nbr_dict
[docs]def add_noise(adata,
noise:Optional[str]='nb',
inverse_noise:Optional[float]=5
) -> AnnData:
r"""
Add poisson or negative binomial noise on raw counts
also run scanpy pipeline to PCA step
Parameters
----------
adata
anndata object
noise
type of noise, one of 'poisson' or 'nb'
inverse_noise
if noise is 'nb', control the noise level
(smaller means larger variance)
"""
if 'counts' not in adata.layers.keys():
adata.layers["counts"] = adata.X.copy()
mu = torch.tensor(adata.X.todense())
if noise.lower() == 'poisson':
adata.X = torch.distributions.poisson.Poisson(mu).sample().numpy()
elif noise.lower() == 'nb':
adata.X = torch.distributions.negative_binomial.NegativeBinomial(inverse_noise,logits=(mu.log()-math.log(inverse_noise))).sample().numpy()
else:
raise NotImplementedError('Can not add this type noise')
return adata.copy()