r"""
Training functions with different strategy
"""
from math import ceil
from typing import List, Optional
import torch
import torch.nn.functional as F
[docs]def train_GAN(
wdiscriminator: torch.nn.Module,
optimizer_d: torch.optim.Optimizer,
embds: List[torch.Tensor],
batch_d_per_iter: Optional[int] = 5,
anchor_scale: Optional[float] = 0.8,
) -> torch.Tensor:
r"""
GAN training strategy
Parameters
----------
wdiscriminator
WGAN
optimizer_d
WGAN parameters
embds
list of LGCN embd
batch_d_per_iter
WGAN train iter numbers
anchor_scale
ratio of anchor cells
"""
embd0, embd1 = embds
wdiscriminator.train()
anchor_size = ceil(embd1.size(0) * anchor_scale)
for j in range(batch_d_per_iter):
w0 = wdiscriminator(embd0)
w1 = wdiscriminator(embd1)
anchor1 = w1.view(-1).argsort(descending=True)[:anchor_size]
anchor0 = w0.view(-1).argsort(descending=False)[:anchor_size]
embd0_anchor = embd0[anchor0, :].clone().detach()
embd1_anchor = embd1[anchor1, :].clone().detach()
optimizer_d.zero_grad()
loss = -torch.mean(wdiscriminator(embd0_anchor)) + torch.mean(wdiscriminator(embd1_anchor))
loss.backward()
optimizer_d.step()
for p in wdiscriminator.parameters():
p.data.clamp_(-0.1, 0.1)
w0 = wdiscriminator(embd0)
w1 = wdiscriminator(embd1)
anchor1 = w1.view(-1).argsort(descending=True)[:anchor_size]
anchor0 = w0.view(-1).argsort(descending=False)[:anchor_size]
embd0_anchor = embd0[anchor0, :]
embd1_anchor = embd1[anchor1, :]
loss = -torch.mean(wdiscriminator(embd1_anchor))
return loss
[docs]def feature_reconstruct_loss(
embd: torch.Tensor, x: torch.Tensor, recon_model: torch.nn.Module
) -> torch.Tensor:
r"""
Reconstruction loss (MSE)
Parameters
----------
embd
embd of a cell
x
input
recon_model
reconstruction model
"""
recon_x = recon_model(embd)
return torch.norm(recon_x - x, dim=1, p=2).mean()
[docs]def train_reconstruct(
recon_models: torch.nn.Module,
optimizer_recons,
embds: List[torch.Tensor],
features: List[torch.Tensor],
batch_r_per_iter: Optional[int] = 10,
) -> torch.Tensor:
r"""
Data reconstruction network training strategy
Parameters
----------
recon_models
list of reconstruction model
optimizer_recons
list of reconstruction optimizer
embds
list of LGCN embd
features
list of rae node features
batch_d_per_iter
WGAN train iter numbers
"""
recon_model0, recon_model1 = recon_models
optimizer_recon0, optimizer_recon1 = optimizer_recons
embd0, embd1 = embds
recon_model0.train()
recon_model1.train()
embd0_copy = embd0.clone().detach()
embd1_copy = embd1.clone().detach()
for t in range(batch_r_per_iter):
optimizer_recon0.zero_grad()
loss = feature_reconstruct_loss(embd0_copy, features[0], recon_model0)
loss.backward()
optimizer_recon0.step()
for t in range(batch_r_per_iter):
optimizer_recon1.zero_grad()
loss = feature_reconstruct_loss(embd1_copy, features[1], recon_model1)
loss.backward()
optimizer_recon1.step()
loss = 0.5 * feature_reconstruct_loss(
embd0, features[0], recon_model0
) + 0.5 * feature_reconstruct_loss(embd1, features[1], recon_model1)
return loss
[docs]def check_align(
embds: List[torch.Tensor],
ground_truth: torch.Tensor,
k: Optional[int] = [5, 10],
mode: Optional[str] = "cosine",
) -> List[float]:
r"""
Check embedding correspondence in given distance (default cosine similarity) under ground truth
Parameters
-----------
embds
List of graph features, each element is (node_num, feature_dim)
ground_truth
mapping ground_truth (2, node_num)
k
list of top k (only support 2 elements yet)
mode
distance quota
"""
embd0, embd1 = embds
assert k[1] > k[0]
g_map = {}
for i in range(ground_truth.size(1)):
g_map[ground_truth[1, i].item()] = ground_truth[0, i].item()
g_list = list(g_map.keys())
cossim = torch.zeros(embd1.size(0), embd0.size(0))
for i in range(embd1.size(0)):
cossim[i] = F.cosine_similarity(
embd0, embd1[i : i + 1].expand(embd0.size(0), embd1.size(1)), dim=-1
).view(-1)
ind = cossim.argsort(dim=1, descending=True)[:, : k[1]]
a1 = 0
ak0 = 0
ak1 = 0
for i, node in enumerate(g_list):
if ind[node, 0].item() == g_map[node]:
a1 += 1
ak0 += 1
ak1 += 1
else:
for j in range(1, k[0]):
if ind[node, j].item() == g_map[node]:
ak0 += 1
ak1 += 1
break
else:
for l in range(k[0], k[1]):
if ind[node, l].item() == g_map[node]:
ak1 += 1
break
a1 /= len(g_list)
ak0 /= len(g_list)
ak1 /= len(g_list)
print(f"H@1:{a1*100}; H@{k[0]}:{ak0*100}; H@{k[1]}:{ak1*100}")
return a1, ak0, ak1