Source code for gravity.train.gene_stage

"""High-level helpers to refine GRAVITY models in the gene-wise stage.

This module loads the stage-1 checkpoint, freezes most parameters, and enables
fine-tuning of a restricted set of solver layers. Artefacts are exported in the
same format used by the cell-wise stage to ease downstream consumption.
"""

from __future__ import annotations

from dataclasses import dataclass
from pathlib import Path
from typing import Dict, Optional, Sequence, Union

import json
import pandas as pd
import pytorch_lightning as pl
import torch
from torch.utils.data import DataLoader, random_split

try:  # Lightning 2.x
    from pytorch_lightning.callbacks import TQDMProgressBar
except ImportError:  # pragma: no cover - Lightning 1.x fallback path
    TQDMProgressBar = None

from ..data.preprocessing import (
    assert_gene_order_matches,
    load_gene_stage_dataset,
    resolve_gene_order,
)
from .gene_model import FullModelGeneWise
from ..utils import log_verbose, resolve_path
import time

__all__ = [
    "GeneStageConfig",
    "train_gene_stage",
]


class ExportProgressBar(TQDMProgressBar if TQDMProgressBar is not None else object):
    """Progress bar with custom description for export/inference passes."""

    def __init__(self, description: str, refresh_rate: int = 1):
        if TQDMProgressBar is None:  # pragma: no cover - fallback when class unavailable
            self._description = description
            return
        super().__init__(refresh_rate=refresh_rate)
        self._description = description

    def init_test_tqdm(self):  # pragma: no cover - Lightning handles display
        if TQDMProgressBar is None:
            return None
        bar = super().init_test_tqdm()
        if bar is not None:
            bar.set_description(self._description)
        return bar


[docs] @dataclass class GeneStageConfig: """Configuration bundle for the gene-wise fine-tuning stage. Device and distribution arguments mirror those in :class:`gravity.pipeline.PipelineConfig`. ``gene_order_path`` should point to the same gene list used by the stage-1 checkpoint when reproducing pretrained/reference runs. """ raw_counts: str middle_csv: str stage1_checkpoint: str future_positions: str prior_network: Optional[str] = './prior_data/nichenet_mouse.zip' output_dir: str = 'gravity_outputs' stage2_csv: str = 'stage2.csv' checkpoint_name: str = 'stage2.ckpt' pretrained_checkpoint: Optional[str] = None gene_subset: Optional[Sequence[str]] = None gene_order_path: Optional[str] = None batch_size: int = 32 epochs: int = 6 accelerator: str = 'auto' devices: Optional[Union[int, Sequence[int]]] = None num_workers: int = 0 val_fraction: float = 0.0 precision: Optional[Union[int, str]] = None gradient_clip_val: Optional[float] = None strategy: Optional[str] = None seed: int = 42 log_every_n_steps: int = 50 progress_bar: bool = True learning_rate: float = 1e-4 embedding_size: int = 16 model_dimension: int = 16 ffn_dimension: int = 16
def _split(dataset, val_fraction: float, seed: int): """Split a dataset deterministically into train/val subsets if requested.""" if not 0.0 < val_fraction < 1.0: return dataset, None total = len(dataset) if total <= 1: return dataset, None val_size = max(1, int(total * val_fraction)) train_size = total - val_size if train_size <= 0: train_size = total - 1 val_size = 1 generator = torch.Generator().manual_seed(seed) train_subset, val_subset = random_split(dataset, [train_size, val_size], generator=generator) return train_subset, val_subset
[docs] def train_gene_stage(config: GeneStageConfig) -> Dict[str, Path]: """Run the GRAVITY gene-wise refinement stage. Returns ------- Dict[str, Path] Paths to ``stage2_csv``, ``checkpoint`` and the gene/prior files. """ pl.seed_everything(config.seed, workers=True) output_dir = Path(config.output_dir).resolve() output_dir.mkdir(parents=True, exist_ok=True) stage2_csv_path = output_dir / config.stage2_csv checkpoint_path = output_dir / config.checkpoint_name genes_path = output_dir / 'genes.txt' mapper_path = output_dir / 'genemap.json' effective_gene_subset = resolve_gene_order(config.gene_subset, config.gene_order_path) dataset = load_gene_stage_dataset( config.middle_csv, prior_path=config.prior_network, future_positions=config.future_positions, gene_list=effective_gene_subset, ) total_cells = len(dataset) log_verbose( f"[gravity] stage2 dataset loaded: {total_cells} cells; val_fraction={config.val_fraction}", level=1, ) assert_gene_order_matches(genes_path, dataset.hvg, label="stage2 genes.txt") skip_stage2 = False if config.pretrained_checkpoint is None and stage2_csv_path.exists() and checkpoint_path.exists(): try: df_stats = pd.read_csv(stage2_csv_path, usecols=['cellIndex', 'alpha', 'beta']) existing_cells = df_stats['cellIndex'].nunique() has_complete_rates = df_stats[['alpha', 'beta']].notna().all().all() except Exception as exc: log_verbose(f"[gravity] failed to inspect existing stage2 CSV ({exc}); retraining.", level=1) existing_cells = -1 has_complete_rates = False if existing_cells == len(dataset) and has_complete_rates: skip_stage2 = True log_verbose("[gravity] stage2 outputs match current dataset; reusing and skipping training.", level=1) else: log_verbose( f"[gravity] existing stage2 outputs mismatch dataset (cells: {existing_cells} vs {len(dataset)}; complete rates: {has_complete_rates}); retraining.", level=1, ) raw_counts_path = resolve_path(config.raw_counts) raw_df = pd.read_csv(raw_counts_path) if 'cellIndex' not in raw_df.columns: raw_df.insert(0, 'cellIndex', pd.factorize(raw_df['cellID'])[0]) raw_df['cellIndex'] = raw_df['cellIndex'].astype(int) if 'clusters' in raw_df.columns: clusters_series = raw_df['clusters'] else: clusters_series = pd.Series(['NA'] * len(raw_df), index=raw_df.index) embedding_map = dict(zip(raw_df['cellIndex'], clusters_series)) id_map = dict(zip(raw_df['cellIndex'], raw_df['cellID'])) if 'cellID' in raw_df.columns else {idx: str(idx) for idx in raw_df['cellIndex']} hvgs = dataset.hvg genes_path = output_dir / 'genes.txt' with genes_path.open('w') as fp: for gene in hvgs: fp.write(f"{gene}\n") mapper_path = output_dir / 'genemap.json' with mapper_path.open('w') as fp: json.dump(dataset.niche_dict, fp, indent=2) if skip_stage2: return { 'stage2_csv': stage2_csv_path, 'checkpoint': checkpoint_path, 'genes_path': genes_path, 'gene_map': mapper_path, } train_subset, val_subset = _split(dataset, config.val_fraction, config.seed) if val_subset is None: log_verbose("[gravity] stage2 training uses all cells; no validation split.", level=1) else: log_verbose( f"[gravity] stage2 split → train: {len(train_subset)} cells, val: {len(val_subset)} cells", level=1, ) train_loader = DataLoader(train_subset if val_subset is not None else dataset, batch_size=config.batch_size, shuffle=False, num_workers=config.num_workers) val_loader = None if val_subset is not None: val_loader = DataLoader(val_subset, batch_size=config.batch_size, shuffle=False, num_workers=config.num_workers) stage2_csv_path = output_dir / config.stage2_csv checkpoint_path = output_dir / config.checkpoint_name stage1_checkpoint_path = resolve_path(config.stage1_checkpoint) model_kwargs = dict( input_dim=len(hvgs) * 2 + 3, gene_list=hvgs, TF_list=dataset.TF_from_net, TG_list=dataset.TG_from_net, TFTG_map=dataset.niche_dict, TFTG_map_reverse=dataset.niche_dict_reverse, output_dim_trans=len(hvgs), embedding_map=embedding_map, id_map=id_map, gene_select=hvgs, nbrs=dataset.nbrs, origin_data=dataset.data, negs=None, output_csv=str(stage2_csv_path), embedding_size=config.embedding_size, model_dimension=config.model_dimension, ffn_dimension=config.ffn_dimension, ) devices = config.devices if config.devices is not None else 1 trainer_kwargs = dict( accelerator=config.accelerator, devices=devices, max_epochs=config.epochs, logger=False, enable_checkpointing=False, log_every_n_steps=config.log_every_n_steps, enable_progress_bar=config.progress_bar, default_root_dir=str(output_dir), ) if config.precision is not None: trainer_kwargs['precision'] = config.precision if config.gradient_clip_val is not None: trainer_kwargs['gradient_clip_val'] = config.gradient_clip_val if config.strategy is not None: trainer_kwargs['strategy'] = config.strategy def _export(model: FullModelGeneWise, description: str, *, save_checkpoint: bool = True) -> None: single_test_loader = DataLoader(dataset, batch_size=config.batch_size, shuffle=False, num_workers=config.num_workers) log_verbose( f"[gravity] {description}: refining velocities & writing outputs...", level=1, ) export_callbacks = [] if config.progress_bar and TQDMProgressBar is not None: export_callbacks.append(ExportProgressBar("Stage2 infer/export")) single_tester = pl.Trainer( accelerator='auto', devices=1, logger=False, enable_checkpointing=False, enable_progress_bar=config.progress_bar, default_root_dir=str(output_dir), callbacks=export_callbacks, ) single_tester.test(model, dataloaders=single_test_loader) if save_checkpoint: single_tester.save_checkpoint(str(checkpoint_path)) log_verbose("[gravity] stage2 infer/export finished; synchronising files...", level=1) for _ in range(60): if stage2_csv_path.exists(): break time.sleep(1.0) if not stage2_csv_path.exists(): raise RuntimeError(f"Stage2 CSV was not written: {stage2_csv_path}") if config.pretrained_checkpoint is not None: pretrained_path = resolve_path(config.pretrained_checkpoint) log_verbose(f"[gravity] loading stage2 checkpoint for inference/export: {pretrained_path}", level=1) model = FullModelGeneWise.load_from_checkpoint(str(pretrained_path), **model_kwargs) _export(model, "stage2 checkpoint infer/export") return { 'stage2_csv': stage2_csv_path, 'checkpoint': checkpoint_path, 'genes_path': genes_path, 'gene_map': mapper_path, } model = FullModelGeneWise.load_from_checkpoint(str(stage1_checkpoint_path), **model_kwargs) for param in model.parameters(): param.requires_grad = False linear_layers = getattr(model.GravityModel.solver, "linear_layer_names", []) if not linear_layers: trainable_layers = [] else: k = min(3, len(linear_layers)) trainable_layers = linear_layers[-k:] trainable_count = 0 for name, param in model.GravityModel.solver.named_parameters(): if any(name.startswith(layer_name) for layer_name in trainable_layers): param.requires_grad = True trainable_count += 1 if trainable_count == 0: raise RuntimeError( "No solver parameters left trainable; check linear_layer_names or configuration." ) trainer = pl.Trainer(**trainer_kwargs) trainer.fit(model, train_loader, val_loader) if getattr(trainer, "global_rank", 0) == 0: single_model = FullModelGeneWise(**model_kwargs) single_model.load_state_dict(model.state_dict()) _export(single_model, "stage2 infer/export", save_checkpoint=False) trainer.save_checkpoint(str(checkpoint_path)) world_size = 1 if hasattr(trainer, "strategy") and hasattr(trainer.strategy, "world_size"): world_size = trainer.strategy.world_size elif hasattr(trainer, "training_type_plugin") and hasattr(trainer.training_type_plugin, "world_size"): world_size = trainer.training_type_plugin.world_size if world_size and world_size > 1: barrier_fn = None if hasattr(trainer, "strategy") and hasattr(trainer.strategy, "barrier"): barrier_fn = trainer.strategy.barrier elif hasattr(trainer, "training_type_plugin") and hasattr(trainer.training_type_plugin, "barrier"): barrier_fn = trainer.training_type_plugin.barrier if barrier_fn is not None: barrier_fn() return { 'stage2_csv': stage2_csv_path, 'checkpoint': checkpoint_path, 'genes_path': genes_path, 'gene_map': mapper_path, }