r"""
Vis multi dataset and their connection
"""
import random
from typing import List, Optional, Union
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import plotly.graph_objects as go
import scanpy as sc
from anndata import AnnData
from .color import godsnot_102, vega_10_scanpy, vega_20_scanpy, zeileis_28
[docs]def get_color(n:int = 1, cmap: str = "scanpy", seed: int = 0):
r"""
Get color
Parameters
----------
n
number of colors you want
cmap
color map (use same with scanpy)
seed
random seed to duplicate
"""
if cmap == "scanpy" and n <= 10:
step = 10 // n
return vega_10_scanpy[::step][:n]
elif cmap == "scanpy" and n <= 20:
step = 20 // n
return vega_20_scanpy[::step][:n]
elif cmap == "scanpy" and n <= 28:
step = 28 // n
return zeileis_28[::step][:n]
elif cmap == "scanpy" and n <= 102:
step = 102 // n
return godsnot_102[::step][:n]
else:
print("WARNING: Using random color")
random.seed(seed)
if n == 1:
return "#" + "".join([random.choice("0123456789ABCDEF") for j in range(6)])
elif n > 1:
return [
"#" + "".join([random.choice("0123456789ABCDEF") for j in range(6)])
for i in range(n)
]
[docs]class build_3D:
r"""
Build 3D pics/models from multi-datasets
Parameters
----------
datasets
list adata of in order
mappings
list of SLAT matching results
spatial_key
obsm key of spatial info
anno_key
obs key of cell annotation such as celltype
subsample_size
subsample size of matches
scale_coordinate
scale the coordinate from different slides
"""
def __init__(
self,
adatas: List[AnnData],
mappings: List[np.ndarray],
spatial_key: Optional[str] = "spatial",
anno_key: Optional[str] = "annotation",
subsample_size: Optional[int] = 200,
scale_coordinate: Optional[bool] = True,
) -> None:
assert len(mappings) == len(adatas) - 1
self.mappings = mappings
self.loc_list = []
self.anno_list = []
for adata in adatas:
loc = adata.obsm[spatial_key].copy()
if scale_coordinate:
for i in range(2):
loc[:, i] = (loc[:, i] - np.min(loc[:, i])) / (
np.max(loc[:, i]) - np.min(loc[:, i])
)
anno = adata.obs[anno_key]
self.loc_list.append(loc)
self.anno_list.append(anno)
self.adatas = adatas
self.anno_key = anno_key
self.celltypes = set(pd.concat(self.anno_list))
self.subsample_size = subsample_size
[docs] def draw_3D(
self,
size: Optional[List[int]] = [10, 10],
point_size: Optional[List[int]] = [0.5, 0.5],
point_alpha: Optional[float] = 0.6,
line_width: Optional[float] = 0.6,
line_color: Optional[str] = "#4169E1",
line_alpha: Optional[float] = 0.8,
hide_axis: Optional[bool] = False,
height: Optional[float] = 1.0,
height_scale: Optional[float] = 1.0,
) -> None:
r"""
Draw 3D picture of two layers
Parameters
----------
size
plt figure size (width, height)
point_size
point size of each layer
point_alpha
point alpha of each layer
line_width
pair line width
line_color
pair line color
line_alpha
pair line alpha
hide_axis
if hide axis
height
height of one layer
"""
fig = plt.figure(figsize=(size[0], size[1]))
ax = fig.add_subplot(111, projection="3d")
ax.set_box_aspect([1, 1, height_scale * len(self.mappings)])
# color by different cell types
color = get_color(len(self.celltypes))
c_map = {}
for i, celltype in enumerate(self.celltypes):
c_map[celltype] = color[i]
for j, mapping in enumerate(self.mappings):
print(f"Mapping {j}th layer ")
# plot cells
for i, (layer, anno, ad) in enumerate(
zip(self.loc_list[j : j + 2], self.anno_list[j : j + 2], self.adatas[j : j + 2])
):
if i == 0 and 0 < j < len(self.mappings) - 1:
continue
ad.obs[self.anno_key] = ad.obs[self.anno_key].astype("category")
if f"{self.anno_key}_colors" in ad.uns.keys():
c_map = dict(
zip(
ad.obs[self.anno_key].cat.categories.tolist(),
ad.uns[f"{self.anno_key}_colors"],
)
)
else:
if len(ad.obs[self.anno_key].cat.categories) > 28:
c_map = dict(
zip(ad.obs[self.anno_key].cat.categories, sc.pl.palettes.default_102)
)
else:
c_map = dict(
zip(ad.obs[self.anno_key].cat.categories, sc.pl.palettes.zeileis_28)
)
for cell_type in ad.obs[self.anno_key].cat.categories:
slice = layer[anno == cell_type, :]
xs = slice[:, 0]
ys = slice[:, 1]
zs = height * (j + i)
ax.scatter(
xs,
ys,
zs,
s=point_size[i],
c=c_map[cell_type],
alpha=point_alpha,
)
# plot mapping line
mapping = mapping[
:,
np.random.choice(mapping.shape[1], self.subsample_size, replace=False),
].copy()
for k in range(mapping.shape[1]):
cell1_index = mapping[:, k][0] # query
cell0_index = mapping[:, k][1] # ref
cell0_coord = self.loc_list[j][cell0_index, :]
cell1_coord = self.loc_list[j + 1][cell1_index, :]
coord = np.row_stack((cell0_coord, cell1_coord))
ax.plot(
coord[:, 0],
coord[:, 1],
[height * j, height * (j + 1)],
color=line_color,
linestyle="dashed",
linewidth=line_width,
alpha=line_alpha,
)
if hide_axis:
plt.axis("off")
# plt.show()
return ax
[docs]class match_3D_multi:
r"""
Plot the mapping result between 2 datasets
Parameters
----------
dataset_A
pandas dataframe which contain ['index','x','y'], reference dataset
dataset_B
pandas dataframe which contain ['index','x','y'], target dataset
matching
matching results
meta
dataframe colname of meta, such as celltype
expr
dataframe colname of gene expr
subsample_size
subsample size of matches
reliability
match score (cosine similarity score)
scale_coordinate
if scale coordinate via (:math:`data - np.min(data)) / (np.max(data) - np.min(data))`)
rotate
how to rotate the slides (force scale_coordinate), such as ['x','y'], means dataset0 rotate on x axes
and dataset1 rotate on y axes
change_xy
exchange x and y on dataset_B
subset
index of query cells to be plotted
Note
----------
dataset_A and dataset_B can in different length
"""
def __init__(
self,
dataset_A: pd.DataFrame,
dataset_B: pd.DataFrame,
matching: np.ndarray,
meta: Optional[str] = None,
expr: Optional[str] = None,
subsample_size: Optional[int] = 300,
reliability: Optional[np.ndarray] = None,
scale_coordinate: Optional[bool] = True,
rotate: Optional[List[str]] = None,
exchange_xy: Optional[bool] = False,
subset: Optional[List[int]] = None,
) -> None:
self.dataset_A = dataset_A.copy()
self.dataset_B = dataset_B.copy()
self.meta = meta
self.matching = matching
self.conf = reliability
self.subset = subset # index of query cells to be plotted
scale_coordinate = True if rotate is not None else scale_coordinate
assert all(item in dataset_A.columns.values for item in ["index", "x", "y"])
assert all(item in dataset_B.columns.values for item in ["index", "x", "y"])
if meta:
set1 = list(set(self.dataset_A[meta]))
set2 = list(set(self.dataset_B[meta]))
self.celltypes = set1 + [x for x in set2 if x not in set1]
self.celltypes.sort() # make sure celltypes are in the same order
overlap = set(set2).intersection(set1)
print(
f"dataset1: {len(set1)} cell types; dataset2: {len(set2)} cell types; \n\
Total :{len(self.celltypes)} celltypes; Overlap: {len(overlap)} cell types \n\
Not overlap :[{[y for y in (set1+set2) if y not in overlap]}]"
)
self.expr = expr if expr else False
if scale_coordinate:
for i, dataset in enumerate([self.dataset_A, self.dataset_B]):
for axis in ["x", "y"]:
dataset[axis] = (dataset[axis] - np.min(dataset[axis])) / (
np.max(dataset[axis]) - np.min(dataset[axis])
)
if rotate is None:
pass
elif axis in rotate[i]:
dataset[axis] = 1 - dataset[axis]
if exchange_xy:
self.dataset_B[["x", "y"]] = self.dataset_B[["y", "x"]]
if subset is not None:
matching = matching[:, subset]
if matching.shape[1] > subsample_size and subsample_size > 0:
self.matching = matching[
:, np.random.choice(matching.shape[1], subsample_size, replace=False)
]
else:
subsample_size = matching.shape[1]
self.matching = matching
print(f"Subsampled {subsample_size} pairs from {matching.shape[1]}")
self.datasets = [self.dataset_A, self.dataset_B]
[docs] def draw_3D(
self,
size: Optional[List[int]] = [10, 10],
conf_cutoff: Optional[float] = 0,
point_size: Optional[List[float]] = [0.1, 0.1],
line_width: Optional[float] = 0.3,
line_color: Optional[str] = "grey",
line_alpha: Optional[float] = 0.7,
hide_axis: Optional[bool] = False,
show_error: Optional[bool] = True,
show_celltype: Optional[bool] = False,
cmap: Optional[str] = "Reds",
save: Optional[str] = None,
) -> None:
r"""
Draw 3D picture of two datasets
Parameters
----------
size
plt figure size
conf_cutoff
confidence cutoff of mapping to be plotted
point_size
point size of every dataset
line_width
pair line width
line_color
pair line color
line_alpha
pair line alpha
hide_axis
if hide axis
show_error
if show error celltype mapping with different color
cmap
color map when vis expr
save
save file path
"""
self.conf_cutoff = conf_cutoff
show_error = show_error if self.meta else False
fig = plt.figure(figsize=(size[0], size[1]))
ax = fig.add_subplot(111, projection="3d")
# color by meta
if self.meta:
color = get_color(len(self.celltypes))
c_map = {}
for i, celltype in enumerate(self.celltypes):
c_map[celltype] = color[i]
if self.expr:
c_map = cmap
# expr_concat = pd.concat(self.datasets)[self.expr].to_numpy()
# norm = plt.Normalize(expr_concat.min(), expr_concat.max())
for i, dataset in enumerate(self.datasets):
if self.expr:
norm = plt.Normalize(
dataset[self.expr].to_numpy().min(),
dataset[self.expr].to_numpy().max(),
)
for cell_type in self.celltypes:
slice = dataset[dataset[self.meta] == cell_type]
xs = slice["x"]
ys = slice["y"]
zs = i
if self.expr:
ax.scatter(
xs,
ys,
zs,
s=point_size[i],
c=slice[self.expr],
cmap=c_map,
norm=norm,
)
else:
ax.scatter(xs, ys, zs, s=point_size[i], c=c_map[cell_type])
# plot points without meta
else:
for i, dataset in enumerate(self.datasets):
xs = dataset["x"]
ys = dataset["y"]
zs = i
ax.scatter(xs, ys, zs, s=point_size[i])
# plot line
self.c_map = c_map
self.draw_lines(ax, show_error, show_celltype, line_color, line_width, line_alpha)
if hide_axis:
plt.axis("off")
if save is not None:
plt.savefig(save)
plt.show()
[docs] def draw_lines(
self, ax, show_error, show_celltype, line_color, line_width=0.3, line_alpha=0.7
) -> None:
r"""
Draw lines between paired cells in two datasets
"""
for i in range(self.matching.shape[1]):
if self.conf is not None and self.conf[i] < self.conf_cutoff:
continue
pair = self.matching[:, i]
default_color = line_color
if self.meta is not None:
celltype1 = (
self.dataset_A.loc[self.dataset_A["index"] == pair[1], self.meta]
.astype(str)
.values[0]
)
celltype2 = (
self.dataset_B.loc[self.dataset_B["index"] == pair[0], self.meta]
.astype(str)
.values[0]
)
if show_error:
if celltype1 == celltype2:
color = "#ade8f4" # blue
else:
color = "#ffafcc" # red
if show_celltype:
if celltype1 == celltype2:
color = self.c_map[celltype1]
else:
color = "#696969" # celltype1 error match color
point0 = np.append(self.dataset_A[self.dataset_A["index"] == pair[1]][["x", "y"]], 0)
point1 = np.append(self.dataset_B[self.dataset_B["index"] == pair[0]][["x", "y"]], 1)
coord = np.row_stack((point0, point1))
color = color if show_error or show_celltype else default_color
ax.plot(
coord[:, 0],
coord[:, 1],
coord[:, 2],
color=color,
linestyle="dashed",
linewidth=line_width,
alpha=line_alpha,
)
[docs]class match_3D_multi_error(match_3D_multi):
r"""
Highlight the error mapping between datasets, child of class:`match_3D_multi()`
Parameters
----------
dataset_A
pandas dataframe which contain ['index','x','y']
dataset_B
pandas dataframe which contain ['index','x','y']
matching
matching results
mode
which cell pairs to highlight
highlight_color
color to highlight the line
meta
dataframe colname of meta, such as celltype
expr
dataframe colname of gene expr
subsample_size
subsample size of matches
reliability
if the match is reliable
scale_coordinate
if scale the coordinate via `data - np.min(data)) / (np.max(data) - np.min(data))`
rotate
how to rotate the slides (force scale_coordinate)
change_xy
exchange x and y on dataset_B
subset
index of query cells to be plotted
Note
----------
dataset_A and dataset_B can in different length
"""
def __init__(
self,
dataset_A: pd.DataFrame,
dataset_B: pd.DataFrame,
matching: np.ndarray,
mode: Optional[str] = "high_true",
highlight_color: Optional[str] = "red",
meta: Optional[str] = None,
expr: Optional[str] = None,
subsample_size: Optional[int] = 300,
reliability: Optional[np.ndarray] = None,
scale_coordinate: Optional[bool] = False,
rotate: Optional[List[str]] = None,
exchange_xy: Optional[bool] = False,
subset: Optional[Union[np.ndarray, List[int]]] = None,
) -> None:
super().__init__(
dataset_A,
dataset_B,
matching,
meta,
expr,
subsample_size,
reliability,
scale_coordinate,
rotate,
exchange_xy,
subset,
)
assert mode in ["high_true", "low_true", "high_false", "low_false"]
self.mode = mode
self.highlight_color = highlight_color
[docs] def draw_lines(
self,
ax,
show_error,
show_celltype,
default_color,
line_width=0.3,
line_alpha=0.7,
) -> None:
for i in range(self.matching.shape[1]):
pair = self.matching[:, i]
if (
self.dataset_B.loc[self.dataset_B["index"] == pair[0], "celltype"]
.astype(str)
.values
== self.dataset_A.loc[self.dataset_A["index"] == pair[1], "celltype"]
.astype(str)
.values
):
if "false" in self.mode:
continue
if self.conf is not None:
if "low" in self.mode and not self.conf[i]:
continue
point0 = np.append(self.dataset_A[self.dataset_A["index"] == pair[1]][["x", "y"]], 0)
point1 = np.append(self.dataset_B[self.dataset_B["index"] == pair[0]][["x", "y"]], 1)
coord = np.row_stack((point0, point1))
ax.scatter(point0[0], point0[1], point0[2], color="red", alpha=1, s=0.3)
ax.scatter(point1[0], point1[1], point1[2], color="red", alpha=1, s=0.3)
ax.plot(
coord[:, 0],
coord[:, 1],
coord[:, 2],
color=self.highlight_color,
linestyle="dashed",
linewidth=line_width,
alpha=line_alpha,
)
[docs]class match_3D_celltype(match_3D_multi):
r"""
Highlight the celltype mapping, child of class:`match_3D_multi()`
Parameters
----------
dataset_A
pandas dataframe which contain ['index','x','y']
dataset_B
pandas dataframe which contain ['index','x','y']
matching
matching results
highlight_celltype
celltypes to highlight in two datasets
highlight_line
color to highlight the line
highlight_cell
color to highlight the cell
meta
dataframe col name of meta, such as celltype
expr
dataframe col name of gene expr
subsample_size
subsample size of matches
reliability
if the match is reliable
scale_coordinate
if scale the coordinate via `data - np.min(data)) / (np.max(data) - np.min(data))`
rotate
how to rotate the slides (force scale_coordinate)
change_xy
exchange x and y on dataset_B
subset
index of query cells to be plotted
Note
----------
dataset_A and dataset_B can in different length
"""
def __init__(
self,
dataset_A: pd.DataFrame,
dataset_B: pd.DataFrame,
matching: np.ndarray,
highlight_celltype: Optional[List[List[str]]] = [[], []],
highlight_line: Optional[Union[List[str], str]] = "red",
highlight_cell: Optional[str] = None,
meta: Optional[str] = None,
expr: Optional[str] = None,
subsample_size: Optional[int] = 300,
reliability: Optional[np.ndarray] = None,
scale_coordinate: Optional[bool] = False,
rotate: Optional[List[str]] = None,
exchange_xy: Optional[bool] = False,
subset: Optional[Union[np.ndarray, List[int]]] = None,
) -> None:
super().__init__(
dataset_A,
dataset_B,
matching,
meta,
expr,
subsample_size,
reliability,
scale_coordinate,
rotate,
exchange_xy,
subset,
)
assert set(highlight_celltype[0]).issubset(set(self.celltypes))
assert set(highlight_celltype[1]).issubset(set(self.celltypes))
self.highlight_celltype = highlight_celltype
self.highlight_line = highlight_line
self.highlight_cell = highlight_cell
[docs] def draw_lines(
self,
ax,
show_error,
line_width: float = 0.3,
line_alpha: float = 0.7,
) -> None:
if len(self.highlight_celltype[0]) >= len(self.highlight_celltype[1]):
color_index = self.highlight_celltype[0]
else:
color_index = self.highlight_celltype[1]
if type(self.highlight_line) == list and len(self.highlight_line) >= len(color_index):
cmap = self.highlight_line
else:
cmap = get_color(len(color_index))
for i in range(self.matching.shape[1]):
pair = self.matching[:, i]
a = self.dataset_A.loc[self.dataset_A["index"] == pair[1], self.meta].astype(str).values
b = self.dataset_B.loc[self.dataset_B["index"] == pair[0], self.meta].astype(str).values
if a not in self.highlight_celltype[0] or b not in self.highlight_celltype[1]:
continue
point0 = np.append(self.dataset_A[self.dataset_A["index"] == pair[1]][["x", "y"]], 0)
point1 = np.append(self.dataset_B[self.dataset_B["index"] == pair[0]][["x", "y"]], 1)
coord = np.row_stack((point0, point1))
if self.highlight_cell:
ax.scatter(
point0[0],
point0[1],
point0[2],
color=self.highlight_cell,
alpha=1,
s=1,
)
ax.scatter(
point1[0],
point1[1],
point1[2],
color=self.highlight_cell,
alpha=1,
s=1,
)
if isinstance(cmap, list):
color = (
cmap[color_index.index(a)]
if len(self.highlight_celltype[0]) >= len(self.highlight_celltype[1])
else cmap[color_index.index(b)]
)
else:
color = cmap
color = color if show_error else self.highlight_line
ax.plot(
coord[:, 0],
coord[:, 1],
coord[:, 2],
color=color,
linestyle="dashed",
linewidth=line_width,
alpha=line_alpha,
)
[docs]def Sankey(
matching_table: pd.DataFrame,
filter_num: Optional[int] = 50,
color: Optional[List[str]] = "red",
title: Optional[str] = "",
prefix: Optional[List[str]] = ["E11.5", "E12.5"],
layout: Optional[List[int]] = [1300, 900],
font_size: Optional[float] = 15,
font_color: Optional[str] = "Black",
save_name: Optional[str] = None,
format: Optional[str] = "png",
width: Optional[int] = 1200,
height: Optional[int] = 1000,
return_fig: Optional[bool] = False,
) -> None:
r"""
Sankey plot of celltype
Parameters
----------
matching_tables
list of matching table
filter_num
filter number of matches
color
color of node
title
plot title
prefix
prefix to distinguish datasets
layout
layout size of picture
font_size
font size in plot
font_color
font color in plot
save_name
save file name (None for not save)
format
save picture format (see https://plotly.com/python/static-image-export/ for more details)
width
save picture width
height
save picture height
return_fig
if return plotly figure
"""
source, target, value = [], [], []
label_ref = [a + f"_{prefix[0]}" for a in matching_table.columns.to_list()]
label_query = [a + f"_{prefix[1]}" for a in matching_table.index.to_list()]
label_all = label_query + label_ref
label2index = dict(zip(label_all, list(range(len(label_all)))))
for i, query in enumerate(label_query):
for j, ref in enumerate(label_ref):
if int(matching_table.iloc[i, j]) > filter_num:
target.append(label2index[query])
source.append(label2index[ref])
value.append(int(matching_table.iloc[i, j]))
fig = go.Figure(
data=[
go.Sankey(
node=dict(
pad=50,
thickness=50,
line=dict(color="green", width=0.5),
label=label_all,
color=color,
),
link=dict(
source=source, # indices correspond to labels, eg A1, A2, A1, B1, ...
target=target,
value=value,
),
)
],
layout=go.Layout(autosize=False, width=layout[0], height=layout[1]),
)
fig.update_layout(title_text=title, font_size=font_size, font_color=font_color)
fig.show()
if save_name is not None:
fig.write_image(save_name + f".{format}", width=width, height=height)
if return_fig:
return fig
[docs]def multi_Sankey(
matching_tables: List[pd.DataFrame],
color: Optional[List[str]] = "random",
title: Optional[str] = "Sankey plot",
layout: Optional[List[int]] = [1300, 900],
day: Optional[float] = 0,
save_name: Optional[str] = None,
format: Optional[str] = "svg",
) -> None:
r"""
Sankey plot of celltype in multi datasets
Parameters
----------
matching_tables
list of matching table
color
how to color the nodes, 'random' for random color, 'celltype' for color by celltype
title
plot title
layout
layout size of picture
day
start day of dataset for temporal order
save_name
save file name (None for not save)
format
saved picture format (see https://plotly.com/python/static-image-export/ for more details)
"""
mappings = len(matching_tables) + 1
prefixes = [day + i for i in range(mappings)]
source, target, value, label_all = [], [], [], set()
for i, matching_table in enumerate(matching_tables):
label_ref = [a + f"_{prefixes[i]}" for a in matching_table.columns.to_list()]
label_query = [a + f"_{prefixes[i+1]}" for a in matching_table.index.to_list()]
# label_all.add(label_ref)
for i in label_ref + label_query:
label_all.add(i)
label2index = dict(zip(label_all, list(range(len(label_all)))))
for matching_table, prefix in zip(matching_tables, prefixes):
for i, query in enumerate(matching_table.index):
for j, ref in enumerate(matching_table.columns):
if int(matching_table.iloc[i, j]) > 10:
target.append(label2index[query + "_" + str(prefix + 1)])
source.append(label2index[ref + "_" + str(prefix)])
value.append(int(matching_table.iloc[i, j]))
if color == "random":
color = [get_color()] * matching_tables[0].shape[0]
for matching_table in matching_tables:
color += [get_color()] * matching_table.shape[1]
elif color == "celltype":
pass
fig = go.Figure(
data=[
go.Sankey(
node=dict(
pad=50,
thickness=50,
line=dict(color="green", width=0.5),
label=list(label_all),
color=color,
),
link=dict(
source=source, # indices correspond to labels, eg A1, A2, A1, B1, ...
target=target,
value=value,
),
)
],
layout=go.Layout(autosize=False, width=layout[0], height=layout[1]),
)
fig.update_layout(title_text=title, font_size=10)
if save_name is not None:
fig.write_image(save_name + f".{format}", width=layout[0], height=layout[1])
fig.show()
[docs]def matching_2d(
matching: np.ndarray,
ref: AnnData,
src: AnnData,
biology_meta: str,
topology_meta: str,
spot_size: Optional[int] = 5,
title: Optional[str] = "2D matching",
save: Optional[str] = None,
) -> None:
r"""
Visualize the matching result in 2D space
Parameters
----------
matching
matching result
ref
reference dataset
src
target dataset
biology_meta
celltype meta colname of adata.obs
topology_meta
region meta colname of adata.obs
spot_size
size of spot for visualization
title
plot title
save
save file name (None for not save)
"""
src.obs["target_celltype"] = ref.obs.iloc[matching[1, :], :][biology_meta].to_list()
src.obs["target_region"] = ref.obs.iloc[matching[1, :], :][topology_meta].to_list()
src.obs["vis"] = "celltype_false_region_false"
src.obs["vis"] = src.obs["vis"].astype("str")
cell_type_match = src.obs[biology_meta] == src.obs["target_celltype"]
region_match = src.obs[topology_meta] == src.obs["target_region"]
cell_type_match = cell_type_match.to_numpy()
region_match = region_match.to_numpy()
src.obs.loc[np.logical_and(cell_type_match, region_match), "vis"] = "celltype_true_region_true"
src.obs.loc[
np.logical_and(~cell_type_match, region_match), "vis"
] = "celltype_false_region_true"
src.obs.loc[
np.logical_and(cell_type_match, ~region_match), "vis"
] = "celltype_true_region_false"
sc.pl.spatial(
src,
color="vis",
spot_size=spot_size,
title=title,
palette=["red", "purple", "yellow", "green"],
save=save,
)
del src.obs["target_celltype"]
del src.obs["target_region"]
del src.obs["vis"]