Source code for scSLAT.model.preprocess

r"""
Data preprocess and build graph
"""
from typing import List, Optional, Union

import pandas as pd
import scanpy as sc
import torch
from anndata import AnnData
from torch_geometric.nn import knn_graph, radius_graph
from torch_geometric.utils import to_undirected


[docs]def Cal_Spatial_Net( adata: AnnData, rad_cutoff: Optional[Union[None, int]] = None, k_cutoff: Optional[Union[None, int]] = None, model: Optional[str] = "Radius", return_data: Optional[bool] = False, verbose: Optional[bool] = True, ) -> None: r""" Construct the spatial neighbor networks. Parameters ---------- adata AnnData object of scanpy package. rad_cutoff radius cutoff when model='Radius' k_cutoff The number of nearest neighbors when model='KNN' model The network construction model. When model=='Radius', the spot is connected to spots whose distance is less than rad_cutoff. When model=='KNN', the spot is connected to its first k_cutoff nearest neighbors. Returns ------- The spatial networks are saved in adata.uns['Spatial_Net'] """ assert model in ["Radius", "KNN"] if verbose: print("Calculating spatial neighbor graph ...") if model == "KNN": edge_index = knn_graph( x=torch.tensor(adata.obsm["spatial"]), flow="target_to_source", k=k_cutoff, loop=True, num_workers=8, ) edge_index = to_undirected( edge_index, num_nodes=adata.shape[0] ) # ensure the graph is undirected elif model == "Radius": edge_index = radius_graph( x=torch.tensor(adata.obsm["spatial"]), flow="target_to_source", r=rad_cutoff, loop=True, num_workers=8, ) graph_df = pd.DataFrame(edge_index.numpy().T, columns=["Cell1", "Cell2"]) id_cell_trans = dict(zip(range(adata.n_obs), adata.obs_names)) graph_df["Cell1"] = graph_df["Cell1"].map(id_cell_trans) graph_df["Cell2"] = graph_df["Cell2"].map(id_cell_trans) adata.uns["Spatial_Net"] = graph_df if verbose: print(f"The graph contains {graph_df.shape[0]} edges, {adata.n_obs} cells.") print(f"{graph_df.shape[0]/adata.n_obs} neighbors per cell on average.") if return_data: return adata
[docs]def scanpy_workflow( adata: AnnData, filter_cell: Optional[bool] = False, min_gene: Optional[int] = 200, min_cell: Optional[int] = 30, call_hvg: Optional[bool] = True, n_top_genes: Optional[Union[int, List]] = 2500, batch_key: Optional[str] = None, n_comps: Optional[int] = 50, viz: Optional[bool] = False, resolution: Optional[float] = 0.8, ) -> AnnData: r""" Scanpy workflow using Seurat HVG Parameters ---------- adata adata filter_cell whether to filter cells and genes min_gene min number of genes per cell min_cell min number of cells per gene call_hvg whether to call highly variable genes (only support seurat_v3 method) n_top_genes n top genes or gene list n_comps n PCA components viz whether to run visualize steps resolution resolution for leiden clustering (used when viz=True) Return ---------- anndata object """ if "counts" not in adata.layers.keys(): adata.layers["counts"] = adata.X.copy() if filter_cell: sc.pp.filter_cells(adata, min_genes=min_gene) sc.pp.filter_genes(adata, min_cells=min_cell) if call_hvg: if isinstance(n_top_genes, int): if adata.n_vars > n_top_genes: sc.pp.highly_variable_genes( adata, n_top_genes=n_top_genes, flavor="seurat_v3", batch_key=batch_key ) else: adata.var["highly_variable"] = True print("All genes are highly variable.") elif isinstance(n_top_genes, list): adata.var["highly_variable"] = False n_top_genes = list(set(adata.var.index).intersection(set(n_top_genes))) adata.var.loc[n_top_genes, "highly_variable"] = True else: print("Skip calling highly variable genes.") sc.pp.normalize_total(adata, target_sum=1e4) sc.pp.log1p(adata) sc.pp.scale(adata, max_value=10) if n_comps > 0: sc.tl.pca(adata, n_comps=n_comps, svd_solver="arpack") if viz: sc.pp.neighbors(adata, n_neighbors=15, n_pcs=n_comps) sc.tl.umap(adata) sc.tl.leiden(adata, resolution=resolution) return adata