Source code for gravity.train.cell_stage

"""High-level helpers to train the GRAVITY cell-wise stage.

The functions in this module prepare datasets, configure PyTorch Lightning
trainers, and export artefacts (CSV predictions, TF scores, priors) expected by
downstream steps. They do not change the model's inner computations.
"""

from __future__ import annotations

from dataclasses import dataclass
from pathlib import Path
from typing import Dict, Optional, Sequence, Union

import pandas as pd
import pytorch_lightning as pl
import torch
from torch.utils.data import DataLoader, random_split

try:  # Lightning 2.x
    from pytorch_lightning.callbacks import TQDMProgressBar
except ImportError:  # pragma: no cover - Lightning 1.x fallback path
    TQDMProgressBar = None

from ..data.preprocessing import (
    assert_gene_order_matches,
    load_cell_stage_dataset,
    resolve_gene_order,
)
from .cell_model import FullModelCellWise
from ..utils import log_verbose, resolve_path
import time

__all__ = [
    "CellStageConfig",
    "train_cell_stage",
]


class ExportProgressBar(TQDMProgressBar if TQDMProgressBar is not None else object):
    """Progress bar with custom description for export/inference passes."""

    def __init__(self, description: str, refresh_rate: int = 1):
        if TQDMProgressBar is None:  # pragma: no cover - fallback when class unavailable
            self._description = description
            return
        super().__init__(refresh_rate=refresh_rate)
        self._description = description

    def init_test_tqdm(self):  # pragma: no cover - Lightning handles display
        if TQDMProgressBar is None:
            return None
        bar = super().init_test_tqdm()
        if bar is not None:
            bar.set_description(self._description)
        return bar


[docs] @dataclass class CellStageConfig: """Configuration bundle for the cell-wise training stage. See also the top-level :class:`gravity.pipeline.PipelineConfig` for how device/distribution options propagate. ``gene_order_path`` fixes the checkpoint-compatible gene index order when using pretrained/reference weights. """ raw_counts: str middle_csv: str prior_network: Optional[str] = './prior_data/nichenet_mouse.zip' output_dir: str = 'gravity_outputs' stage1_csv: str = 'stage1.csv' checkpoint_name: str = 'stage1.ckpt' pretrained_checkpoint: Optional[str] = None attention_dir: str = 'attentions' gene_subset: Optional[Sequence[str]] = None gene_order_path: Optional[str] = None batch_size: int = 32 epochs: int = 6 accelerator: str = 'auto' devices: Optional[Union[int, Sequence[int]]] = None num_workers: int = 0 val_fraction: float = 0.0 attention_topk: int = 64 attention_output: bool = True precision: Optional[Union[int, str]] = None gradient_clip_val: Optional[float] = None strategy: Optional[str] = None seed: int = 42 log_every_n_steps: int = 50 progress_bar: bool = True learning_rate: float = 1e-6 embedding_size: int = 16 model_dimension: int = 16 ffn_dimension: int = 16 n_pos_neighbors: int = 30 n_neg_neighbors:int = 10
def _split_for_validation(dataset, val_fraction: float, seed: int): """Split a dataset deterministically into train/val subsets if requested.""" if not 0.0 < val_fraction < 1.0: return dataset, None total = len(dataset) if total <= 1: return dataset, None val_size = max(1, int(total * val_fraction)) train_size = total - val_size if train_size <= 0: train_size = total - 1 val_size = 1 generator = torch.Generator().manual_seed(seed) train_subset, val_subset = random_split(dataset, [train_size, val_size], generator=generator) return train_subset, val_subset
[docs] def train_cell_stage(config: CellStageConfig) -> Dict[str, Path]: """Run the GRAVITY cell-wise stage and export key artefacts. Returns ------- Dict[str, Path] Paths to ``stage1_csv``, ``checkpoint``, attention outputs (optional), and the generated gene files used by later steps. """ pl.seed_everything(config.seed, workers=True) output_dir = Path(config.output_dir).resolve() output_dir.mkdir(parents=True, exist_ok=True) attention_dir = output_dir / config.attention_dir if config.attention_output: attention_dir.mkdir(parents=True, exist_ok=True) if config.prior_network is None: prior_network_path = None else: prior_network_path = resolve_path(config.prior_network) stage1_csv_path = output_dir / config.stage1_csv checkpoint_path = output_dir / config.checkpoint_name attn_h5ad = attention_dir / "attention_TF_scores_with_types.h5ad" effective_gene_subset = resolve_gene_order(config.gene_subset, config.gene_order_path) dataset = load_cell_stage_dataset( config.middle_csv, prior_path=prior_network_path, gene_list=effective_gene_subset, n_pos_neighbors=config.n_pos_neighbors, n_neg_neighbors=config.n_neg_neighbors ) total_cells = len(dataset) log_verbose( f"[gravity] stage1 dataset loaded: {total_cells} cells; val_fraction={config.val_fraction}", level=1, ) hvgs = dataset.hvg genes_path = output_dir / 'genes.txt' assert_gene_order_matches(genes_path, hvgs, label="stage1 genes.txt") stage1_csv_path = output_dir / config.stage1_csv checkpoint_path = output_dir / config.checkpoint_name attn_h5ad = attention_dir / "attention_TF_scores_with_types.h5ad" skip_stage1 = False if ( config.pretrained_checkpoint is None and stage1_csv_path.exists() and checkpoint_path.exists() and (not config.attention_output or attn_h5ad.exists()) ): try: df_stats = pd.read_csv(stage1_csv_path, usecols=['cellIndex', 'alpha', 'beta']) existing_cells = df_stats['cellIndex'].nunique() has_complete_rates = df_stats[['alpha', 'beta']].notna().all().all() except Exception as exc: log_verbose(f"[gravity] failed to inspect existing stage1 CSV ({exc}); retraining.", level=1) existing_cells = -1 has_complete_rates = False if existing_cells == len(dataset) and has_complete_rates: skip_stage1 = True log_verbose("[gravity] stage1 outputs match current dataset; reusing and skipping training.", level=1) else: log_verbose( f"[gravity] existing stage1 outputs mismatch dataset (cells: {existing_cells} vs {len(dataset)}; complete rates: {has_complete_rates}); retraining.", level=1, ) raw_counts_path = resolve_path(config.raw_counts) raw_df = pd.read_csv(raw_counts_path) if 'cellIndex' not in raw_df.columns: raw_df.insert(0, 'cellIndex', pd.factorize(raw_df['cellID'])[0]) raw_df['cellIndex'] = raw_df['cellIndex'].astype(int) cell_to_type = dict(zip(raw_df['cellIndex'].astype(str), raw_df['clusters'].astype(str))) if 'clusters' in raw_df.columns else {} if 'clusters' in raw_df.columns: clusters_series = raw_df['clusters'] else: clusters_series = pd.Series(['NA'] * len(raw_df), index=raw_df.index) embedding_map = dict(zip(raw_df['cellIndex'], clusters_series)) id_map = dict(zip(raw_df['cellIndex'], raw_df['cellID'])) if 'cellID' in raw_df.columns else {idx: str(idx) for idx in raw_df['cellIndex']} with genes_path.open('w') as fp: for gene in hvgs: fp.write(f"{gene}\n") log_verbose(f"[gravity] wrote {len(hvgs)} gene identifiers to {genes_path}", level=2) import json mapper_path = output_dir / 'genemap.json' with mapper_path.open('w') as fp: json.dump(dataset.niche_dict, fp, indent=2) log_verbose(f"[gravity] stored TF→target prior map at {mapper_path}", level=2) train_subset, val_subset = _split_for_validation(dataset, config.val_fraction, config.seed) if val_subset is None: log_verbose("[gravity] stage1 training uses all cells; no validation split.", level=1) else: log_verbose( f"[gravity] stage1 split → train: {len(train_subset)} cells, val: {len(val_subset)} cells", level=1, ) train_dataset = train_subset if val_subset is not None else dataset train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=False, num_workers=config.num_workers) val_loader = None if val_subset is not None: val_loader = DataLoader(val_subset, batch_size=config.batch_size, shuffle=False, num_workers=config.num_workers) if skip_stage1: return { 'stage1_csv': stage1_csv_path, 'checkpoint': checkpoint_path, 'attention_dir': attention_dir if config.attention_output else None, 'genes_path': genes_path, 'gene_map': mapper_path, } model_kwargs = dict( input_dim=len(hvgs) * 2 + 3, gene_list=hvgs, TF_list=dataset.TF_from_net, TG_list=dataset.TG_from_net, TFTG_map=dataset.niche_dict, TFTG_map_reverse=dataset.niche_dict_reverse, output_dim_trans=len(hvgs), embedding_map=embedding_map, id_map=id_map, gene_select=hvgs, nbrs=dataset.nbrs, origin_data=dataset.data, negs=dataset.negs, output_csv=str(stage1_csv_path), attention_output=config.attention_output, output_network_path=str(attention_dir), cell_to_type=cell_to_type, gene_list_path=str(genes_path), gene_mapper_path=str(mapper_path), csr_topk=config.attention_topk, learning_rate=config.learning_rate, embedding_size=config.embedding_size, model_dimension=config.model_dimension, ffn_dimension=config.ffn_dimension, ) devices = config.devices if config.devices is not None else 1 trainer_kwargs = dict( accelerator=config.accelerator, devices=devices, max_epochs=config.epochs, logger=False, enable_checkpointing=False, log_every_n_steps=config.log_every_n_steps, enable_progress_bar=config.progress_bar, default_root_dir=str(output_dir), ) if config.precision is not None: trainer_kwargs['precision'] = config.precision if config.gradient_clip_val is not None: trainer_kwargs['gradient_clip_val'] = config.gradient_clip_val if config.strategy is not None: trainer_kwargs['strategy'] = config.strategy def _export(model: FullModelCellWise, description: str, *, save_checkpoint: bool = True) -> None: single_test_loader = DataLoader(dataset, batch_size=config.batch_size, shuffle=False, num_workers=config.num_workers) log_verbose( f"[gravity] {description}: computing velocities & attention matrices...", level=1, ) export_callbacks = [] if config.progress_bar and TQDMProgressBar is not None: export_callbacks.append(ExportProgressBar("Stage1 infer/export")) single_tester = pl.Trainer( accelerator='auto', devices=1, logger=False, enable_checkpointing=False, enable_progress_bar=config.progress_bar, default_root_dir=str(output_dir), callbacks=export_callbacks, ) single_tester.test(model, dataloaders=single_test_loader) if save_checkpoint: single_tester.save_checkpoint(str(checkpoint_path)) log_verbose("[gravity] stage1 infer/export finished; synchronising files...", level=1) for _ in range(60): if stage1_csv_path.exists(): break time.sleep(1.0) if not stage1_csv_path.exists(): raise RuntimeError(f"Stage1 CSV was not written: {stage1_csv_path}") if config.pretrained_checkpoint is not None: pretrained_path = resolve_path(config.pretrained_checkpoint) log_verbose(f"[gravity] loading stage1 checkpoint for inference/export: {pretrained_path}", level=1) model = FullModelCellWise.load_from_checkpoint(str(pretrained_path), **model_kwargs) _export(model, "stage1 checkpoint infer/export") return { 'stage1_csv': stage1_csv_path, 'checkpoint': checkpoint_path, 'attention_dir': attention_dir if config.attention_output else None, 'genes_path': genes_path, 'gene_map': mapper_path, } model = FullModelCellWise(**model_kwargs) trainer = pl.Trainer(**trainer_kwargs) trainer.fit(model, train_loader, val_loader) # Only rank 0 performs export/saving if getattr(trainer, "global_rank", 0) == 0: single_model = FullModelCellWise(**model_kwargs) single_model.load_state_dict(model.state_dict()) _export(single_model, "stage1 infer/export", save_checkpoint=False) trainer.save_checkpoint(str(checkpoint_path)) # ensure all ranks wait until exports finish # Synchronize ranks if running distributed world_size = 1 if hasattr(trainer, "strategy") and hasattr(trainer.strategy, "world_size"): world_size = trainer.strategy.world_size elif hasattr(trainer, "training_type_plugin") and hasattr(trainer.training_type_plugin, "world_size"): world_size = trainer.training_type_plugin.world_size if world_size and world_size > 1: barrier_fn = None if hasattr(trainer, "strategy") and hasattr(trainer.strategy, "barrier"): barrier_fn = trainer.strategy.barrier elif hasattr(trainer, "training_type_plugin") and hasattr(trainer.training_type_plugin, "barrier"): barrier_fn = trainer.training_type_plugin.barrier if barrier_fn is not None: barrier_fn() return { 'stage1_csv': stage1_csv_path, 'checkpoint': checkpoint_path, 'attention_dir': attention_dir if config.attention_output else None, 'genes_path': genes_path, 'gene_map': mapper_path, }