"""Data preprocessing utilities for GRAVITY training stages."""
from __future__ import annotations
from pathlib import Path
from typing import Optional, Sequence
from ..utils import log_verbose, resolve_path
from .datasets import PreprocessDataset, CustomDataset, CustomDatasetGeneWise
# Add to __all__
__all__ = [
"preprocess_counts",
"load_cell_stage_dataset",
"load_gene_stage_dataset",
"load_gene_order",
"resolve_gene_order",
"assert_gene_order_matches",
"export_intermediate_from_h5ad", # new
"adata_to_df_with_embed", # new (self-contained version)
]
def _norm_gene(gene: str) -> str:
return str(gene).strip().upper()
[docs]
def load_gene_order(path: str) -> list[str]:
"""Load a newline-delimited gene order file.
The order is part of the model contract for pretrained checkpoints because
attention and solver tensors are indexed by gene position.
"""
gene_path = resolve_path(path)
genes = [_norm_gene(line) for line in Path(gene_path).read_text().splitlines()]
genes = [gene for gene in genes if gene]
seen = set()
duplicates = set()
for gene in genes:
if gene in seen:
duplicates.add(gene)
seen.add(gene)
duplicates = sorted(duplicates)
if duplicates:
preview = ", ".join(duplicates[:10])
raise ValueError(
f"gene order file contains duplicate genes: {preview}"
+ ("..." if len(duplicates) > 10 else "")
)
if not genes:
raise ValueError(f"gene order file is empty: {gene_path}")
return genes
[docs]
def resolve_gene_order(
gene_subset: Optional[Sequence[str]] = None,
gene_order_path: Optional[str] = None,
) -> Optional[list[str]]:
"""Resolve the effective gene list, preserving checkpoint-compatible order."""
if gene_order_path is None:
return None if gene_subset is None else [_norm_gene(gene) for gene in gene_subset]
ordered = load_gene_order(gene_order_path)
if gene_subset is None:
return ordered
requested = [_norm_gene(gene) for gene in gene_subset if _norm_gene(gene)]
requested_set = set(requested)
ordered_set = set(ordered)
missing = [gene for gene in requested if gene not in ordered_set]
if missing:
preview = ", ".join(missing[:10])
raise ValueError(
f"gene_subset contains genes not present in gene_order_path: {preview}"
+ ("..." if len(missing) > 10 else "")
)
return [gene for gene in ordered if gene in requested_set]
[docs]
def assert_gene_order_matches(path: Path, expected: Sequence[str], *, label: str = "genes.txt") -> None:
"""Raise if an existing gene-order file does not match expected order."""
if not path.exists():
return
current = [_norm_gene(line) for line in path.read_text().splitlines() if _norm_gene(line)]
expected_norm = [_norm_gene(gene) for gene in expected if _norm_gene(gene)]
if current == expected_norm:
return
first_mismatch = None
for idx, (cur_gene, exp_gene) in enumerate(zip(current, expected_norm)):
if cur_gene != exp_gene:
first_mismatch = f"first mismatch at position {idx}: existing={cur_gene}, expected={exp_gene}"
break
if first_mismatch is None:
first_mismatch = f"length differs: existing={len(current)}, expected={len(expected_norm)}"
raise RuntimeError(
f"Existing {label} at {path} does not match the current gene order "
f"({first_mismatch}). Use a different workdir, remove stale outputs, "
"or pass the checkpoint-matching gene_order_path."
)
[docs]
def adata_to_df_with_embed(
adata,
us_para: Sequence[str] = ("Mu", "Ms"),
cell_type_para: str = "celltype",
embed_para: str = "X_umap",
save_path: str = "cell_type_u_s_sample_df.csv",
gene_list: Optional[Sequence[str]] = None,
):
"""Convert an AnnData object to a long CSV with per-gene/per-cell rows and 2D embedding.
The resulting CSV contains, for every (gene, cell) pair:
- gene_name, unsplice, splice
- cellID, clusters (cell-type), embedding1, embedding2
Notes
-----
- The two matrices are taken from `adata.layers[us_para[0]]` (unspliced) and
`adata.layers[us_para[1]]` (spliced). By default, that is ['Mu', 'Ms'].
- The 2D embedding is read from `adata.obsm[embed_para]` (e.g., 'X_umap').
- This function writes the CSV incrementally (one gene at a time) to keep
memory usage manageable for large datasets, then appends cell metadata.
Parameters
----------
adata
An `anndata.AnnData` object containing layers and embedding.
us_para
Names of the two layers for unspliced and spliced moments/counts,
respectively; default ('Mu', 'Ms').
cell_type_para
Column name in `adata.obs` that holds cell-type labels (default 'celltype').
embed_para
Key in `adata.obsm` for 2D embedding (default 'X_umap').
save_path
Destination CSV file path.
gene_list
Specific genes to export. If None, use all genes (`adata.var.index`).
Returns
-------
pandas.DataFrame
The final DataFrame that was saved to `save_path`.
"""
# Local imports to avoid adding hard dependencies at module import time.
import numpy as np
import pandas as pd
# tqdm is optional; fall back to a no-op iterator if unavailable.
try:
from tqdm import tqdm
except Exception:
def tqdm(x, **kwargs): # type: ignore
return x
# Helper: extract a single gene's (unspliced, spliced) vectors as a DataFrame.
def _adata_to_raw_one_gene(_adata, _us_para, _gene) -> pd.DataFrame:
"""Return a DataFrame with columns [gene_name, unsplice, splice] for one gene."""
data2 = _adata[:, _adata.var.index.isin([_gene])].copy()
# Expect shapes: (n_cells, 1)
u0 = np.asarray(data2.layers[_us_para[0]][:, 0], dtype=np.float32)
s0 = np.asarray(data2.layers[_us_para[1]][:, 0], dtype=np.float32)
df_one = pd.DataFrame(
{"gene_name": _gene, "unsplice": u0, "splice": s0},
copy=False,
)
return df_one
# Determine gene list.
if gene_list is None:
gene_list = list(adata.var.index)
# Stream-write per-gene blocks to CSV (header for the first gene only).
for i, gene in enumerate(tqdm(gene_list, desc="Export genes")):
df_g = _adata_to_raw_one_gene(adata, us_para, gene)
if i == 0:
df_g.to_csv(save_path, header=True, index=False)
else:
df_g.to_csv(save_path, mode="a", header=False, index=False)
# Build per-cell metadata (will be repeated for every gene).
n_genes = len(gene_list)
cellID = pd.DataFrame({"cellID": adata.obs.index})
# Ensure the cell-type column exists; raise a clear error if missing.
if cell_type_para not in adata.obs:
raise KeyError(f"[gravity] column '{cell_type_para}' not found in adata.obs.")
celltype_meta = adata.obs[cell_type_para].reset_index(drop=True)
celltype = pd.DataFrame({"clusters": celltype_meta})
# Validate the embedding.
if embed_para not in adata.obsm:
raise KeyError(f"[gravity] embedding '{embed_para}' not found in adata.obsm.")
if adata.obsm[embed_para].shape[1] < 2:
raise ValueError(f"[gravity] embedding '{embed_para}' must have at least 2 columns.")
embed_map = pd.DataFrame(
{
"embedding1": adata.obsm[embed_para][:, 0],
"embedding2": adata.obsm[embed_para][:, 1],
}
)
# Repeat per-cell metadata for every gene.
embed_info = pd.concat([cellID, celltype, embed_map], axis=1)
embed_raw = pd.concat([embed_info] * n_genes, ignore_index=True)
# Read the just-written raw gene table and append metadata.
raw_data = pd.read_csv(save_path)
if len(raw_data) != len(embed_raw):
# Defensive check to catch mismatches early.
raise RuntimeError(
f"[gravity] row mismatch: gene table has {len(raw_data)} rows, "
f"but repeated cell-metadata has {len(embed_raw)}."
)
raw_data = pd.concat([raw_data, embed_raw], axis=1)
raw_data.to_csv(save_path, header=True, index=False)
return raw_data
[docs]
def preprocess_counts(
input_file: str,
output_csv: str,
*,
gene_order: Optional[Sequence[str]] = None,
) -> Path:
"""Prepare the cell-wise training table from a long single-cell CSV.
Parameters
----------
input_file:
Path to the raw long-format CSV with columns including
`cellID`, `gene_name`, `unsplice`, `splice`, `embedding1`, `embedding2`.
output_csv:
Destination CSV containing one row per cell with serialized gene tuples.
gene_order:
Optional ordered gene list. When provided, matching gene columns are
written first in this order so downstream checkpoints see the intended
gene-index layout.
Returns
-------
pathlib.Path
The path to the generated intermediate CSV.
"""
input_path = resolve_path(input_file)
output_path = Path(output_csv).resolve()
if output_path.exists():
log_verbose(f"[gravity] found existing preprocessed file: {output_path}; skip.", level=1)
return output_path
log_verbose(f"[gravity] preprocessing raw counts from {input_path} → {output_path}", level=1)
output_path.parent.mkdir(parents=True, exist_ok=True)
PreprocessDataset(str(input_path), str(output_path), gene_order=gene_order)
return output_path
[docs]
def load_cell_stage_dataset(middle_file: str, *, prior_path: str = './prior_data/nichenet_mouse.zip', gene_list: Optional[Sequence[str]] = None, n_pos_neighbors = 30, n_neg_neighbors = 10) -> CustomDataset:
"""Instantiate the PyTorch dataset used for the cell-wise stage."""
dataset = CustomDataset(middle_file, prior=prior_path, gene_select=gene_list, n_pos_neighbors=n_pos_neighbors, n_neg_neighbors=n_neg_neighbors)
log_verbose(
f"[gravity] loaded cell-wise dataset with {len(dataset)} cells and {len(dataset.hvg)} HVGs",
level=2,
)
return dataset
[docs]
def load_gene_stage_dataset(middle_file: str, *, prior_path: str = './prior_data/nichenet_mouse.zip', future_positions: str = './final_positions_with_index_yixian.npy', gene_list: Optional[Sequence[str]] = None) -> CustomDatasetGeneWise:
"""Instantiate the PyTorch dataset used for the gene-wise refinement stage."""
dataset = CustomDatasetGeneWise(middle_file, prior=prior_path, gene_select=gene_list, future_pos=future_positions)
log_verbose(
f"[gravity] loaded gene-wise dataset with {len(dataset)} cells and {len(dataset.hvg)} HVGs",
level=2,
)
return dataset