gravity Package

User-friendly entry points for the GRAVITY library.

class gravity.PipelineConfig(raw_counts, workdir='gravity_outputs', prior_network='./prior_data/nichenet_mouse.zip', gene_subset=None, gene_order_path=None, batch_size=16, n_pos_neighbors=30, n_neg_neighbors=10, stage1_epochs=6, stage2_epochs=6, stage1_lr=1e-06, stage2_lr=0.0001, stage1_pretrained_checkpoint=None, stage2_pretrained_checkpoint=None, embedding_size=16, model_dimension=16, ffn_dimension=16, val_fraction_stage1=0.0, val_fraction_stage2=0.0, accelerator='auto', devices=None, num_workers=8, strategy=None, precision=None, gradient_clip_val=None, future_tau=0.5, log_every_n_steps=50, progress_bar=True, make_plot=False, plot_gene=None, plot_color='clusters', plot_genes=None, arrow_grid=(20, 20), arrow_scale=1.0, middle_csv_name='combine.csv', stage1_csv_name='stage1.csv', stage2_csv_name='stage2.csv', future_positions_name='future_positions.npy', stage1_checkpoint_name='stage1.ckpt', stage2_checkpoint_name='stage2.ckpt')[source]

Bases: object

Configuration for the full GRAVITY pipeline.

Parameters:
  • raw_counts (str) – Path to the input long-format CSV (must include at least cellID, gene_name, unsplice, splice, embedding1, embedding2).

  • workdir (str) – Output directory where intermediate and final artifacts are written.

  • prior_network (str | None) – Path to the prior TF–target network archive used by GRAVITY.

  • gene_subset (Sequence[str] | None) – Optional list of genes to restrict training and evaluation.

  • gene_order_path (str | None) – Optional newline-delimited gene list that fixes gene-index order. Use this when running pretrained/reference checkpoints; model weights and attention matrices are aligned by gene position, not only by gene name.

  • batch_size (int) – Mini-batch size used by both training stages.

  • stage1_epochs (int) – Number of epochs per stage.

  • stage2_epochs (int) – Number of epochs per stage.

  • stage1_lr (float) – Learning rates per stage. The unsupervised cell-wise stage and contrastive gene-wise stage are somewhat learning-rate sensitive; for reference-style runs we recommend keeping stage1_lr below 1e-5 and tuning stage2_lr between 1e-3 and 1e-5.

  • stage2_lr (float) – Learning rates per stage. The unsupervised cell-wise stage and contrastive gene-wise stage are somewhat learning-rate sensitive; for reference-style runs we recommend keeping stage1_lr below 1e-5 and tuning stage2_lr between 1e-3 and 1e-5.

  • stage1_pretrained_checkpoint (str | None) – Optional checkpoints used for inference/export instead of training the corresponding stage. Use these for checkpoint-based reference reproduction.

  • stage2_pretrained_checkpoint (str | None) – Optional checkpoints used for inference/export instead of training the corresponding stage. Use these for checkpoint-based reference reproduction.

  • val_fraction_stage1 (float) – Fraction of data reserved for validation in each stage.

  • val_fraction_stage2 (float) – Fraction of data reserved for validation in each stage.

  • accelerator (str) – Forwarded to PyTorch Lightning trainers to control device placement, distribution strategy and numerical precision.

  • devices (int | Sequence[int] | None) – Forwarded to PyTorch Lightning trainers to control device placement, distribution strategy and numerical precision.

  • strategy (str | None) – Forwarded to PyTorch Lightning trainers to control device placement, distribution strategy and numerical precision.

  • precision (int | str | None) – Forwarded to PyTorch Lightning trainers to control device placement, distribution strategy and numerical precision.

  • gradient_clip_val (float | None) – Forwarded to PyTorch Lightning trainers to control device placement, distribution strategy and numerical precision.

  • num_workers (int) – Forwarded to PyTorch Lightning trainers to control device placement, distribution strategy and numerical precision.

  • future_tau (float) – Scaling factor governing the radius used in future-neighbor search.

  • log_every_n_steps (int) – Logging and progress display controls.

  • progress_bar (bool) – Logging and progress display controls.

  • make_plot (bool) – Plotting options for optional velocity visualization.

  • plot_gene (str | None) – Plotting options for optional velocity visualization.

  • plot_color (str | None) – Plotting options for optional velocity visualization.

  • plot_genes (str | Sequence[str] | None) – Plotting options for optional velocity visualization.

  • arrow_grid (Tuple[int, int]) – Plotting options for optional velocity visualization.

  • arrow_scale (float) – Plotting options for optional velocity visualization.

  • middle_csv_name (str) – Filenames for artifacts written under workdir.

  • stage*_csv_name – Filenames for artifacts written under workdir.

  • future_positions_name (str) – Filenames for artifacts written under workdir.

  • stage*_checkpoint_name – Filenames for artifacts written under workdir.

  • n_pos_neighbors (int)

  • n_neg_neighbors (int)

  • embedding_size (int)

  • model_dimension (int)

  • ffn_dimension (int)

  • stage1_csv_name (str)

  • stage2_csv_name (str)

  • stage1_checkpoint_name (str)

  • stage2_checkpoint_name (str)

raw_counts: str
workdir: str = 'gravity_outputs'
prior_network: str | None = './prior_data/nichenet_mouse.zip'
gene_subset: Sequence[str] | None = None
gene_order_path: str | None = None
batch_size: int = 16
n_pos_neighbors: int = 30
n_neg_neighbors: int = 10
stage1_epochs: int = 6
stage2_epochs: int = 6
stage1_lr: float = 1e-06
stage2_lr: float = 0.0001
stage1_pretrained_checkpoint: str | None = None
stage2_pretrained_checkpoint: str | None = None
embedding_size: int = 16
model_dimension: int = 16
ffn_dimension: int = 16
val_fraction_stage1: float = 0.0
val_fraction_stage2: float = 0.0
accelerator: str = 'auto'
devices: int | Sequence[int] | None = None
num_workers: int = 8
strategy: str | None = None
precision: int | str | None = None
gradient_clip_val: float | None = None
future_tau: float = 0.5
log_every_n_steps: int = 50
progress_bar: bool = True
make_plot: bool = False
plot_gene: str | None = None
plot_color: str | None = 'clusters'
plot_genes: str | Sequence[str] | None = None
arrow_grid: Tuple[int, int] = (20, 20)
arrow_scale: float = 1.0
middle_csv_name: str = 'combine.csv'
stage1_csv_name: str = 'stage1.csv'
stage2_csv_name: str = 'stage2.csv'
future_positions_name: str = 'future_positions.npy'
stage1_checkpoint_name: str = 'stage1.ckpt'
stage2_checkpoint_name: str = 'stage2.ckpt'
gravity.run_pipeline(config)[source]

Execute preprocessing, two training stages, future projection, and optional plotting.

Parameters:

config (PipelineConfig) – PipelineConfig instance that specifies inputs, training and plotting options, and output filenames.

Returns:

Dict[str, pathlib.Path] – Mapping of artifact names to absolute paths under workdir. Keys include middle_csv, stage1_csv, stage1_checkpoint, attention_dir, future_positions, stage2_csv, and stage2_checkpoint. If plotting is enabled, additional keys may be present for generated figures.

Return type:

Dict[str, Path]

Notes

Multi-GPU controls in config are forwarded to PyTorch Lightning inside the stage trainers. Preprocessing, future projection and plotting always run in the main process. When using DDP with strategy='ddp' (spawn), child processes may re-import and execute the entry script; to avoid duplicated orchestration, this function performs an early return on non-zero ranks.

gravity.preprocess_counts(input_file, output_csv, *, gene_order=None)[source]

Prepare the cell-wise training table from a long single-cell CSV.

Parameters:
  • input_file (str) – Path to the raw long-format CSV with columns including cellID, gene_name, unsplice, splice, embedding1, embedding2.

  • output_csv (str) – Destination CSV containing one row per cell with serialized gene tuples.

  • gene_order (Sequence[str] | None) – Optional ordered gene list. When provided, matching gene columns are written first in this order so downstream checkpoints see the intended gene-index layout.

Returns:

pathlib.Path – The path to the generated intermediate CSV.

Return type:

Path

gravity.load_gene_order(path)[source]

Load a newline-delimited gene order file.

The order is part of the model contract for pretrained checkpoints because attention and solver tensors are indexed by gene position.

Parameters:

path (str)

Return type:

list[str]

gravity.resolve_gene_order(gene_subset=None, gene_order_path=None)[source]

Resolve the effective gene list, preserving checkpoint-compatible order.

Parameters:
  • gene_subset (Sequence[str] | None)

  • gene_order_path (str | None)

Return type:

list[str] | None

gravity.export_intermediate_from_h5ad(input_h5ad, output_csv, *, retain_genes=None, min_shared_counts=20, n_top_genes=2000, n_pcs=30, n_neighbors=30, embed_key='X_umap', celltype_key='cell_type', overwrite=False, normalized=False)[source]

Preprocess AnnData for GRAVITY and export a cellDancer-style CSV.

This function performs only the preprocessing required to create GRAVITY’s user-facing long-format count table. It does not run RNA velocity inference, does not plot, and does not save a processed .h5ad.

Pipeline

  1. Read the input AnnData (.h5ad).

  2. Optionally force-keep user-specified genes through filtering/HVG selection.

  3. Normalize and filter genes with the local preprocessing settings.

  4. Compute first-order moments with scv.pp.moments (produces ‘Mu’/’Ms’).

  5. Export a cellDancer-style CSV via the local adata_to_df_with_embed function, including Mu/Ms, the chosen 2D embedding, and cell-type labels.

param input_h5ad:

Path to an AnnData file that includes ‘spliced’ and ‘unspliced’ layers.

param output_csv:

Destination CSV file path for the cellDancer-style long-format table.

param retain_genes:

Genes that must be retained during filtering/HVG selection (if present).

param min_shared_counts:

Minimum shared counts across cells for gene filtering.

param n_top_genes:

Number of highly variable genes to retain (in addition to retain_genes).

param n_pcs:

Number of principal components used by scv.pp.moments.

param n_neighbors:

Neighborhood size used by scv.pp.moments.

param embed_key:

Key in adata.obsm for a 2D embedding (e.g., “X_umap”).

param celltype_key:

Column in adata.obs that holds cell-type labels (e.g., “cell_type” or “celltype”).

param overwrite:

If False and output_csv exists, skip work and return the existing path.

returns:

pathlib.Path – The path to the generated CSV.

raises KeyError:

If embed_key is not found in adata.obsm.

raises RuntimeError:

If required layers (‘spliced’ and ‘unspliced’) are missing.

raises ImportError:

If scanpy or scvelo are missing.

Parameters:
  • input_h5ad (str)

  • output_csv (str)

  • retain_genes (Sequence[str] | None)

  • min_shared_counts (int)

  • n_top_genes (int)

  • n_pcs (int)

  • n_neighbors (int)

  • embed_key (str)

  • celltype_key (str)

  • overwrite (bool)

  • normalized (bool)

Return type:

Path

class gravity.CellStageConfig(raw_counts, middle_csv, prior_network='./prior_data/nichenet_mouse.zip', output_dir='gravity_outputs', stage1_csv='stage1.csv', checkpoint_name='stage1.ckpt', pretrained_checkpoint=None, attention_dir='attentions', gene_subset=None, gene_order_path=None, batch_size=32, epochs=6, accelerator='auto', devices=None, num_workers=0, val_fraction=0.0, attention_topk=64, attention_output=True, precision=None, gradient_clip_val=None, strategy=None, seed=42, log_every_n_steps=50, progress_bar=True, learning_rate=1e-06, embedding_size=16, model_dimension=16, ffn_dimension=16, n_pos_neighbors=30, n_neg_neighbors=10)[source]

Bases: object

Configuration bundle for the cell-wise training stage.

See also the top-level gravity.pipeline.PipelineConfig for how device/distribution options propagate. gene_order_path fixes the checkpoint-compatible gene index order when using pretrained/reference weights.

Parameters:
  • raw_counts (str)

  • middle_csv (str)

  • prior_network (str | None)

  • output_dir (str)

  • stage1_csv (str)

  • checkpoint_name (str)

  • pretrained_checkpoint (str | None)

  • attention_dir (str)

  • gene_subset (Sequence[str] | None)

  • gene_order_path (str | None)

  • batch_size (int)

  • epochs (int)

  • accelerator (str)

  • devices (int | Sequence[int] | None)

  • num_workers (int)

  • val_fraction (float)

  • attention_topk (int)

  • attention_output (bool)

  • precision (int | str | None)

  • gradient_clip_val (float | None)

  • strategy (str | None)

  • seed (int)

  • log_every_n_steps (int)

  • progress_bar (bool)

  • learning_rate (float)

  • embedding_size (int)

  • model_dimension (int)

  • ffn_dimension (int)

  • n_pos_neighbors (int)

  • n_neg_neighbors (int)

raw_counts: str
middle_csv: str
prior_network: str | None = './prior_data/nichenet_mouse.zip'
output_dir: str = 'gravity_outputs'
stage1_csv: str = 'stage1.csv'
checkpoint_name: str = 'stage1.ckpt'
pretrained_checkpoint: str | None = None
attention_dir: str = 'attentions'
gene_subset: Sequence[str] | None = None
gene_order_path: str | None = None
batch_size: int = 32
epochs: int = 6
accelerator: str = 'auto'
devices: int | Sequence[int] | None = None
num_workers: int = 0
val_fraction: float = 0.0
attention_topk: int = 64
attention_output: bool = True
precision: int | str | None = None
gradient_clip_val: float | None = None
strategy: str | None = None
seed: int = 42
log_every_n_steps: int = 50
progress_bar: bool = True
learning_rate: float = 1e-06
embedding_size: int = 16
model_dimension: int = 16
ffn_dimension: int = 16
n_pos_neighbors: int = 30
n_neg_neighbors: int = 10
gravity.train_cell_stage(config)[source]

Run the GRAVITY cell-wise stage and export key artefacts.

Returns:

Dict[str, Path] – Paths to stage1_csv, checkpoint, attention outputs (optional), and the generated gene files used by later steps.

Parameters:

config (CellStageConfig)

Return type:

Dict[str, Path]

class gravity.GeneStageConfig(raw_counts, middle_csv, stage1_checkpoint, future_positions, prior_network='./prior_data/nichenet_mouse.zip', output_dir='gravity_outputs', stage2_csv='stage2.csv', checkpoint_name='stage2.ckpt', pretrained_checkpoint=None, gene_subset=None, gene_order_path=None, batch_size=32, epochs=6, accelerator='auto', devices=None, num_workers=0, val_fraction=0.0, precision=None, gradient_clip_val=None, strategy=None, seed=42, log_every_n_steps=50, progress_bar=True, learning_rate=0.0001, embedding_size=16, model_dimension=16, ffn_dimension=16)[source]

Bases: object

Configuration bundle for the gene-wise fine-tuning stage.

Device and distribution arguments mirror those in gravity.pipeline.PipelineConfig. gene_order_path should point to the same gene list used by the stage-1 checkpoint when reproducing pretrained/reference runs.

Parameters:
  • raw_counts (str)

  • middle_csv (str)

  • stage1_checkpoint (str)

  • future_positions (str)

  • prior_network (str | None)

  • output_dir (str)

  • stage2_csv (str)

  • checkpoint_name (str)

  • pretrained_checkpoint (str | None)

  • gene_subset (Sequence[str] | None)

  • gene_order_path (str | None)

  • batch_size (int)

  • epochs (int)

  • accelerator (str)

  • devices (int | Sequence[int] | None)

  • num_workers (int)

  • val_fraction (float)

  • precision (int | str | None)

  • gradient_clip_val (float | None)

  • strategy (str | None)

  • seed (int)

  • log_every_n_steps (int)

  • progress_bar (bool)

  • learning_rate (float)

  • embedding_size (int)

  • model_dimension (int)

  • ffn_dimension (int)

raw_counts: str
middle_csv: str
stage1_checkpoint: str
future_positions: str
prior_network: str | None = './prior_data/nichenet_mouse.zip'
output_dir: str = 'gravity_outputs'
stage2_csv: str = 'stage2.csv'
checkpoint_name: str = 'stage2.ckpt'
pretrained_checkpoint: str | None = None
gene_subset: Sequence[str] | None = None
gene_order_path: str | None = None
batch_size: int = 32
epochs: int = 6
accelerator: str = 'auto'
devices: int | Sequence[int] | None = None
num_workers: int = 0
val_fraction: float = 0.0
precision: int | str | None = None
gradient_clip_val: float | None = None
strategy: str | None = None
seed: int = 42
log_every_n_steps: int = 50
progress_bar: bool = True
learning_rate: float = 0.0001
embedding_size: int = 16
model_dimension: int = 16
ffn_dimension: int = 16
gravity.train_gene_stage(config)[source]

Run the GRAVITY gene-wise refinement stage.

Returns:

Dict[str, Path] – Paths to stage2_csv, checkpoint and the gene/prior files.

Parameters:

config (GeneStageConfig)

Return type:

Dict[str, Path]

class gravity.FullModelCellWise(input_dim, TF_list, TG_list, TFTG_map, TFTG_map_reverse, embedding_size=16, model_dimension=16, ffn_dimension=16, output_dim_trans=10, output_dim_dense=20, gene_list=None, embedding_map=None, id_map=None, gene_select=None, nbrs=None, origin_data=None, negs=None, output_csv=None, attention_output=True, output_network_path='attentions', cell_to_type=None, gene_list_path='genes.txt', gene_mapper_path='genemap.json', csr_topk=64, learning_rate=1e-06)[source]

Bases: LightningModule

PyTorch Lightning module encapsulating the cell-wise GRAVITY stage.

Parameters:
  • input_dim (int)

  • TF_list (Sequence[str])

  • TG_list (Sequence[str])

  • TFTG_map (Dict[str, Sequence[str]])

  • TFTG_map_reverse (Dict[str, Sequence[str]])

  • embedding_size (int)

  • model_dimension (int)

  • ffn_dimension (int)

  • output_dim_trans (int)

  • output_dim_dense (int)

  • gene_list (Optional[Sequence[str]])

  • embedding_map (Optional[Dict[int, str]])

  • id_map (Optional[Dict[int, str]])

  • gene_select (Optional[Sequence[str]])

  • nbrs (Optional[np.ndarray])

  • origin_data (Optional[pd.DataFrame])

  • negs (Optional[np.ndarray])

  • output_csv (Optional[str])

  • attention_output (bool)

  • output_network_path (str)

  • cell_to_type (Optional[Dict[str, str]])

  • gene_list_path (str)

  • gene_mapper_path (str)

  • csr_topk (int)

  • learning_rate (float)

forward(x, mask=None)[source]

Same as torch.nn.Module.forward().

Args:

*args: Whatever you decide to pass into the forward method. **kwargs: Keyword arguments are also possible.

Return:

Your model’s output

Parameters:
  • x (Tensor)

  • mask (Tensor | None)

training_step(batch, batch_idx)[source]

Here you compute and return the training loss and some additional metrics for e.g. the progress bar or logger.

Args:
batch (Tensor | (Tensor, …) | [Tensor, …]):

The output of your DataLoader. A tensor, tuple or list.

batch_idx (int): Integer displaying index of this batch

Return:

Any of.

  • Tensor - The loss tensor

  • dict - A dictionary. Can include any keys, but must include the key 'loss'

  • None - Training will skip to the next batch. This is only for automatic optimization.

    This is not supported for multi-GPU, TPU, IPU, or DeepSpeed.

In this step you’d normally do the forward pass and calculate the loss for a batch. You can also do fancier things like multiple forward passes or something model specific.

Example:

def training_step(self, batch, batch_idx):
    x, y, z = batch
    out = self.encoder(x)
    loss = self.loss(out, x)
    return loss

To use multiple optimizers, you can switch to ‘manual optimization’ and control their stepping:

def __init__(self):
    super().__init__()
    self.automatic_optimization = False


# Multiple optimizers (e.g.: GANs)
def training_step(self, batch, batch_idx):
    opt1, opt2 = self.optimizers()

    # do training_step with encoder
    ...
    opt1.step()
    # do training_step with decoder
    ...
    opt2.step()
Note:

When accumulate_grad_batches > 1, the loss returned here will be automatically normalized by accumulate_grad_batches internally.

test_step(batch, batch_idx)[source]

Operates on a single batch of data from the test set. In this step you’d normally generate examples or calculate anything of interest such as accuracy.

Args:

batch: The output of your DataLoader. batch_idx: The index of this batch. dataloader_id: The index of the dataloader that produced this batch.

(only if multiple test dataloaders used).

Return:

Any of.

  • Any object or value

  • None - Testing will skip to the next batch

# if you have one test dataloader:
def test_step(self, batch, batch_idx):
    ...


# if you have multiple test dataloaders:
def test_step(self, batch, batch_idx, dataloader_idx=0):
    ...

Examples:

# CASE 1: A single test dataset
def test_step(self, batch, batch_idx):
    x, y = batch

    # implement your own
    out = self(x)
    loss = self.loss(out, y)

    # log 6 example images
    # or generated text... or whatever
    sample_imgs = x[:6]
    grid = torchvision.utils.make_grid(sample_imgs)
    self.logger.experiment.add_image('example_images', grid, 0)

    # calculate acc
    labels_hat = torch.argmax(out, dim=1)
    test_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)

    # log the outputs!
    self.log_dict({'test_loss': loss, 'test_acc': test_acc})

If you pass in multiple test dataloaders, test_step() will have an additional argument. We recommend setting the default value of 0 so that you can quickly switch between single and multiple dataloaders.

# CASE 2: multiple test dataloaders
def test_step(self, batch, batch_idx, dataloader_idx=0):
    # dataloader_idx tells you which dataset this is.
    ...
Note:

If you don’t need to test you don’t need to implement this method.

Note:

When the test_step() is called, the model has been put in eval mode and PyTorch gradients have been disabled. At the end of the test epoch, the model goes back to training mode and gradients are enabled.

on_test_epoch_end()[source]

Called in the test loop at the very end of the epoch.

Return type:

None

configure_optimizers()[source]

Choose what optimizers and learning-rate schedulers to use in your optimization. Normally you’d need one. But in the case of GANs or similar you might have multiple. Optimization with multiple optimizers only works in the manual optimization mode.

Return:

Any of these 6 options.

  • Single optimizer.

  • List or Tuple of optimizers.

  • Two lists - The first list has multiple optimizers, and the second has multiple LR schedulers (or multiple lr_scheduler_config).

  • Dictionary, with an "optimizer" key, and (optionally) a "lr_scheduler" key whose value is a single LR scheduler or lr_scheduler_config.

  • None - Fit will run without any optimizer.

The lr_scheduler_config is a dictionary which contains the scheduler and its associated configuration. The default configuration is shown below.

lr_scheduler_config = {
    # REQUIRED: The scheduler instance
    "scheduler": lr_scheduler,
    # The unit of the scheduler's step size, could also be 'step'.
    # 'epoch' updates the scheduler on epoch end whereas 'step'
    # updates it after a optimizer update.
    "interval": "epoch",
    # How many epochs/steps should pass between calls to
    # `scheduler.step()`. 1 corresponds to updating the learning
    # rate after every epoch/step.
    "frequency": 1,
    # Metric to to monitor for schedulers like `ReduceLROnPlateau`
    "monitor": "val_loss",
    # If set to `True`, will enforce that the value specified 'monitor'
    # is available when the scheduler is updated, thus stopping
    # training if not found. If set to `False`, it will only produce a warning
    "strict": True,
    # If using the `LearningRateMonitor` callback to monitor the
    # learning rate progress, this keyword can be used to specify
    # a custom logged name
    "name": None,
}

When there are schedulers in which the .step() method is conditioned on a value, such as the torch.optim.lr_scheduler.ReduceLROnPlateau scheduler, Lightning requires that the lr_scheduler_config contains the keyword "monitor" set to the metric name that the scheduler should be conditioned on.

Metrics can be made available to monitor by simply logging it using self.log('metric_to_track', metric_val) in your LightningModule.

Note:

Some things to know:

  • Lightning calls .backward() and .step() automatically in case of automatic optimization.

  • If a learning rate scheduler is specified in configure_optimizers() with key "interval" (default “epoch”) in the scheduler configuration, Lightning will call the scheduler’s .step() method automatically in case of automatic optimization.

  • If you use 16-bit precision (precision=16), Lightning will automatically handle the optimizer.

  • If you use torch.optim.LBFGS, Lightning handles the closure function automatically for you.

  • If you use multiple optimizers, you will have to switch to ‘manual optimization’ mode and step them yourself.

  • If you need to control how often the optimizer steps, override the optimizer_step() hook.

class gravity.FullModelGeneWise(input_dim, TF_list, TG_list, TFTG_map, TFTG_map_reverse, embedding_size=16, model_dimension=16, ffn_dimension=16, output_dim_trans=10, output_dim_dense=20, gene_list=None, embedding_map=None, id_map=None, gene_select=None, nbrs=None, origin_data=None, negs=None, output_csv=None, learning_rate=0.0001)[source]

Bases: LightningModule

Parameters:
  • input_dim (int)

  • TF_list (Sequence[str])

  • TG_list (Sequence[str])

  • TFTG_map (Dict[str, Sequence[str]])

  • TFTG_map_reverse (Dict[str, Sequence[str]])

  • embedding_size (int)

  • model_dimension (int)

  • ffn_dimension (int)

  • output_dim_trans (int)

  • output_dim_dense (int)

  • gene_list (Optional[Sequence[str]])

  • embedding_map (Optional[Dict[int, str]])

  • id_map (Optional[Dict[int, str]])

  • gene_select (Optional[Sequence[str]])

  • nbrs (Optional[np.ndarray])

  • origin_data (Optional[pd.DataFrame])

  • negs (Optional[np.ndarray])

  • output_csv (Optional[str])

  • learning_rate (float)

forward(x, mask=None)[source]

Same as torch.nn.Module.forward().

Args:

*args: Whatever you decide to pass into the forward method. **kwargs: Keyword arguments are also possible.

Return:

Your model’s output

training_step(batch, batch_idx)[source]

Here you compute and return the training loss and some additional metrics for e.g. the progress bar or logger.

Args:
batch (Tensor | (Tensor, …) | [Tensor, …]):

The output of your DataLoader. A tensor, tuple or list.

batch_idx (int): Integer displaying index of this batch

Return:

Any of.

  • Tensor - The loss tensor

  • dict - A dictionary. Can include any keys, but must include the key 'loss'

  • None - Training will skip to the next batch. This is only for automatic optimization.

    This is not supported for multi-GPU, TPU, IPU, or DeepSpeed.

In this step you’d normally do the forward pass and calculate the loss for a batch. You can also do fancier things like multiple forward passes or something model specific.

Example:

def training_step(self, batch, batch_idx):
    x, y, z = batch
    out = self.encoder(x)
    loss = self.loss(out, x)
    return loss

To use multiple optimizers, you can switch to ‘manual optimization’ and control their stepping:

def __init__(self):
    super().__init__()
    self.automatic_optimization = False


# Multiple optimizers (e.g.: GANs)
def training_step(self, batch, batch_idx):
    opt1, opt2 = self.optimizers()

    # do training_step with encoder
    ...
    opt1.step()
    # do training_step with decoder
    ...
    opt2.step()
Note:

When accumulate_grad_batches > 1, the loss returned here will be automatically normalized by accumulate_grad_batches internally.

test_step(batch, batch_idx)[source]

Operates on a single batch of data from the test set. In this step you’d normally generate examples or calculate anything of interest such as accuracy.

Args:

batch: The output of your DataLoader. batch_idx: The index of this batch. dataloader_id: The index of the dataloader that produced this batch.

(only if multiple test dataloaders used).

Return:

Any of.

  • Any object or value

  • None - Testing will skip to the next batch

# if you have one test dataloader:
def test_step(self, batch, batch_idx):
    ...


# if you have multiple test dataloaders:
def test_step(self, batch, batch_idx, dataloader_idx=0):
    ...

Examples:

# CASE 1: A single test dataset
def test_step(self, batch, batch_idx):
    x, y = batch

    # implement your own
    out = self(x)
    loss = self.loss(out, y)

    # log 6 example images
    # or generated text... or whatever
    sample_imgs = x[:6]
    grid = torchvision.utils.make_grid(sample_imgs)
    self.logger.experiment.add_image('example_images', grid, 0)

    # calculate acc
    labels_hat = torch.argmax(out, dim=1)
    test_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)

    # log the outputs!
    self.log_dict({'test_loss': loss, 'test_acc': test_acc})

If you pass in multiple test dataloaders, test_step() will have an additional argument. We recommend setting the default value of 0 so that you can quickly switch between single and multiple dataloaders.

# CASE 2: multiple test dataloaders
def test_step(self, batch, batch_idx, dataloader_idx=0):
    # dataloader_idx tells you which dataset this is.
    ...
Note:

If you don’t need to test you don’t need to implement this method.

Note:

When the test_step() is called, the model has been put in eval mode and PyTorch gradients have been disabled. At the end of the test epoch, the model goes back to training mode and gradients are enabled.

on_test_epoch_end()[source]

Called in the test loop at the very end of the epoch.

Return type:

None

configure_optimizers()[source]

Choose what optimizers and learning-rate schedulers to use in your optimization. Normally you’d need one. But in the case of GANs or similar you might have multiple. Optimization with multiple optimizers only works in the manual optimization mode.

Return:

Any of these 6 options.

  • Single optimizer.

  • List or Tuple of optimizers.

  • Two lists - The first list has multiple optimizers, and the second has multiple LR schedulers (or multiple lr_scheduler_config).

  • Dictionary, with an "optimizer" key, and (optionally) a "lr_scheduler" key whose value is a single LR scheduler or lr_scheduler_config.

  • None - Fit will run without any optimizer.

The lr_scheduler_config is a dictionary which contains the scheduler and its associated configuration. The default configuration is shown below.

lr_scheduler_config = {
    # REQUIRED: The scheduler instance
    "scheduler": lr_scheduler,
    # The unit of the scheduler's step size, could also be 'step'.
    # 'epoch' updates the scheduler on epoch end whereas 'step'
    # updates it after a optimizer update.
    "interval": "epoch",
    # How many epochs/steps should pass between calls to
    # `scheduler.step()`. 1 corresponds to updating the learning
    # rate after every epoch/step.
    "frequency": 1,
    # Metric to to monitor for schedulers like `ReduceLROnPlateau`
    "monitor": "val_loss",
    # If set to `True`, will enforce that the value specified 'monitor'
    # is available when the scheduler is updated, thus stopping
    # training if not found. If set to `False`, it will only produce a warning
    "strict": True,
    # If using the `LearningRateMonitor` callback to monitor the
    # learning rate progress, this keyword can be used to specify
    # a custom logged name
    "name": None,
}

When there are schedulers in which the .step() method is conditioned on a value, such as the torch.optim.lr_scheduler.ReduceLROnPlateau scheduler, Lightning requires that the lr_scheduler_config contains the keyword "monitor" set to the metric name that the scheduler should be conditioned on.

Metrics can be made available to monitor by simply logging it using self.log('metric_to_track', metric_val) in your LightningModule.

Note:

Some things to know:

  • Lightning calls .backward() and .step() automatically in case of automatic optimization.

  • If a learning rate scheduler is specified in configure_optimizers() with key "interval" (default “epoch”) in the scheduler configuration, Lightning will call the scheduler’s .step() method automatically in case of automatic optimization.

  • If you use 16-bit precision (precision=16), Lightning will automatically handle the optimizer.

  • If you use torch.optim.LBFGS, Lightning handles the closure function automatically for you.

  • If you use multiple optimizers, you will have to switch to ‘manual optimization’ mode and step them yourself.

  • If you need to control how often the optimizer steps, override the optimizer_step() hook.

gravity.estimate_future_positions(stage1_csv, output_path, *, tau=0.5, show_plot=False, plot_path=None, projection_neighbor_choice='embedding', projection_neighbor_size=200, expression_scale='power10')[source]

Estimate future embeddings using the learned cell-wise velocities.

Parameters:
  • stage1_csv (str) – Output CSV from the cell-wise stage (stage1).

  • output_path (str) – Destination .npy file to store the nearest-neighbour anchors.

  • tau (float) – Scaling factor that shrinks the velocity vectors when forming the search radius.

  • show_plot (bool) – Whether to display a matplotlib window with the quiver plot.

  • plot_path (str | None) – Optional path to save the figure instead of (or in addition to) showing it.

  • projection_neighbor_choice (str)

  • projection_neighbor_size (int)

  • expression_scale (str | None)

Returns:

final_positions, cell_df – Tuple of the n × 3 array with neighbour coordinates + indices, and the augmented dataframe returned by compute_cell_velocity_().

Return type:

Tuple[ndarray, DataFrame]

gravity.plot_velocity_gene(stage_csv, *, gene=None, x='splice', y='unsplice', color_by='clusters', palette=None, categories=None, cmap='viridis', point_size=200.0, alpha=0.5, arrow_grid=(20, 20), arrow_scale=1.0, arrow_color='k', projection_neighbor_choice='embedding', expression_scale='power10', projection_neighbor_size=200, min_mass=0.5, axis_off=True, output_path=None, show=False)[source]

Plot gene-level phase portraits with projected velocity arrows.

The arrows connect current to predicted expression for a sampled subset of cells.

Parameters:
  • stage_csv (str)

  • gene (str | None)

  • x (str)

  • y (str)

  • color_by (str | None)

  • palette (Mapping[str, str] | None)

  • categories (Sequence[str] | None)

  • cmap (str)

  • point_size (float)

  • alpha (float)

  • arrow_grid (Tuple[int, int])

  • arrow_scale (float)

  • arrow_color (str)

  • projection_neighbor_choice (str)

  • expression_scale (str | None)

  • projection_neighbor_size (int)

  • min_mass (float)

  • axis_off (bool)

  • output_path (str | None)

  • show (bool)

Return type:

Axes

gravity.plot_velocity_cell(stage_csv, *, gene=None, x='splice', y='unsplice', color_by='clusters', palette=None, categories=None, cmap='viridis', point_size=200.0, alpha=0.5, arrow_grid=(20, 20), arrow_scale=1.0, arrow_color='k', projection_neighbor_choice='embedding', expression_scale='power10', projection_neighbor_size=200, min_mass=0.5, axis_off=True, output_path=None, show=False)[source]

Plot cell-level velocities in the embedding space.

The function computes velocity projections via gravity.velocity.compute_cell_velocity_() and overlays grid-curved arrows. Color handling supports both discrete palettes and continuous scalars.

Parameters:
  • stage_csv (str)

  • gene (str | None)

  • x (str)

  • y (str)

  • color_by (str | None)

  • palette (Mapping[str, str] | None)

  • categories (Sequence[str] | None)

  • cmap (str)

  • point_size (float)

  • alpha (float)

  • arrow_grid (Tuple[int, int])

  • arrow_scale (float)

  • arrow_color (str)

  • projection_neighbor_choice (str)

  • expression_scale (str | None)

  • projection_neighbor_size (int)

  • min_mass (float)

  • axis_off (bool)

  • output_path (str | None)

  • show (bool)

Return type:

Axes

gravity.rank_tf_scores(attention_h5ad, *, groupby='cell_type', method='wilcoxon', key_added='tf_rankings', n_genes=30, output_plot=None, sort_group=None, top_n=30, reuse_h5ad=None)[source]

Run differential ranking on TF scores aggregated per cell type.

Parameters:
  • attention_h5ad (str) – Path to attention_TF_scores_with_types.h5ad.

  • sort_group (int | str | None) – Group to sort by (logFC descending, p-value ascending). Accepts index (int) or name (str).

  • reuse_h5ad (str | None) – If provided with an h5ad already containing uns[key_added], reuse the stored ranking results.

  • groupby (str)

  • method (str)

  • key_added (str)

  • n_genes (int)

  • output_plot (str | None)

  • top_n (int)

Returns:

all_rankings, top_rankings – DataFrame of all per-group rankings, and (if sort_group is given) the Top-N table for the selected group.

Return type:

Tuple[DataFrame, DataFrame | None]

gravity.compute_batc(adata, cluster_edges, *, cluster_key='clusters', embedding_key='umap', velocity_key='velocity', use_bins=True, n_bins=60, min_per_bin=5, n_samples=800, store_in_adata=True, progress=False)[source]

Compute the Branching-aware Trajectory Consistency (BATC) score.

BATC evaluates RNA velocity fields on predefined lineage graphs by fitting smooth principal curves between successive clusters and comparing their tangents with per-cell velocity vectors. For each directed edge :math:`A

ightarrow B`, a curve \(\gamma_{A o B}(u)\) is fitted on the

2D embedding of cells belonging to \(A \cup B\). Every cell \(c\) on the edge is projected onto the curve and the cosine similarity between the local tangent :math:` au_c` and its velocity \(v_c\) is recorded:

. math:

\mathrm{BATC}_{A        o B} =
rac{1}{|I_{A\cup B}|}

sum_{c in I_{Acup B}}

rac{v_c cdot au_c}{lVert v_c Vert,lVert au_c Vert}.

To handle branching, for each source cluster \(A\) with outgoing targets \(B_1, \ldots, B_m\) we compute these cosine scores for every outgoing edge and, for each cell \(c \in I_A\), keep the best matching branch

. math:

b_c^{(A)} = \max_{B \in \mathrm{Out}(A)}

rac{v_c cdot au_c^{(A o B)}}{lVert v_c Vert,lVert au_c^{(A o B)} Vert}.

The final dataset-level BATC is the cell-weighted mean of these branch-aware scores over all source clusters:

. math:

\mathrm{BATC}_{\mathrm{overall}} =

rac{1}{sum_A |I_A|} sum_A sum_{c in I_A} b_c^{(A)}.

Zero-norm vectors are mapped to nan scores, and all averages use numpy.nanmean for robustness.

adata:

Annotated data matrix containing the embedding and velocity arrays.

cluster_edges:

Directed edges describing permitted transitions between clusters. Nodes are cast to strings to match adata.obs[cluster_key].

cluster_key:

Column in adata.obs holding cluster labels. Defaults to 'clusters'.

embedding_key:

Key identifying the 2D embedding. Defaults to 'umap' which is resolved to adata.obsm['X_umap'] when present. Any other key is resolved in the same fashion (X_<key> or direct entry).

velocity_key:

Key identifying the velocity representation (must align with the embedding). With the default 'velocity' the function looks for adata.obsm['velocity_umap'] when embedding_key='umap'.

use_bins, n_bins, min_per_bin:

Controls for the skeletonisation step when fitting the principal curves.

n_samples:

Number of evaluation points used to project cells onto each curve.

store_in_adata:

If True (default) write per-cell and aggregate scores back to the adata.obs/adata.uns containers.

progress:

Whether to display a progress bar over edges (requires tqdm).

float

The BATC overall score (best-per-cell cosine mean).

Parameters:
  • cluster_edges (Sequence[Tuple[str, str]])

  • cluster_key (str)

  • embedding_key (str)

  • velocity_key (str)

  • use_bins (bool)

  • n_bins (int)

  • min_per_bin (int)

  • n_samples (int)

  • store_in_adata (bool)

  • progress (bool)

Return type:

float

gravity.set_verbose(level)[source]

Set global verbosity level (higher means more logs).

Parameters:

level (int)

Return type:

None

gravity.get_verbose()[source]

Return the configured verbosity level.

Return type:

int

gravity.log_verbose(message, *, level=1)[source]

Print message when the verbosity is at least level.

Parameters:
  • message (Any)

  • level (int)

Return type:

None

Pipeline configuration and execution for the GRAVITY workflow.

This module exposes a high-level configuration dataclass and a single entry-point function that coordinates the complete GRAVITY pipeline:

  • preprocess long-format counts into a wide table used by subsequent stages,

  • train the cell-wise stage (stage 1) and export attention matrices if enabled,

  • estimate future positions from the stage-1 outputs,

  • train the gene-wise stage (stage 2),

  • optionally render velocity visualizations at the cell and gene level.

Multi-GPU and distributed execution are handled by PyTorch Lightning inside each training stage; this module coordinates preprocessing, stage transitions, future projection, and optional plotting.

References

class gravity.pipeline.PipelineConfig(raw_counts, workdir='gravity_outputs', prior_network='./prior_data/nichenet_mouse.zip', gene_subset=None, gene_order_path=None, batch_size=16, n_pos_neighbors=30, n_neg_neighbors=10, stage1_epochs=6, stage2_epochs=6, stage1_lr=1e-06, stage2_lr=0.0001, stage1_pretrained_checkpoint=None, stage2_pretrained_checkpoint=None, embedding_size=16, model_dimension=16, ffn_dimension=16, val_fraction_stage1=0.0, val_fraction_stage2=0.0, accelerator='auto', devices=None, num_workers=8, strategy=None, precision=None, gradient_clip_val=None, future_tau=0.5, log_every_n_steps=50, progress_bar=True, make_plot=False, plot_gene=None, plot_color='clusters', plot_genes=None, arrow_grid=(20, 20), arrow_scale=1.0, middle_csv_name='combine.csv', stage1_csv_name='stage1.csv', stage2_csv_name='stage2.csv', future_positions_name='future_positions.npy', stage1_checkpoint_name='stage1.ckpt', stage2_checkpoint_name='stage2.ckpt')[source]

Configuration for the full GRAVITY pipeline.

Parameters:
  • raw_counts (str) – Path to the input long-format CSV (must include at least cellID, gene_name, unsplice, splice, embedding1, embedding2).

  • workdir (str) – Output directory where intermediate and final artifacts are written.

  • prior_network (str | None) – Path to the prior TF–target network archive used by GRAVITY.

  • gene_subset (Sequence[str] | None) – Optional list of genes to restrict training and evaluation.

  • gene_order_path (str | None) – Optional newline-delimited gene list that fixes gene-index order. Use this when running pretrained/reference checkpoints; model weights and attention matrices are aligned by gene position, not only by gene name.

  • batch_size (int) – Mini-batch size used by both training stages.

  • stage1_epochs (int) – Number of epochs per stage.

  • stage2_epochs (int) – Number of epochs per stage.

  • stage1_lr (float) – Learning rates per stage. The unsupervised cell-wise stage and contrastive gene-wise stage are somewhat learning-rate sensitive; for reference-style runs we recommend keeping stage1_lr below 1e-5 and tuning stage2_lr between 1e-3 and 1e-5.

  • stage2_lr (float) – Learning rates per stage. The unsupervised cell-wise stage and contrastive gene-wise stage are somewhat learning-rate sensitive; for reference-style runs we recommend keeping stage1_lr below 1e-5 and tuning stage2_lr between 1e-3 and 1e-5.

  • stage1_pretrained_checkpoint (str | None) – Optional checkpoints used for inference/export instead of training the corresponding stage. Use these for checkpoint-based reference reproduction.

  • stage2_pretrained_checkpoint (str | None) – Optional checkpoints used for inference/export instead of training the corresponding stage. Use these for checkpoint-based reference reproduction.

  • val_fraction_stage1 (float) – Fraction of data reserved for validation in each stage.

  • val_fraction_stage2 (float) – Fraction of data reserved for validation in each stage.

  • accelerator (str) – Forwarded to PyTorch Lightning trainers to control device placement, distribution strategy and numerical precision.

  • devices (int | Sequence[int] | None) – Forwarded to PyTorch Lightning trainers to control device placement, distribution strategy and numerical precision.

  • strategy (str | None) – Forwarded to PyTorch Lightning trainers to control device placement, distribution strategy and numerical precision.

  • precision (int | str | None) – Forwarded to PyTorch Lightning trainers to control device placement, distribution strategy and numerical precision.

  • gradient_clip_val (float | None) – Forwarded to PyTorch Lightning trainers to control device placement, distribution strategy and numerical precision.

  • num_workers (int) – Forwarded to PyTorch Lightning trainers to control device placement, distribution strategy and numerical precision.

  • future_tau (float) – Scaling factor governing the radius used in future-neighbor search.

  • log_every_n_steps (int) – Logging and progress display controls.

  • progress_bar (bool) – Logging and progress display controls.

  • make_plot (bool) – Plotting options for optional velocity visualization.

  • plot_gene (str | None) – Plotting options for optional velocity visualization.

  • plot_color (str | None) – Plotting options for optional velocity visualization.

  • plot_genes (str | Sequence[str] | None) – Plotting options for optional velocity visualization.

  • arrow_grid (Tuple[int, int]) – Plotting options for optional velocity visualization.

  • arrow_scale (float) – Plotting options for optional velocity visualization.

  • middle_csv_name (str) – Filenames for artifacts written under workdir.

  • stage*_csv_name – Filenames for artifacts written under workdir.

  • future_positions_name (str) – Filenames for artifacts written under workdir.

  • stage*_checkpoint_name – Filenames for artifacts written under workdir.

  • n_pos_neighbors (int)

  • n_neg_neighbors (int)

  • embedding_size (int)

  • model_dimension (int)

  • ffn_dimension (int)

  • stage1_csv_name (str)

  • stage2_csv_name (str)

  • stage1_checkpoint_name (str)

  • stage2_checkpoint_name (str)

raw_counts: str
workdir: str = 'gravity_outputs'
prior_network: str | None = './prior_data/nichenet_mouse.zip'
gene_subset: Sequence[str] | None = None
gene_order_path: str | None = None
batch_size: int = 16
n_pos_neighbors: int = 30
n_neg_neighbors: int = 10
stage1_epochs: int = 6
stage2_epochs: int = 6
stage1_lr: float = 1e-06
stage2_lr: float = 0.0001
stage1_pretrained_checkpoint: str | None = None
stage2_pretrained_checkpoint: str | None = None
embedding_size: int = 16
model_dimension: int = 16
ffn_dimension: int = 16
val_fraction_stage1: float = 0.0
val_fraction_stage2: float = 0.0
accelerator: str = 'auto'
devices: int | Sequence[int] | None = None
num_workers: int = 8
strategy: str | None = None
precision: int | str | None = None
gradient_clip_val: float | None = None
future_tau: float = 0.5
log_every_n_steps: int = 50
progress_bar: bool = True
make_plot: bool = False
plot_gene: str | None = None
plot_color: str | None = 'clusters'
plot_genes: str | Sequence[str] | None = None
arrow_grid: Tuple[int, int] = (20, 20)
arrow_scale: float = 1.0
middle_csv_name: str = 'combine.csv'
stage1_csv_name: str = 'stage1.csv'
stage2_csv_name: str = 'stage2.csv'
future_positions_name: str = 'future_positions.npy'
stage1_checkpoint_name: str = 'stage1.ckpt'
stage2_checkpoint_name: str = 'stage2.ckpt'
gravity.pipeline.run_pipeline(config)[source]

Execute preprocessing, two training stages, future projection, and optional plotting.

Parameters:

config (PipelineConfig) – PipelineConfig instance that specifies inputs, training and plotting options, and output filenames.

Returns:

Dict[str, pathlib.Path] – Mapping of artifact names to absolute paths under workdir. Keys include middle_csv, stage1_csv, stage1_checkpoint, attention_dir, future_positions, stage2_csv, and stage2_checkpoint. If plotting is enabled, additional keys may be present for generated figures.

Return type:

Dict[str, Path]

Notes

Multi-GPU controls in config are forwarded to PyTorch Lightning inside the stage trainers. Preprocessing, future projection and plotting always run in the main process. When using DDP with strategy='ddp' (spawn), child processes may re-import and execute the entry script; to avoid duplicated orchestration, this function performs an early return on non-zero ranks.

Data preprocessing utilities for GRAVITY training stages.

gravity.data.preprocessing.preprocess_counts(input_file, output_csv, *, gene_order=None)[source]

Prepare the cell-wise training table from a long single-cell CSV.

Parameters:
  • input_file (str) – Path to the raw long-format CSV with columns including cellID, gene_name, unsplice, splice, embedding1, embedding2.

  • output_csv (str) – Destination CSV containing one row per cell with serialized gene tuples.

  • gene_order (Sequence[str] | None) – Optional ordered gene list. When provided, matching gene columns are written first in this order so downstream checkpoints see the intended gene-index layout.

Returns:

pathlib.Path – The path to the generated intermediate CSV.

Return type:

Path

gravity.data.preprocessing.load_cell_stage_dataset(middle_file, *, prior_path='./prior_data/nichenet_mouse.zip', gene_list=None, n_pos_neighbors=30, n_neg_neighbors=10)[source]

Instantiate the PyTorch dataset used for the cell-wise stage.

Parameters:
  • middle_file (str)

  • prior_path (str)

  • gene_list (Sequence[str] | None)

Return type:

CustomDataset

gravity.data.preprocessing.load_gene_stage_dataset(middle_file, *, prior_path='./prior_data/nichenet_mouse.zip', future_positions='./final_positions_with_index_yixian.npy', gene_list=None)[source]

Instantiate the PyTorch dataset used for the gene-wise refinement stage.

Parameters:
  • middle_file (str)

  • prior_path (str)

  • future_positions (str)

  • gene_list (Sequence[str] | None)

Return type:

CustomDatasetGeneWise

gravity.data.preprocessing.load_gene_order(path)[source]

Load a newline-delimited gene order file.

The order is part of the model contract for pretrained checkpoints because attention and solver tensors are indexed by gene position.

Parameters:

path (str)

Return type:

list[str]

gravity.data.preprocessing.resolve_gene_order(gene_subset=None, gene_order_path=None)[source]

Resolve the effective gene list, preserving checkpoint-compatible order.

Parameters:
  • gene_subset (Sequence[str] | None)

  • gene_order_path (str | None)

Return type:

list[str] | None

gravity.data.preprocessing.assert_gene_order_matches(path, expected, *, label='genes.txt')[source]

Raise if an existing gene-order file does not match expected order.

Parameters:
  • path (Path)

  • expected (Sequence[str])

  • label (str)

Return type:

None

gravity.data.preprocessing.export_intermediate_from_h5ad(input_h5ad, output_csv, *, retain_genes=None, min_shared_counts=20, n_top_genes=2000, n_pcs=30, n_neighbors=30, embed_key='X_umap', celltype_key='cell_type', overwrite=False, normalized=False)[source]

Preprocess AnnData for GRAVITY and export a cellDancer-style CSV.

This function performs only the preprocessing required to create GRAVITY’s user-facing long-format count table. It does not run RNA velocity inference, does not plot, and does not save a processed .h5ad.

Pipeline

  1. Read the input AnnData (.h5ad).

  2. Optionally force-keep user-specified genes through filtering/HVG selection.

  3. Normalize and filter genes with the local preprocessing settings.

  4. Compute first-order moments with scv.pp.moments (produces ‘Mu’/’Ms’).

  5. Export a cellDancer-style CSV via the local adata_to_df_with_embed function, including Mu/Ms, the chosen 2D embedding, and cell-type labels.

param input_h5ad:

Path to an AnnData file that includes ‘spliced’ and ‘unspliced’ layers.

param output_csv:

Destination CSV file path for the cellDancer-style long-format table.

param retain_genes:

Genes that must be retained during filtering/HVG selection (if present).

param min_shared_counts:

Minimum shared counts across cells for gene filtering.

param n_top_genes:

Number of highly variable genes to retain (in addition to retain_genes).

param n_pcs:

Number of principal components used by scv.pp.moments.

param n_neighbors:

Neighborhood size used by scv.pp.moments.

param embed_key:

Key in adata.obsm for a 2D embedding (e.g., “X_umap”).

param celltype_key:

Column in adata.obs that holds cell-type labels (e.g., “cell_type” or “celltype”).

param overwrite:

If False and output_csv exists, skip work and return the existing path.

returns:

pathlib.Path – The path to the generated CSV.

raises KeyError:

If embed_key is not found in adata.obsm.

raises RuntimeError:

If required layers (‘spliced’ and ‘unspliced’) are missing.

raises ImportError:

If scanpy or scvelo are missing.

Parameters:
  • input_h5ad (str)

  • output_csv (str)

  • retain_genes (Sequence[str] | None)

  • min_shared_counts (int)

  • n_top_genes (int)

  • n_pcs (int)

  • n_neighbors (int)

  • embed_key (str)

  • celltype_key (str)

  • overwrite (bool)

  • normalized (bool)

Return type:

Path

gravity.data.preprocessing.adata_to_df_with_embed(adata, us_para=('Mu', 'Ms'), cell_type_para='celltype', embed_para='X_umap', save_path='cell_type_u_s_sample_df.csv', gene_list=None)[source]

Convert an AnnData object to a long CSV with per-gene/per-cell rows and 2D embedding.

The resulting CSV contains, for every (gene, cell) pair: - gene_name, unsplice, splice - cellID, clusters (cell-type), embedding1, embedding2

Notes

  • The two matrices are taken from adata.layers[us_para[0]] (unspliced) and adata.layers[us_para[1]] (spliced). By default, that is [‘Mu’, ‘Ms’].

  • The 2D embedding is read from adata.obsm[embed_para] (e.g., ‘X_umap’).

  • This function writes the CSV incrementally (one gene at a time) to keep memory usage manageable for large datasets, then appends cell metadata.

Parameters:
  • adata – An anndata.AnnData object containing layers and embedding.

  • us_para (Sequence[str]) – Names of the two layers for unspliced and spliced moments/counts, respectively; default (‘Mu’, ‘Ms’).

  • cell_type_para (str) – Column name in adata.obs that holds cell-type labels (default ‘celltype’).

  • embed_para (str) – Key in adata.obsm for 2D embedding (default ‘X_umap’).

  • save_path (str) – Destination CSV file path.

  • gene_list (Sequence[str] | None) – Specific genes to export. If None, use all genes (adata.var.index).

Returns:

pandas.DataFrame – The final DataFrame that was saved to save_path.

High-level helpers to train the GRAVITY cell-wise stage.

The functions in this module prepare datasets, configure PyTorch Lightning trainers, and export artefacts (CSV predictions, TF scores, priors) expected by downstream steps. They do not change the model’s inner computations.

class gravity.train.cell_stage.CellStageConfig(raw_counts, middle_csv, prior_network='./prior_data/nichenet_mouse.zip', output_dir='gravity_outputs', stage1_csv='stage1.csv', checkpoint_name='stage1.ckpt', pretrained_checkpoint=None, attention_dir='attentions', gene_subset=None, gene_order_path=None, batch_size=32, epochs=6, accelerator='auto', devices=None, num_workers=0, val_fraction=0.0, attention_topk=64, attention_output=True, precision=None, gradient_clip_val=None, strategy=None, seed=42, log_every_n_steps=50, progress_bar=True, learning_rate=1e-06, embedding_size=16, model_dimension=16, ffn_dimension=16, n_pos_neighbors=30, n_neg_neighbors=10)[source]

Configuration bundle for the cell-wise training stage.

See also the top-level gravity.pipeline.PipelineConfig for how device/distribution options propagate. gene_order_path fixes the checkpoint-compatible gene index order when using pretrained/reference weights.

Parameters:
  • raw_counts (str)

  • middle_csv (str)

  • prior_network (str | None)

  • output_dir (str)

  • stage1_csv (str)

  • checkpoint_name (str)

  • pretrained_checkpoint (str | None)

  • attention_dir (str)

  • gene_subset (Sequence[str] | None)

  • gene_order_path (str | None)

  • batch_size (int)

  • epochs (int)

  • accelerator (str)

  • devices (int | Sequence[int] | None)

  • num_workers (int)

  • val_fraction (float)

  • attention_topk (int)

  • attention_output (bool)

  • precision (int | str | None)

  • gradient_clip_val (float | None)

  • strategy (str | None)

  • seed (int)

  • log_every_n_steps (int)

  • progress_bar (bool)

  • learning_rate (float)

  • embedding_size (int)

  • model_dimension (int)

  • ffn_dimension (int)

  • n_pos_neighbors (int)

  • n_neg_neighbors (int)

raw_counts: str
middle_csv: str
prior_network: str | None = './prior_data/nichenet_mouse.zip'
output_dir: str = 'gravity_outputs'
stage1_csv: str = 'stage1.csv'
checkpoint_name: str = 'stage1.ckpt'
pretrained_checkpoint: str | None = None
attention_dir: str = 'attentions'
gene_subset: Sequence[str] | None = None
gene_order_path: str | None = None
batch_size: int = 32
epochs: int = 6
accelerator: str = 'auto'
devices: int | Sequence[int] | None = None
num_workers: int = 0
val_fraction: float = 0.0
attention_topk: int = 64
attention_output: bool = True
precision: int | str | None = None
gradient_clip_val: float | None = None
strategy: str | None = None
seed: int = 42
log_every_n_steps: int = 50
progress_bar: bool = True
learning_rate: float = 1e-06
embedding_size: int = 16
model_dimension: int = 16
ffn_dimension: int = 16
n_pos_neighbors: int = 30
n_neg_neighbors: int = 10
gravity.train.cell_stage.train_cell_stage(config)[source]

Run the GRAVITY cell-wise stage and export key artefacts.

Returns:

Dict[str, Path] – Paths to stage1_csv, checkpoint, attention outputs (optional), and the generated gene files used by later steps.

Parameters:

config (CellStageConfig)

Return type:

Dict[str, Path]

High-level helpers to refine GRAVITY models in the gene-wise stage.

This module loads the stage-1 checkpoint, freezes most parameters, and enables fine-tuning of a restricted set of solver layers. Artefacts are exported in the same format used by the cell-wise stage to ease downstream consumption.

class gravity.train.gene_stage.GeneStageConfig(raw_counts, middle_csv, stage1_checkpoint, future_positions, prior_network='./prior_data/nichenet_mouse.zip', output_dir='gravity_outputs', stage2_csv='stage2.csv', checkpoint_name='stage2.ckpt', pretrained_checkpoint=None, gene_subset=None, gene_order_path=None, batch_size=32, epochs=6, accelerator='auto', devices=None, num_workers=0, val_fraction=0.0, precision=None, gradient_clip_val=None, strategy=None, seed=42, log_every_n_steps=50, progress_bar=True, learning_rate=0.0001, embedding_size=16, model_dimension=16, ffn_dimension=16)[source]

Configuration bundle for the gene-wise fine-tuning stage.

Device and distribution arguments mirror those in gravity.pipeline.PipelineConfig. gene_order_path should point to the same gene list used by the stage-1 checkpoint when reproducing pretrained/reference runs.

Parameters:
  • raw_counts (str)

  • middle_csv (str)

  • stage1_checkpoint (str)

  • future_positions (str)

  • prior_network (str | None)

  • output_dir (str)

  • stage2_csv (str)

  • checkpoint_name (str)

  • pretrained_checkpoint (str | None)

  • gene_subset (Sequence[str] | None)

  • gene_order_path (str | None)

  • batch_size (int)

  • epochs (int)

  • accelerator (str)

  • devices (int | Sequence[int] | None)

  • num_workers (int)

  • val_fraction (float)

  • precision (int | str | None)

  • gradient_clip_val (float | None)

  • strategy (str | None)

  • seed (int)

  • log_every_n_steps (int)

  • progress_bar (bool)

  • learning_rate (float)

  • embedding_size (int)

  • model_dimension (int)

  • ffn_dimension (int)

raw_counts: str
middle_csv: str
stage1_checkpoint: str
future_positions: str
prior_network: str | None = './prior_data/nichenet_mouse.zip'
output_dir: str = 'gravity_outputs'
stage2_csv: str = 'stage2.csv'
checkpoint_name: str = 'stage2.ckpt'
pretrained_checkpoint: str | None = None
gene_subset: Sequence[str] | None = None
gene_order_path: str | None = None
batch_size: int = 32
epochs: int = 6
accelerator: str = 'auto'
devices: int | Sequence[int] | None = None
num_workers: int = 0
val_fraction: float = 0.0
precision: int | str | None = None
gradient_clip_val: float | None = None
strategy: str | None = None
seed: int = 42
log_every_n_steps: int = 50
progress_bar: bool = True
learning_rate: float = 0.0001
embedding_size: int = 16
model_dimension: int = 16
ffn_dimension: int = 16
gravity.train.gene_stage.train_gene_stage(config)[source]

Run the GRAVITY gene-wise refinement stage.

Returns:

Dict[str, Path] – Paths to stage2_csv, checkpoint and the gene/prior files.

Parameters:

config (GeneStageConfig)

Return type:

Dict[str, Path]

Utility to project GRAVITY velocities towards future positions.

gravity.tools.future.estimate_future_positions(stage1_csv, output_path, *, tau=0.5, show_plot=False, plot_path=None, projection_neighbor_choice='embedding', projection_neighbor_size=200, expression_scale='power10')[source]

Estimate future embeddings using the learned cell-wise velocities.

Parameters:
  • stage1_csv (str) – Output CSV from the cell-wise stage (stage1).

  • output_path (str) – Destination .npy file to store the nearest-neighbour anchors.

  • tau (float) – Scaling factor that shrinks the velocity vectors when forming the search radius.

  • show_plot (bool) – Whether to display a matplotlib window with the quiver plot.

  • plot_path (str | None) – Optional path to save the figure instead of (or in addition to) showing it.

  • projection_neighbor_choice (str)

  • projection_neighbor_size (int)

  • expression_scale (str | None)

Returns:

final_positions, cell_df – Tuple of the n × 3 array with neighbour coordinates + indices, and the augmented dataframe returned by compute_cell_velocity_().

Return type:

Tuple[ndarray, DataFrame]