Source code for gravity.pipeline

"""Pipeline configuration and execution for the GRAVITY workflow.

This module exposes a high-level configuration dataclass and a single
entry-point function that coordinates the complete GRAVITY pipeline:

- preprocess long-format counts into a wide table used by subsequent stages,
- train the cell-wise stage (stage 1) and export attention matrices if enabled,
- estimate future positions from the stage-1 outputs,
- train the gene-wise stage (stage 2),
- optionally render velocity visualizations at the cell and gene level.

Multi-GPU and distributed execution are handled by PyTorch Lightning inside
each training stage; this module coordinates preprocessing, stage transitions,
future projection, and optional plotting.

References
----------
"""

from __future__ import annotations

from dataclasses import dataclass
from pathlib import Path
from typing import Dict, Optional, Sequence, Tuple, Union

import pandas as pd
import re
import time

from .data import preprocess_counts, export_intermediate_from_h5ad, resolve_gene_order
from .train import CellStageConfig, GeneStageConfig, train_cell_stage, train_gene_stage
from .tools.future import estimate_future_positions
from .plotting.velocity import plot_velocity_gene, plot_velocity_cell
from .utils import log_verbose, resolve_path

__all__ = [
    "PipelineConfig",
    "run_pipeline",
]


[docs] @dataclass class PipelineConfig: """Configuration for the full GRAVITY pipeline. Parameters ---------- raw_counts: Path to the input long-format CSV (must include at least ``cellID, gene_name, unsplice, splice, embedding1, embedding2``). workdir: Output directory where intermediate and final artifacts are written. prior_network: Path to the prior TF–target network archive used by GRAVITY. gene_subset: Optional list of genes to restrict training and evaluation. gene_order_path: Optional newline-delimited gene list that fixes gene-index order. Use this when running pretrained/reference checkpoints; model weights and attention matrices are aligned by gene position, not only by gene name. batch_size: Mini-batch size used by both training stages. stage1_epochs, stage2_epochs: Number of epochs per stage. stage1_lr, stage2_lr: Learning rates per stage. The unsupervised cell-wise stage and contrastive gene-wise stage are somewhat learning-rate sensitive; for reference-style runs we recommend keeping ``stage1_lr`` below ``1e-5`` and tuning ``stage2_lr`` between ``1e-3`` and ``1e-5``. stage1_pretrained_checkpoint, stage2_pretrained_checkpoint: Optional checkpoints used for inference/export instead of training the corresponding stage. Use these for checkpoint-based reference reproduction. val_fraction_stage1, val_fraction_stage2: Fraction of data reserved for validation in each stage. accelerator, devices, strategy, precision, gradient_clip_val, num_workers: Forwarded to PyTorch Lightning trainers to control device placement, distribution strategy and numerical precision. future_tau: Scaling factor governing the radius used in future-neighbor search. log_every_n_steps, progress_bar: Logging and progress display controls. make_plot, plot_gene, plot_color, plot_genes, arrow_grid, arrow_scale: Plotting options for optional velocity visualization. middle_csv_name, stage*_csv_name, future_positions_name, stage*_checkpoint_name: Filenames for artifacts written under ``workdir``. """ raw_counts: str workdir: str = 'gravity_outputs' prior_network: Optional[str] = './prior_data/nichenet_mouse.zip' gene_subset: Optional[Sequence[str]] = None gene_order_path: Optional[str] = None batch_size: int = 16 n_pos_neighbors: int = 30 n_neg_neighbors: int = 10 stage1_epochs: int = 6 stage2_epochs: int = 6 stage1_lr: float = 1e-6 stage2_lr: float = 1e-4 stage1_pretrained_checkpoint: Optional[str] = None stage2_pretrained_checkpoint: Optional[str] = None embedding_size: int = 16 model_dimension: int = 16 ffn_dimension: int = 16 val_fraction_stage1: float = 0.0 val_fraction_stage2: float = 0.0 accelerator: str = 'auto' devices: Optional[Union[int, Sequence[int]]] = None num_workers: int = 8 strategy: Optional[str] = None precision: Optional[Union[int, str]] = None gradient_clip_val: Optional[float] = None future_tau: float = 0.5 log_every_n_steps: int = 50 progress_bar: bool = True make_plot: bool = False plot_gene: Optional[str] = None plot_color: Optional[str] = 'clusters' plot_genes: Optional[Union[str, Sequence[str]]] = None arrow_grid: Tuple[int, int] = (20, 20) arrow_scale: float = 1.0 middle_csv_name: str = 'combine.csv' stage1_csv_name: str = 'stage1.csv' stage2_csv_name: str = 'stage2.csv' future_positions_name: str = 'future_positions.npy' stage1_checkpoint_name: str = 'stage1.ckpt' stage2_checkpoint_name: str = 'stage2.ckpt'
[docs] def run_pipeline(config: PipelineConfig) -> Dict[str, Path]: """Execute preprocessing, two training stages, future projection, and optional plotting. Parameters ---------- config: :class:`PipelineConfig` instance that specifies inputs, training and plotting options, and output filenames. Returns ------- Dict[str, pathlib.Path] Mapping of artifact names to absolute paths under ``workdir``. Keys include ``middle_csv``, ``stage1_csv``, ``stage1_checkpoint``, ``attention_dir``, ``future_positions``, ``stage2_csv``, and ``stage2_checkpoint``. If plotting is enabled, additional keys may be present for generated figures. Notes ----- Multi-GPU controls in ``config`` are forwarded to PyTorch Lightning inside the stage trainers. Preprocessing, future projection and plotting always run in the main process. When using DDP with ``strategy='ddp'`` (spawn), child processes may re-import and execute the entry script; to avoid duplicated orchestration, this function performs an early return on non-zero ranks. """ # Guard: if called from DDP worker, skip orchestration. import os try: global_rank = int(os.environ.get("PL_GLOBAL_RANK", os.environ.get("RANK", "0"))) except Exception: global_rank = 0 if global_rank != 0: log_verbose("[gravity] run_pipeline invoked on non-zero rank; skipping in worker.", level=1) workdir = Path(config.workdir).resolve() return { 'middle_csv': workdir / config.middle_csv_name, 'stage1_csv': workdir / config.stage1_csv_name, 'stage1_checkpoint': workdir / config.stage1_checkpoint_name, 'attention_dir': workdir / 'attentions', 'future_positions': workdir / config.future_positions_name, 'stage2_csv': workdir / config.stage2_csv_name, 'stage2_checkpoint': workdir / config.stage2_checkpoint_name, } workdir = Path(config.workdir).resolve() workdir.mkdir(parents=True, exist_ok=True) middle_csv_path = workdir / config.middle_csv_name raw_counts_path = resolve_path(config.raw_counts) if config.prior_network is None: prior_network_path = None else: prior_network_path = resolve_path(config.prior_network) effective_gene_subset = resolve_gene_order(config.gene_subset, config.gene_order_path) preprocess_counts( str(raw_counts_path), str(middle_csv_path), gene_order=effective_gene_subset, ) cell_cfg = CellStageConfig( raw_counts=str(raw_counts_path), middle_csv=str(middle_csv_path), prior_network=prior_network_path, output_dir=str(workdir), stage1_csv=config.stage1_csv_name, checkpoint_name=config.stage1_checkpoint_name, pretrained_checkpoint=config.stage1_pretrained_checkpoint, attention_dir='attentions', gene_subset=effective_gene_subset, gene_order_path=config.gene_order_path, n_pos_neighbors = config.n_pos_neighbors, n_neg_neighbors = config.n_neg_neighbors, batch_size=config.batch_size, epochs=config.stage1_epochs, accelerator=config.accelerator, devices=config.devices, strategy=config.strategy, num_workers=config.num_workers, val_fraction=config.val_fraction_stage1, attention_topk=64, attention_output=True, log_every_n_steps=config.log_every_n_steps, progress_bar=config.progress_bar, precision=config.precision, gradient_clip_val=config.gradient_clip_val, learning_rate=config.stage1_lr, embedding_size = config.embedding_size, model_dimension = config.model_dimension, ffn_dimension = config.ffn_dimension, ) stage1_outputs = train_cell_stage(cell_cfg) # Ensure stage1 CSV is written before proceeding stage1_csv_path = Path(stage1_outputs['stage1_csv']).resolve() if not stage1_csv_path.exists(): log_verbose(f"[gravity] waiting for stage1 CSV to appear: {stage1_csv_path}", level=1) waited = 0 while waited < 900: # wait up to 15 minutes for slow filesystems if stage1_csv_path.exists(): break time.sleep(1.0) waited += 1 if not stage1_csv_path.exists(): raise FileNotFoundError(f"stage1 CSV not found after training: {stage1_csv_path}") plots_dir: Optional[Path] = None plot_results: Dict[str, object] = {} if config.make_plot: plots_dir = workdir / 'velocity_plots' plots_dir.mkdir(parents=True, exist_ok=True) future_positions_path = workdir / config.future_positions_name future_plot_path: Optional[Path] = None if plots_dir is not None: future_plot_path = plots_dir / 'future_projection_embedding.png' estimate_future_positions( str(stage1_csv_path), str(future_positions_path), tau=config.future_tau, show_plot=False, plot_path=str(future_plot_path) if future_plot_path else None, ) if future_plot_path is not None and future_plot_path.exists(): plot_results['future_projection_plot'] = future_plot_path gene_cfg = GeneStageConfig( raw_counts=str(raw_counts_path), middle_csv=str(middle_csv_path), stage1_checkpoint=str(stage1_outputs['checkpoint']), future_positions=str(future_positions_path), prior_network=prior_network_path, output_dir=str(workdir), stage2_csv=config.stage2_csv_name, checkpoint_name=config.stage2_checkpoint_name, pretrained_checkpoint=config.stage2_pretrained_checkpoint, gene_subset=effective_gene_subset, gene_order_path=config.gene_order_path, batch_size=config.batch_size, epochs=config.stage2_epochs, accelerator=config.accelerator, devices=config.devices, strategy=config.strategy, num_workers=config.num_workers, val_fraction=config.val_fraction_stage2, log_every_n_steps=config.log_every_n_steps, progress_bar=config.progress_bar, precision=config.precision, gradient_clip_val=config.gradient_clip_val, learning_rate=config.stage2_lr, embedding_size=config.embedding_size, model_dimension=config.model_dimension, ffn_dimension=config.ffn_dimension, ) stage2_outputs = train_gene_stage(gene_cfg) if config.make_plot and plots_dir is not None: def _safe_plot(stage_csv: str, output_name: str) -> Path | None: plot_path = plots_dir / output_name try: plot_velocity_cell( stage_csv, gene=config.plot_gene, color_by=config.plot_color, arrow_grid=config.arrow_grid, arrow_scale=config.arrow_scale, output_path=str(plot_path), show=False, ) return plot_path except Exception as exc: # pragma: no cover - plotting optional log_verbose(f"[gravity] cell-level plotting failed ({output_name}): {exc}", level=1) return None stage1_plot = _safe_plot(str(stage1_outputs['stage1_csv']), 'cell_velocity_stage1_embedding.png') if stage1_plot is not None: plot_results['cell_velocity_plot_stage1'] = stage1_plot stage2_plot = _safe_plot(str(stage2_outputs['stage2_csv']), 'cell_velocity_stage2_embedding.png') if stage2_plot is not None: plot_results['cell_velocity_plot_stage2'] = stage2_plot gene_plots: list[Path] = [] genes_to_plot: list[str] = [] if config.plot_genes: if isinstance(config.plot_genes, str): if config.plot_genes.lower() == 'all': stage2_df = pd.read_csv(stage2_outputs['stage2_csv']) genes_to_plot = stage2_df['gene_name'].dropna().astype(str).unique().tolist() else: genes_to_plot = [config.plot_genes] else: genes_to_plot = list(config.plot_genes) for gene_name in genes_to_plot: gene_path = plots_dir / f"gene_{_sanitize_name(gene_name)}_expression.png" try: plot_velocity_gene( str(stage2_outputs['stage2_csv']), gene=gene_name, color_by=config.plot_color, arrow_grid=config.arrow_grid, arrow_scale=config.arrow_scale, output_path=str(gene_path), show=False, ) gene_plots.append(gene_path) except Exception as exc: log_verbose(f"[gravity] gene-level plotting failed for {gene_name}: {exc}", level=1) if gene_plots: plot_results['gene_velocity_plots'] = gene_plots outputs = { 'middle_csv': middle_csv_path, 'stage1_csv': stage1_outputs['stage1_csv'], 'stage1_checkpoint': stage1_outputs['checkpoint'], 'attention_dir': stage1_outputs['attention_dir'], 'future_positions': future_positions_path, 'stage2_csv': stage2_outputs['stage2_csv'], 'stage2_checkpoint': stage2_outputs['checkpoint'], } if plot_results: outputs.update(plot_results) return outputs
def _sanitize_name(name: str) -> str: """Return a filesystem-friendly version of ``name``. Only ASCII letters, digits and the characters ``.``, ``_`` and ``-`` are retained; all other runs are replaced by underscores. """ return re.sub(r"[^A-Za-z0-9._-]+", "_", name)