Source code for scSLAT.utils

r"""
Miscellaneous utilities
"""
import random
from subprocess import run
from typing import Optional, Union

import numpy as np
import pynvml
import torch
from anndata import AnnData


[docs]def norm_to_raw( adata: AnnData, library_size: Optional[Union[str, np.ndarray]] = "total_counts", check_size: Optional[int] = 100, ) -> AnnData: r""" Convert normalized adata.X to raw counts Parameters ---------- adata adata to be convert library_size raw library size of every cells, can be a key of `adata.obs` or a array check_size check the head `[0:check_size]` row and column to judge if adata normed Note ---------- Adata must follow scanpy official norm step """ check_chunk = adata.X[0:check_size, 0:check_size].todense() assert not all(isinstance(x, int) for x in check_chunk) from scipy import sparse scale_size = np.array(adata.X.expm1().sum(axis=1).round()).flatten() if isinstance(library_size, str): scale_factor = np.array(adata.obs[library_size]) / scale_size elif isinstance(library_size, np.ndarray): scale_factor = library_size / scale_size else: try: scale_factor = np.array(library_size) / scale_size except ValueError: raise ValueError("Invalid `library_size`") scale_factor.resize((scale_factor.shape[0], 1)) raw_count = sparse.csr_matrix.multiply( sparse.csr_matrix(adata.X).expm1(), sparse.csr_matrix(scale_factor) ) raw_count = sparse.csr_matrix(np.round(raw_count)) adata.X = raw_count # adata.layers['counts'] = raw_count return adata
[docs]def get_free_gpu() -> int: r""" Get index of GPU with least memory usage Ref ---------- https://stackoverflow.com/questions/58216000/get-total-amount-of-free-gpu-memory-and-available-using-pytorch """ index = 0 if torch.cuda.is_available(): pynvml.nvmlInit() max = 0 for i in range(torch.cuda.device_count()): h = pynvml.nvmlDeviceGetHandleByIndex(i) info = pynvml.nvmlDeviceGetMemoryInfo(h) index = i if info.free > max else index max = info.free if info.free > max else max return index
[docs]def global_seed(seed: int) -> None: r""" Set seed Parameters ---------- seed int """ if seed > 2**32 - 1 or seed < 0: seed = 0 random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) print(f"Global seed set to {seed}.")
[docs]def install_pyg_dep(torch_version: str = None, cuda_version: str = None) -> None: r""" Automatically install PyG dependencies Parameters ---------- torch_version torch version, e.g. 2.2.1 cuda_version cuda version, e.g. 12.1 """ if torch_version is None: torch_version = torch.__version__ torch_version = torch_version.split("+")[0] if cuda_version is None: cuda_version = torch.version.cuda if torch_version < "2.0": raise ValueError(f"PyG only support torch>=2.0, but get {torch_version}") elif "2.0" <= torch_version < "2.1": torch_version = "2.0.0" elif "2.1" <= torch_version < "2.2": torch_version = "2.1.0" elif "2.2" <= torch_version < "2.3": torch_version = "2.2.0" elif "2.3" <= torch_version < "2.4": torch_version = "2.3.0" else: raise ValueError(f"Automatic install only support torch<=2.3, but get {torch_version}") if "cu" in cuda_version and not torch.cuda.is_available(): print("CUDA is not available, try install CPU version, but may raise error.") cuda_version = "cpu" elif cuda_version >= "12.1": cuda_version = "cu121" elif "11.8" <= cuda_version < "12.1": cuda_version = "cu118" elif "11.7" <= cuda_version < "11.8": cuda_version = "cu117" else: raise ValueError(f"PyG only support cuda>=11.7, but get {cuda_version}") url = "https://pytorch-geometric.readthedocs.io/en/latest/install/installation.html" if torch_version in ["2.2.0", "2.1.0"] and cuda_version == "cu117": raise ValueError( f"PyG not support torch-{torch_version} with cuda-11.7, please check {url}" ) if torch_version == "2.0.0" and cuda_version == "cu121": raise ValueError(f"PyG not support torch-2.0.* with cuda-12.1, please check {url}") print(f"Installing PyG dependencies for torch-{torch_version} and cuda-{cuda_version}") cmd = f"pip --no-cache-dir install pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-{torch_version}+{cuda_version}.html" run(cmd, shell=True)