| | import torch |
| | import sys |
| | sys.argv = [''] |
| | from sklearn.preprocessing import StandardScaler |
| | import pytorch_lightning as pl |
| | from torch.utils.data import DataLoader |
| | from lightning.pytorch.utilities.combined_loader import CombinedLoader |
| | import numpy as np |
| | from scipy.spatial import cKDTree |
| | import math |
| | from functools import partial |
| | from sklearn.cluster import KMeans, DBSCAN |
| | import matplotlib.pyplot as plt |
| | import pandas as pd |
| | from torch.utils.data import TensorDataset |
| |
|
| | from train.parsers_sc import parse_args |
| | args = parse_args() |
| |
|
| | class WeightedBranchedCellDataModule(pl.LightningDataModule): |
| | def __init__(self, args): |
| | super().__init__() |
| | self.save_hyperparameters() |
| |
|
| | self.data_path = "./data/mouse_hematopoiesis.csv" |
| | self.batch_size = args.batch_size |
| | self.max_dim = args.dim |
| | self.whiten = args.whiten |
| | self.k = 20 |
| | self.n_samples = 1429 |
| | self.num_timesteps = 3 |
| | self.split_ratios = args.split_ratios |
| | self.metric_clusters = args.metric_clusters |
| | self.args = args |
| | self._prepare_data() |
| | |
| |
|
| | def _prepare_data(self): |
| | print("Preparing cell data in BranchedCellDataModule") |
| | |
| | df = pd.read_csv(self.data_path) |
| | |
| | |
| | coords_by_t = { |
| | t: df[df["samples"] == t][["x1","x2"]].values |
| | for t in sorted(df["samples"].unique()) |
| | } |
| | n0 = coords_by_t[0].shape[0] |
| | self.n_samples = n0 |
| |
|
| | |
| | km = KMeans(n_clusters=2, random_state=42).fit(coords_by_t[2]) |
| | df2 = df[df["samples"] == 2].copy() |
| | df2["branch"] = km.labels_ |
| | |
| | cluster_counts = df2["branch"].value_counts().sort_index() |
| | print(cluster_counts) |
| |
|
| | |
| | endpoints = {} |
| | for b in (0, 1): |
| | endpoints[b] = ( |
| | df2[df2["branch"] == b] |
| | .sample(n=n0, random_state=42)[["x1","x2"]] |
| | .values |
| | ) |
| | |
| | x0 = torch.tensor(coords_by_t[0], dtype=torch.float32) |
| | x_inter = torch.tensor(coords_by_t[1], dtype=torch.float32) |
| | x1_1 = torch.tensor(endpoints[0], dtype=torch.float32) |
| | x1_2 = torch.tensor(endpoints[1], dtype=torch.float32) |
| |
|
| | self.coords_t0 = x0 |
| | self.coords_t1 = x_inter |
| | self.coords_t2_1 = x1_1 |
| | self.coords_t2_2 = x1_2 |
| | self.time_labels = np.concatenate([ |
| | np.zeros(len(self.coords_t0)), |
| | np.ones(len(self.coords_t1)), |
| | np.ones(len(self.coords_t2_1)) * 2, |
| | np.ones(len(self.coords_t2_2)) * 2, |
| | ]) |
| | |
| | split_index = int(n0 * self.split_ratios[0]) |
| | |
| | if n0 - split_index < self.batch_size: |
| | split_index = n0 - self.batch_size |
| |
|
| | train_x0 = x0[:split_index] |
| | val_x0 = x0[split_index:] |
| | train_x1_1 = x1_1[:split_index] |
| | val_x1_1 = x1_1[split_index:] |
| | train_x1_2 = x1_2[:split_index] |
| | val_x1_2 = x1_2[split_index:] |
| | |
| | self.val_x0 = val_x0 |
| | |
| | train_x0_weights = torch.full((train_x0.shape[0], 1), fill_value=1.0) |
| | train_x1_1_weights = torch.full((train_x1_1.shape[0], 1), fill_value=0.5) |
| | train_x1_2_weights = torch.full((train_x1_2.shape[0], 1), fill_value=0.5) |
| | |
| | val_x0_weights = torch.full((val_x0.shape[0], 1), fill_value=1.0) |
| | val_x1_1_weights = torch.full((val_x1_1.shape[0], 1), fill_value=0.5) |
| | val_x1_2_weights = torch.full((val_x1_2.shape[0], 1), fill_value=0.5) |
| |
|
| | if self.n_samples - split_index < self.batch_size: |
| | split_index = self.n_samples - self.batch_size |
| | |
| | self.train_dataloaders = { |
| | "x0": DataLoader(TensorDataset(train_x0, train_x0_weights), batch_size=self.batch_size, shuffle=True, drop_last=True), |
| | "x1_1": DataLoader(TensorDataset(train_x1_1, train_x1_1_weights), batch_size=self.batch_size, shuffle=True, drop_last=True), |
| | "x1_2": DataLoader(TensorDataset(train_x1_2, train_x1_2_weights), batch_size=self.batch_size, shuffle=True, drop_last=True), |
| | } |
| | |
| | self.val_dataloaders = { |
| | "x0": DataLoader(TensorDataset(val_x0, val_x0_weights), batch_size=self.batch_size, shuffle=False, drop_last=True), |
| | "x1_1": DataLoader(TensorDataset(val_x1_1, val_x1_1_weights), batch_size=self.batch_size, shuffle=True, drop_last=True), |
| | "x1_2": DataLoader(TensorDataset(val_x1_2, val_x1_2_weights), batch_size=self.batch_size, shuffle=True, drop_last=True), |
| | } |
| |
|
| | all_data = np.vstack([coords_by_t[t] for t in sorted(coords_by_t.keys())]) |
| | self.dataset = torch.tensor(all_data, dtype=torch.float32) |
| | self.tree = cKDTree(all_data) |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | self.test_dataloaders = { |
| | "x0": DataLoader(TensorDataset(val_x0, val_x0_weights), batch_size=self.val_x0.shape[0], shuffle=False, drop_last=False), |
| | "dataset": DataLoader(TensorDataset(self.dataset), batch_size=self.dataset.shape[0], shuffle=False, drop_last=False), |
| | } |
| | |
| | |
| | |
| | if self.metric_clusters == 3: |
| | km_all = KMeans(n_clusters=3, random_state=45).fit(self.dataset.numpy()) |
| | cluster_labels = km_all.labels_ |
| | |
| | cluster_0_mask = cluster_labels == 0 |
| | cluster_1_mask = cluster_labels == 1 |
| | cluster_2_mask = cluster_labels == 2 |
| | |
| | samples = self.dataset.cpu().numpy() |
| | |
| | cluster_0_data = samples[cluster_0_mask] |
| | cluster_1_data = samples[cluster_1_mask] |
| | cluster_2_data = samples[cluster_2_mask] |
| | |
| | self.metric_samples_dataloaders = [ |
| | DataLoader( |
| | torch.tensor(cluster_1_data, dtype=torch.float32), |
| | batch_size=cluster_1_data.shape[0], |
| | shuffle=False, |
| | drop_last=False, |
| | ), |
| | DataLoader( |
| | torch.tensor(cluster_2_data, dtype=torch.float32), |
| | batch_size=cluster_2_data.shape[0], |
| | shuffle=False, |
| | drop_last=False, |
| | ), |
| | |
| | DataLoader( |
| | torch.tensor(cluster_0_data, dtype=torch.float32), |
| | batch_size=cluster_0_data.shape[0], |
| | shuffle=False, |
| | drop_last=False, |
| | ), |
| | ] |
| | else: |
| | km_all = KMeans(n_clusters=2, random_state=45).fit(self.dataset.numpy()) |
| | cluster_labels = km_all.labels_ |
| | |
| | cluster_0_mask = cluster_labels == 0 |
| | cluster_1_mask = cluster_labels == 1 |
| | |
| | samples = self.dataset.cpu().numpy() |
| | |
| | cluster_0_data = samples[cluster_0_mask] |
| | cluster_1_data = samples[cluster_1_mask] |
| | |
| | self.metric_samples_dataloaders = [ |
| | DataLoader( |
| | torch.tensor(cluster_1_data, dtype=torch.float32), |
| | batch_size=cluster_1_data.shape[0], |
| | shuffle=False, |
| | drop_last=False, |
| | ), |
| | DataLoader( |
| | torch.tensor(cluster_0_data, dtype=torch.float32), |
| | batch_size=cluster_0_data.shape[0], |
| | shuffle=False, |
| | drop_last=False, |
| | ), |
| | ] |
| |
|
| |
|
| | def train_dataloader(self): |
| | combined_loaders = { |
| | "train_samples": CombinedLoader(self.train_dataloaders, mode="min_size"), |
| | "metric_samples": CombinedLoader( |
| | self.metric_samples_dataloaders, mode="min_size" |
| | ), |
| | } |
| | return CombinedLoader(combined_loaders, mode="max_size_cycle") |
| |
|
| | def val_dataloader(self): |
| | combined_loaders = { |
| | "val_samples": CombinedLoader(self.val_dataloaders, mode="min_size"), |
| | "metric_samples": CombinedLoader( |
| | self.metric_samples_dataloaders, mode="min_size" |
| | ), |
| | } |
| |
|
| | return CombinedLoader(combined_loaders, mode="max_size_cycle") |
| |
|
| | def test_dataloader(self): |
| | combined_loaders = { |
| | "test_samples": CombinedLoader(self.test_dataloaders, mode="min_size"), |
| | "metric_samples": CombinedLoader( |
| | self.metric_samples_dataloaders, mode="min_size" |
| | ), |
| | } |
| |
|
| | return CombinedLoader(combined_loaders, mode="max_size_cycle") |
| |
|
| | def get_manifold_proj(self, points): |
| | """Adapted for 2D cell data - uses local neighborhood averaging instead of plane fitting""" |
| | return partial(self.local_smoothing_op, tree=self.tree, dataset=self.dataset) |
| |
|
| | @staticmethod |
| | def local_smoothing_op(x, tree, dataset, k=10, temp=1e-3): |
| | """ |
| | Apply local smoothing based on k-nearest neighbors in the full dataset |
| | This replaces the plane projection for 2D manifold regularization |
| | """ |
| | points_np = x.detach().cpu().numpy() |
| | _, idx = tree.query(points_np, k=k) |
| | nearest_pts = dataset[idx] |
| | |
| | |
| | dists = (x.unsqueeze(1) - nearest_pts).pow(2).sum(-1, keepdim=True) |
| | weights = torch.exp(-dists / temp) |
| | weights = weights / weights.sum(dim=1, keepdim=True) |
| | |
| | |
| | smoothed = (weights * nearest_pts).sum(dim=1) |
| | |
| | |
| | alpha = 0.3 |
| | return (1 - alpha) * x + alpha * smoothed |
| | |
| | def get_timepoint_data(self): |
| | """Return data organized by timepoints for visualization""" |
| | return { |
| | 't0': self.coords_t0, |
| | 't1': self.coords_t1, |
| | 't2_1': self.coords_t2_1, |
| | 't2_2': self.coords_t2_2, |
| | 'time_labels': self.time_labels |
| | } |
| |
|
| |
|
| |
|
| | class SingleBranchCellDataModule(pl.LightningDataModule): |
| | def __init__(self, args): |
| | super().__init__() |
| | self.save_hyperparameters() |
| |
|
| | self.data_path = "./data/mouse_hematopoiesis.csv" |
| | self.batch_size = args.batch_size |
| | self.max_dim = args.dim |
| | self.whiten = args.whiten |
| | self.k = 20 |
| | self.n_samples = 1429 |
| | self.num_timesteps = 3 |
| | self.split_ratios = args.split_ratios |
| | self.metric_clusters = 3 |
| | self.args = args |
| | self._prepare_data() |
| | |
| |
|
| | def _prepare_data(self): |
| | print("Preparing cell data in BranchedCellDataModule") |
| | |
| | df = pd.read_csv(self.data_path) |
| | |
| | |
| | coords_by_t = { |
| | t: df[df["samples"] == t][["x1","x2"]].values |
| | for t in sorted(df["samples"].unique()) |
| | } |
| | n0 = coords_by_t[0].shape[0] |
| | self.n_samples = n0 |
| |
|
| | x0 = torch.tensor(coords_by_t[0], dtype=torch.float32) |
| | x_inter = torch.tensor(coords_by_t[1], dtype=torch.float32) |
| | x1 = torch.tensor(coords_by_t[2], dtype=torch.float32) |
| |
|
| | split_index = int(n0 * self.split_ratios[0]) |
| | |
| | if n0 - split_index < self.batch_size: |
| | split_index = n0 - self.batch_size |
| |
|
| | train_x0 = x0[:split_index] |
| | val_x0 = x0[split_index:] |
| | train_x1 = x1[:split_index] |
| | val_x1 = x1[split_index:] |
| | |
| | self.val_x0 = val_x0 |
| | |
| | train_x0_weights = torch.full((train_x0.shape[0], 1), fill_value=1.0) |
| | train_x1_weights = torch.full((train_x1.shape[0], 1), fill_value=0.5) |
| | |
| | val_x0_weights = torch.full((val_x0.shape[0], 1), fill_value=1.0) |
| | val_x1_weights = torch.full((val_x1.shape[0], 1), fill_value=0.5) |
| |
|
| | if self.n_samples - split_index < self.batch_size: |
| | split_index = self.n_samples - self.batch_size |
| | |
| | self.train_dataloaders = { |
| | "x0": DataLoader(TensorDataset(train_x0, train_x0_weights), batch_size=self.batch_size, shuffle=True, drop_last=True), |
| | "x1": DataLoader(TensorDataset(train_x1, train_x1_weights), batch_size=self.batch_size, shuffle=True, drop_last=True), |
| | } |
| | |
| | self.val_dataloaders = { |
| | "x0": DataLoader(TensorDataset(val_x0, val_x0_weights), batch_size=self.batch_size, shuffle=False, drop_last=True), |
| | "x1": DataLoader(TensorDataset(val_x1, val_x1_weights), batch_size=self.batch_size, shuffle=True, drop_last=True), |
| | } |
| |
|
| | all_data = np.vstack([coords_by_t[t] for t in sorted(coords_by_t.keys())]) |
| | self.dataset = torch.tensor(all_data, dtype=torch.float32) |
| | self.tree = cKDTree(all_data) |
| | |
| | |
| | if self.whiten: |
| | self.scaler = StandardScaler() |
| | self.dataset = torch.tensor( |
| | self.scaler.fit_transform(all_data), dtype=torch.float32 |
| | ) |
| | |
| | self.test_dataloaders = { |
| | "x0": DataLoader(TensorDataset(val_x0, val_x0_weights), batch_size=self.val_x0.shape[0], shuffle=False, drop_last=False), |
| | "dataset": DataLoader(TensorDataset(self.dataset), batch_size=self.dataset.shape[0], shuffle=False, drop_last=False), |
| | } |
| | |
| | |
| | |
| | km_all = KMeans(n_clusters=2, random_state=45).fit(self.dataset.numpy()) |
| | cluster_labels = km_all.labels_ |
| | |
| | cluster_0_mask = cluster_labels == 0 |
| | cluster_1_mask = cluster_labels == 1 |
| | |
| | samples = self.dataset.cpu().numpy() |
| | |
| | cluster_0_data = samples[cluster_0_mask] |
| | cluster_1_data = samples[cluster_1_mask] |
| | |
| | self.metric_samples_dataloaders = [ |
| | DataLoader( |
| | torch.tensor(cluster_1_data, dtype=torch.float32), |
| | batch_size=cluster_1_data.shape[0], |
| | shuffle=False, |
| | drop_last=False, |
| | ), |
| | DataLoader( |
| | torch.tensor(cluster_0_data, dtype=torch.float32), |
| | batch_size=cluster_0_data.shape[0], |
| | shuffle=False, |
| | drop_last=False, |
| | ), |
| | ] |
| |
|
| |
|
| | def train_dataloader(self): |
| | combined_loaders = { |
| | "train_samples": CombinedLoader(self.train_dataloaders, mode="min_size"), |
| | "metric_samples": CombinedLoader( |
| | self.metric_samples_dataloaders, mode="min_size" |
| | ), |
| | } |
| | return CombinedLoader(combined_loaders, mode="max_size_cycle") |
| |
|
| | def val_dataloader(self): |
| | combined_loaders = { |
| | "val_samples": CombinedLoader(self.val_dataloaders, mode="min_size"), |
| | "metric_samples": CombinedLoader( |
| | self.metric_samples_dataloaders, mode="min_size" |
| | ), |
| | } |
| |
|
| | return CombinedLoader(combined_loaders, mode="max_size_cycle") |
| |
|
| | def test_dataloader(self): |
| | combined_loaders = { |
| | "test_samples": CombinedLoader(self.test_dataloaders, mode="min_size"), |
| | "metric_samples": CombinedLoader( |
| | self.metric_samples_dataloaders, mode="min_size" |
| | ), |
| | } |
| |
|
| | return CombinedLoader(combined_loaders, mode="max_size_cycle") |
| |
|
| | def get_manifold_proj(self, points): |
| | """Adapted for 2D cell data - uses local neighborhood averaging instead of plane fitting""" |
| | return partial(self.local_smoothing_op, tree=self.tree, dataset=self.dataset) |
| |
|
| | @staticmethod |
| | def local_smoothing_op(x, tree, dataset, k=10, temp=1e-3): |
| | """ |
| | Apply local smoothing based on k-nearest neighbors in the full dataset |
| | This replaces the plane projection for 2D manifold regularization |
| | """ |
| | points_np = x.detach().cpu().numpy() |
| | _, idx = tree.query(points_np, k=k) |
| | nearest_pts = dataset[idx] |
| | |
| | |
| | dists = (x.unsqueeze(1) - nearest_pts).pow(2).sum(-1, keepdim=True) |
| | weights = torch.exp(-dists / temp) |
| | weights = weights / weights.sum(dim=1, keepdim=True) |
| | |
| | |
| | smoothed = (weights * nearest_pts).sum(dim=1) |
| | |
| | |
| | alpha = 0.3 |
| | return (1 - alpha) * x + alpha * smoothed |
| |
|
| | def get_datamodule(): |
| | datamodule = WeightedBranchedCellDataModule(args) |
| | datamodule.setup(stage="fit") |
| | return datamodule |