gravity Package¶
User-friendly entry points for the GRAVITY library.
- class gravity.PipelineConfig(raw_counts, workdir='gravity_outputs', prior_network='./prior_data/nichenet_mouse.zip', gene_subset=None, gene_order_path=None, batch_size=16, n_pos_neighbors=30, n_neg_neighbors=10, stage1_epochs=6, stage2_epochs=6, stage1_lr=1e-06, stage2_lr=0.0001, stage1_pretrained_checkpoint=None, stage2_pretrained_checkpoint=None, embedding_size=16, model_dimension=16, ffn_dimension=16, val_fraction_stage1=0.0, val_fraction_stage2=0.0, accelerator='auto', devices=None, num_workers=8, strategy=None, precision=None, gradient_clip_val=None, future_tau=0.5, log_every_n_steps=50, progress_bar=True, make_plot=False, plot_gene=None, plot_color='clusters', plot_genes=None, arrow_grid=(20, 20), arrow_scale=1.0, middle_csv_name='combine.csv', stage1_csv_name='stage1.csv', stage2_csv_name='stage2.csv', future_positions_name='future_positions.npy', stage1_checkpoint_name='stage1.ckpt', stage2_checkpoint_name='stage2.ckpt')[source]¶
Bases:
objectConfiguration for the full GRAVITY pipeline.
- Parameters:
raw_counts (str) – Path to the input long-format CSV (must include at least
cellID, gene_name, unsplice, splice, embedding1, embedding2).workdir (str) – Output directory where intermediate and final artifacts are written.
prior_network (str | None) – Path to the prior TF–target network archive used by GRAVITY.
gene_subset (Sequence[str] | None) – Optional list of genes to restrict training and evaluation.
gene_order_path (str | None) – Optional newline-delimited gene list that fixes gene-index order. Use this when running pretrained/reference checkpoints; model weights and attention matrices are aligned by gene position, not only by gene name.
batch_size (int) – Mini-batch size used by both training stages.
stage1_epochs (int) – Number of epochs per stage.
stage2_epochs (int) – Number of epochs per stage.
stage1_lr (float) – Learning rates per stage. The unsupervised cell-wise stage and contrastive gene-wise stage are somewhat learning-rate sensitive; for reference-style runs we recommend keeping
stage1_lrbelow1e-5and tuningstage2_lrbetween1e-3and1e-5.stage2_lr (float) – Learning rates per stage. The unsupervised cell-wise stage and contrastive gene-wise stage are somewhat learning-rate sensitive; for reference-style runs we recommend keeping
stage1_lrbelow1e-5and tuningstage2_lrbetween1e-3and1e-5.stage1_pretrained_checkpoint (str | None) – Optional checkpoints used for inference/export instead of training the corresponding stage. Use these for checkpoint-based reference reproduction.
stage2_pretrained_checkpoint (str | None) – Optional checkpoints used for inference/export instead of training the corresponding stage. Use these for checkpoint-based reference reproduction.
val_fraction_stage1 (float) – Fraction of data reserved for validation in each stage.
val_fraction_stage2 (float) – Fraction of data reserved for validation in each stage.
accelerator (str) – Forwarded to PyTorch Lightning trainers to control device placement, distribution strategy and numerical precision.
devices (int | Sequence[int] | None) – Forwarded to PyTorch Lightning trainers to control device placement, distribution strategy and numerical precision.
strategy (str | None) – Forwarded to PyTorch Lightning trainers to control device placement, distribution strategy and numerical precision.
precision (int | str | None) – Forwarded to PyTorch Lightning trainers to control device placement, distribution strategy and numerical precision.
gradient_clip_val (float | None) – Forwarded to PyTorch Lightning trainers to control device placement, distribution strategy and numerical precision.
num_workers (int) – Forwarded to PyTorch Lightning trainers to control device placement, distribution strategy and numerical precision.
future_tau (float) – Scaling factor governing the radius used in future-neighbor search.
log_every_n_steps (int) – Logging and progress display controls.
progress_bar (bool) – Logging and progress display controls.
make_plot (bool) – Plotting options for optional velocity visualization.
plot_gene (str | None) – Plotting options for optional velocity visualization.
plot_color (str | None) – Plotting options for optional velocity visualization.
plot_genes (str | Sequence[str] | None) – Plotting options for optional velocity visualization.
arrow_grid (Tuple[int, int]) – Plotting options for optional velocity visualization.
arrow_scale (float) – Plotting options for optional velocity visualization.
middle_csv_name (str) – Filenames for artifacts written under
workdir.stage*_csv_name – Filenames for artifacts written under
workdir.future_positions_name (str) – Filenames for artifacts written under
workdir.stage*_checkpoint_name – Filenames for artifacts written under
workdir.n_pos_neighbors (int)
n_neg_neighbors (int)
embedding_size (int)
model_dimension (int)
ffn_dimension (int)
stage1_csv_name (str)
stage2_csv_name (str)
stage1_checkpoint_name (str)
stage2_checkpoint_name (str)
- raw_counts: str¶
- workdir: str = 'gravity_outputs'¶
- prior_network: str | None = './prior_data/nichenet_mouse.zip'¶
- gene_subset: Sequence[str] | None = None¶
- gene_order_path: str | None = None¶
- batch_size: int = 16¶
- n_pos_neighbors: int = 30¶
- n_neg_neighbors: int = 10¶
- stage1_epochs: int = 6¶
- stage2_epochs: int = 6¶
- stage1_lr: float = 1e-06¶
- stage2_lr: float = 0.0001¶
- stage1_pretrained_checkpoint: str | None = None¶
- stage2_pretrained_checkpoint: str | None = None¶
- embedding_size: int = 16¶
- model_dimension: int = 16¶
- ffn_dimension: int = 16¶
- val_fraction_stage1: float = 0.0¶
- val_fraction_stage2: float = 0.0¶
- accelerator: str = 'auto'¶
- devices: int | Sequence[int] | None = None¶
- num_workers: int = 8¶
- strategy: str | None = None¶
- precision: int | str | None = None¶
- gradient_clip_val: float | None = None¶
- future_tau: float = 0.5¶
- log_every_n_steps: int = 50¶
- progress_bar: bool = True¶
- make_plot: bool = False¶
- plot_gene: str | None = None¶
- plot_color: str | None = 'clusters'¶
- plot_genes: str | Sequence[str] | None = None¶
- arrow_grid: Tuple[int, int] = (20, 20)¶
- arrow_scale: float = 1.0¶
- middle_csv_name: str = 'combine.csv'¶
- stage1_csv_name: str = 'stage1.csv'¶
- stage2_csv_name: str = 'stage2.csv'¶
- future_positions_name: str = 'future_positions.npy'¶
- stage1_checkpoint_name: str = 'stage1.ckpt'¶
- stage2_checkpoint_name: str = 'stage2.ckpt'¶
- gravity.run_pipeline(config)[source]¶
Execute preprocessing, two training stages, future projection, and optional plotting.
- Parameters:
config (PipelineConfig) –
PipelineConfiginstance that specifies inputs, training and plotting options, and output filenames.- Returns:
Dict[str, pathlib.Path] – Mapping of artifact names to absolute paths under
workdir. Keys includemiddle_csv,stage1_csv,stage1_checkpoint,attention_dir,future_positions,stage2_csv, andstage2_checkpoint. If plotting is enabled, additional keys may be present for generated figures.- Return type:
Dict[str, Path]
Notes
Multi-GPU controls in
configare forwarded to PyTorch Lightning inside the stage trainers. Preprocessing, future projection and plotting always run in the main process. When using DDP withstrategy='ddp'(spawn), child processes may re-import and execute the entry script; to avoid duplicated orchestration, this function performs an early return on non-zero ranks.
- gravity.preprocess_counts(input_file, output_csv, *, gene_order=None)[source]¶
Prepare the cell-wise training table from a long single-cell CSV.
- Parameters:
input_file (str) – Path to the raw long-format CSV with columns including cellID, gene_name, unsplice, splice, embedding1, embedding2.
output_csv (str) – Destination CSV containing one row per cell with serialized gene tuples.
gene_order (Sequence[str] | None) – Optional ordered gene list. When provided, matching gene columns are written first in this order so downstream checkpoints see the intended gene-index layout.
- Returns:
pathlib.Path – The path to the generated intermediate CSV.
- Return type:
Path
- gravity.load_gene_order(path)[source]¶
Load a newline-delimited gene order file.
The order is part of the model contract for pretrained checkpoints because attention and solver tensors are indexed by gene position.
- Parameters:
path (str)
- Return type:
list[str]
- gravity.resolve_gene_order(gene_subset=None, gene_order_path=None)[source]¶
Resolve the effective gene list, preserving checkpoint-compatible order.
- Parameters:
gene_subset (Sequence[str] | None)
gene_order_path (str | None)
- Return type:
list[str] | None
- gravity.export_intermediate_from_h5ad(input_h5ad, output_csv, *, retain_genes=None, min_shared_counts=20, n_top_genes=2000, n_pcs=30, n_neighbors=30, embed_key='X_umap', celltype_key='cell_type', overwrite=False, normalized=False)[source]¶
Preprocess AnnData for GRAVITY and export a cellDancer-style CSV.
This function performs only the preprocessing required to create GRAVITY’s user-facing long-format count table. It does not run RNA velocity inference, does not plot, and does not save a processed .h5ad.
Pipeline¶
Read the input AnnData (.h5ad).
Optionally force-keep user-specified genes through filtering/HVG selection.
Normalize and filter genes with the local preprocessing settings.
Compute first-order moments with scv.pp.moments (produces ‘Mu’/’Ms’).
Export a cellDancer-style CSV via the local adata_to_df_with_embed function, including Mu/Ms, the chosen 2D embedding, and cell-type labels.
- param input_h5ad:
Path to an AnnData file that includes ‘spliced’ and ‘unspliced’ layers.
- param output_csv:
Destination CSV file path for the cellDancer-style long-format table.
- param retain_genes:
Genes that must be retained during filtering/HVG selection (if present).
- param min_shared_counts:
Minimum shared counts across cells for gene filtering.
- param n_top_genes:
Number of highly variable genes to retain (in addition to retain_genes).
- param n_pcs:
Number of principal components used by scv.pp.moments.
- param n_neighbors:
Neighborhood size used by scv.pp.moments.
- param embed_key:
Key in adata.obsm for a 2D embedding (e.g., “X_umap”).
- param celltype_key:
Column in adata.obs that holds cell-type labels (e.g., “cell_type” or “celltype”).
- param overwrite:
If False and output_csv exists, skip work and return the existing path.
- returns:
pathlib.Path – The path to the generated CSV.
- raises KeyError:
If embed_key is not found in adata.obsm.
- raises RuntimeError:
If required layers (‘spliced’ and ‘unspliced’) are missing.
- raises ImportError:
If scanpy or scvelo are missing.
- Parameters:
input_h5ad (str)
output_csv (str)
retain_genes (Sequence[str] | None)
min_shared_counts (int)
n_top_genes (int)
n_pcs (int)
n_neighbors (int)
embed_key (str)
celltype_key (str)
overwrite (bool)
normalized (bool)
- Return type:
Path
- class gravity.CellStageConfig(raw_counts, middle_csv, prior_network='./prior_data/nichenet_mouse.zip', output_dir='gravity_outputs', stage1_csv='stage1.csv', checkpoint_name='stage1.ckpt', pretrained_checkpoint=None, attention_dir='attentions', gene_subset=None, gene_order_path=None, batch_size=32, epochs=6, accelerator='auto', devices=None, num_workers=0, val_fraction=0.0, attention_topk=64, attention_output=True, precision=None, gradient_clip_val=None, strategy=None, seed=42, log_every_n_steps=50, progress_bar=True, learning_rate=1e-06, embedding_size=16, model_dimension=16, ffn_dimension=16, n_pos_neighbors=30, n_neg_neighbors=10)[source]¶
Bases:
objectConfiguration bundle for the cell-wise training stage.
See also the top-level
gravity.pipeline.PipelineConfigfor how device/distribution options propagate.gene_order_pathfixes the checkpoint-compatible gene index order when using pretrained/reference weights.- Parameters:
raw_counts (str)
middle_csv (str)
prior_network (str | None)
output_dir (str)
stage1_csv (str)
checkpoint_name (str)
pretrained_checkpoint (str | None)
attention_dir (str)
gene_subset (Sequence[str] | None)
gene_order_path (str | None)
batch_size (int)
epochs (int)
accelerator (str)
devices (int | Sequence[int] | None)
num_workers (int)
val_fraction (float)
attention_topk (int)
attention_output (bool)
precision (int | str | None)
gradient_clip_val (float | None)
strategy (str | None)
seed (int)
log_every_n_steps (int)
progress_bar (bool)
learning_rate (float)
embedding_size (int)
model_dimension (int)
ffn_dimension (int)
n_pos_neighbors (int)
n_neg_neighbors (int)
- raw_counts: str¶
- middle_csv: str¶
- prior_network: str | None = './prior_data/nichenet_mouse.zip'¶
- output_dir: str = 'gravity_outputs'¶
- stage1_csv: str = 'stage1.csv'¶
- checkpoint_name: str = 'stage1.ckpt'¶
- pretrained_checkpoint: str | None = None¶
- attention_dir: str = 'attentions'¶
- gene_subset: Sequence[str] | None = None¶
- gene_order_path: str | None = None¶
- batch_size: int = 32¶
- epochs: int = 6¶
- accelerator: str = 'auto'¶
- devices: int | Sequence[int] | None = None¶
- num_workers: int = 0¶
- val_fraction: float = 0.0¶
- attention_topk: int = 64¶
- attention_output: bool = True¶
- precision: int | str | None = None¶
- gradient_clip_val: float | None = None¶
- strategy: str | None = None¶
- seed: int = 42¶
- log_every_n_steps: int = 50¶
- progress_bar: bool = True¶
- learning_rate: float = 1e-06¶
- embedding_size: int = 16¶
- model_dimension: int = 16¶
- ffn_dimension: int = 16¶
- n_pos_neighbors: int = 30¶
- n_neg_neighbors: int = 10¶
- gravity.train_cell_stage(config)[source]¶
Run the GRAVITY cell-wise stage and export key artefacts.
- Returns:
Dict[str, Path] – Paths to
stage1_csv,checkpoint, attention outputs (optional), and the generated gene files used by later steps.- Parameters:
config (CellStageConfig)
- Return type:
Dict[str, Path]
- class gravity.GeneStageConfig(raw_counts, middle_csv, stage1_checkpoint, future_positions, prior_network='./prior_data/nichenet_mouse.zip', output_dir='gravity_outputs', stage2_csv='stage2.csv', checkpoint_name='stage2.ckpt', pretrained_checkpoint=None, gene_subset=None, gene_order_path=None, batch_size=32, epochs=6, accelerator='auto', devices=None, num_workers=0, val_fraction=0.0, precision=None, gradient_clip_val=None, strategy=None, seed=42, log_every_n_steps=50, progress_bar=True, learning_rate=0.0001, embedding_size=16, model_dimension=16, ffn_dimension=16)[source]¶
Bases:
objectConfiguration bundle for the gene-wise fine-tuning stage.
Device and distribution arguments mirror those in
gravity.pipeline.PipelineConfig.gene_order_pathshould point to the same gene list used by the stage-1 checkpoint when reproducing pretrained/reference runs.- Parameters:
raw_counts (str)
middle_csv (str)
stage1_checkpoint (str)
future_positions (str)
prior_network (str | None)
output_dir (str)
stage2_csv (str)
checkpoint_name (str)
pretrained_checkpoint (str | None)
gene_subset (Sequence[str] | None)
gene_order_path (str | None)
batch_size (int)
epochs (int)
accelerator (str)
devices (int | Sequence[int] | None)
num_workers (int)
val_fraction (float)
precision (int | str | None)
gradient_clip_val (float | None)
strategy (str | None)
seed (int)
log_every_n_steps (int)
progress_bar (bool)
learning_rate (float)
embedding_size (int)
model_dimension (int)
ffn_dimension (int)
- raw_counts: str¶
- middle_csv: str¶
- stage1_checkpoint: str¶
- future_positions: str¶
- prior_network: str | None = './prior_data/nichenet_mouse.zip'¶
- output_dir: str = 'gravity_outputs'¶
- stage2_csv: str = 'stage2.csv'¶
- checkpoint_name: str = 'stage2.ckpt'¶
- pretrained_checkpoint: str | None = None¶
- gene_subset: Sequence[str] | None = None¶
- gene_order_path: str | None = None¶
- batch_size: int = 32¶
- epochs: int = 6¶
- accelerator: str = 'auto'¶
- devices: int | Sequence[int] | None = None¶
- num_workers: int = 0¶
- val_fraction: float = 0.0¶
- precision: int | str | None = None¶
- gradient_clip_val: float | None = None¶
- strategy: str | None = None¶
- seed: int = 42¶
- log_every_n_steps: int = 50¶
- progress_bar: bool = True¶
- learning_rate: float = 0.0001¶
- embedding_size: int = 16¶
- model_dimension: int = 16¶
- ffn_dimension: int = 16¶
- gravity.train_gene_stage(config)[source]¶
Run the GRAVITY gene-wise refinement stage.
- Returns:
Dict[str, Path] – Paths to
stage2_csv,checkpointand the gene/prior files.- Parameters:
config (GeneStageConfig)
- Return type:
Dict[str, Path]
- class gravity.FullModelCellWise(input_dim, TF_list, TG_list, TFTG_map, TFTG_map_reverse, embedding_size=16, model_dimension=16, ffn_dimension=16, output_dim_trans=10, output_dim_dense=20, gene_list=None, embedding_map=None, id_map=None, gene_select=None, nbrs=None, origin_data=None, negs=None, output_csv=None, attention_output=True, output_network_path='attentions', cell_to_type=None, gene_list_path='genes.txt', gene_mapper_path='genemap.json', csr_topk=64, learning_rate=1e-06)[source]¶
Bases:
LightningModulePyTorch Lightning module encapsulating the cell-wise GRAVITY stage.
- Parameters:
input_dim (int)
TF_list (Sequence[str])
TG_list (Sequence[str])
TFTG_map (Dict[str, Sequence[str]])
TFTG_map_reverse (Dict[str, Sequence[str]])
embedding_size (int)
model_dimension (int)
ffn_dimension (int)
output_dim_trans (int)
output_dim_dense (int)
gene_list (Optional[Sequence[str]])
embedding_map (Optional[Dict[int, str]])
id_map (Optional[Dict[int, str]])
gene_select (Optional[Sequence[str]])
nbrs (Optional[np.ndarray])
origin_data (Optional[pd.DataFrame])
negs (Optional[np.ndarray])
output_csv (Optional[str])
attention_output (bool)
output_network_path (str)
cell_to_type (Optional[Dict[str, str]])
gene_list_path (str)
gene_mapper_path (str)
csr_topk (int)
learning_rate (float)
- forward(x, mask=None)[source]¶
Same as
torch.nn.Module.forward().- Args:
*args: Whatever you decide to pass into the forward method. **kwargs: Keyword arguments are also possible.
- Return:
Your model’s output
- Parameters:
x (Tensor)
mask (Tensor | None)
- training_step(batch, batch_idx)[source]¶
Here you compute and return the training loss and some additional metrics for e.g. the progress bar or logger.
- Args:
- batch (
Tensor| (Tensor, …) | [Tensor, …]): The output of your
DataLoader. A tensor, tuple or list.
batch_idx (
int): Integer displaying index of this batch- batch (
- Return:
Any of.
Tensor- The loss tensordict- A dictionary. Can include any keys, but must include the key'loss'None- Training will skip to the next batch. This is only for automatic optimization.This is not supported for multi-GPU, TPU, IPU, or DeepSpeed.
In this step you’d normally do the forward pass and calculate the loss for a batch. You can also do fancier things like multiple forward passes or something model specific.
Example:
def training_step(self, batch, batch_idx): x, y, z = batch out = self.encoder(x) loss = self.loss(out, x) return loss
To use multiple optimizers, you can switch to ‘manual optimization’ and control their stepping:
def __init__(self): super().__init__() self.automatic_optimization = False # Multiple optimizers (e.g.: GANs) def training_step(self, batch, batch_idx): opt1, opt2 = self.optimizers() # do training_step with encoder ... opt1.step() # do training_step with decoder ... opt2.step()
- Note:
When
accumulate_grad_batches> 1, the loss returned here will be automatically normalized byaccumulate_grad_batchesinternally.
- test_step(batch, batch_idx)[source]¶
Operates on a single batch of data from the test set. In this step you’d normally generate examples or calculate anything of interest such as accuracy.
- Args:
batch: The output of your
DataLoader. batch_idx: The index of this batch. dataloader_id: The index of the dataloader that produced this batch.(only if multiple test dataloaders used).
- Return:
Any of.
Any object or value
None- Testing will skip to the next batch
# if you have one test dataloader: def test_step(self, batch, batch_idx): ... # if you have multiple test dataloaders: def test_step(self, batch, batch_idx, dataloader_idx=0): ...
Examples:
# CASE 1: A single test dataset def test_step(self, batch, batch_idx): x, y = batch # implement your own out = self(x) loss = self.loss(out, y) # log 6 example images # or generated text... or whatever sample_imgs = x[:6] grid = torchvision.utils.make_grid(sample_imgs) self.logger.experiment.add_image('example_images', grid, 0) # calculate acc labels_hat = torch.argmax(out, dim=1) test_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0) # log the outputs! self.log_dict({'test_loss': loss, 'test_acc': test_acc})
If you pass in multiple test dataloaders,
test_step()will have an additional argument. We recommend setting the default value of 0 so that you can quickly switch between single and multiple dataloaders.# CASE 2: multiple test dataloaders def test_step(self, batch, batch_idx, dataloader_idx=0): # dataloader_idx tells you which dataset this is. ...
- Note:
If you don’t need to test you don’t need to implement this method.
- Note:
When the
test_step()is called, the model has been put in eval mode and PyTorch gradients have been disabled. At the end of the test epoch, the model goes back to training mode and gradients are enabled.
- on_test_epoch_end()[source]¶
Called in the test loop at the very end of the epoch.
- Return type:
None
- configure_optimizers()[source]¶
Choose what optimizers and learning-rate schedulers to use in your optimization. Normally you’d need one. But in the case of GANs or similar you might have multiple. Optimization with multiple optimizers only works in the manual optimization mode.
- Return:
Any of these 6 options.
Single optimizer.
List or Tuple of optimizers.
Two lists - The first list has multiple optimizers, and the second has multiple LR schedulers (or multiple
lr_scheduler_config).Dictionary, with an
"optimizer"key, and (optionally) a"lr_scheduler"key whose value is a single LR scheduler orlr_scheduler_config.None - Fit will run without any optimizer.
The
lr_scheduler_configis a dictionary which contains the scheduler and its associated configuration. The default configuration is shown below.lr_scheduler_config = { # REQUIRED: The scheduler instance "scheduler": lr_scheduler, # The unit of the scheduler's step size, could also be 'step'. # 'epoch' updates the scheduler on epoch end whereas 'step' # updates it after a optimizer update. "interval": "epoch", # How many epochs/steps should pass between calls to # `scheduler.step()`. 1 corresponds to updating the learning # rate after every epoch/step. "frequency": 1, # Metric to to monitor for schedulers like `ReduceLROnPlateau` "monitor": "val_loss", # If set to `True`, will enforce that the value specified 'monitor' # is available when the scheduler is updated, thus stopping # training if not found. If set to `False`, it will only produce a warning "strict": True, # If using the `LearningRateMonitor` callback to monitor the # learning rate progress, this keyword can be used to specify # a custom logged name "name": None, }
When there are schedulers in which the
.step()method is conditioned on a value, such as thetorch.optim.lr_scheduler.ReduceLROnPlateauscheduler, Lightning requires that thelr_scheduler_configcontains the keyword"monitor"set to the metric name that the scheduler should be conditioned on.Metrics can be made available to monitor by simply logging it using
self.log('metric_to_track', metric_val)in yourLightningModule.- Note:
Some things to know:
Lightning calls
.backward()and.step()automatically in case of automatic optimization.If a learning rate scheduler is specified in
configure_optimizers()with key"interval"(default “epoch”) in the scheduler configuration, Lightning will call the scheduler’s.step()method automatically in case of automatic optimization.If you use 16-bit precision (
precision=16), Lightning will automatically handle the optimizer.If you use
torch.optim.LBFGS, Lightning handles the closure function automatically for you.If you use multiple optimizers, you will have to switch to ‘manual optimization’ mode and step them yourself.
If you need to control how often the optimizer steps, override the
optimizer_step()hook.
- class gravity.FullModelGeneWise(input_dim, TF_list, TG_list, TFTG_map, TFTG_map_reverse, embedding_size=16, model_dimension=16, ffn_dimension=16, output_dim_trans=10, output_dim_dense=20, gene_list=None, embedding_map=None, id_map=None, gene_select=None, nbrs=None, origin_data=None, negs=None, output_csv=None, learning_rate=0.0001)[source]¶
Bases:
LightningModule- Parameters:
input_dim (int)
TF_list (Sequence[str])
TG_list (Sequence[str])
TFTG_map (Dict[str, Sequence[str]])
TFTG_map_reverse (Dict[str, Sequence[str]])
embedding_size (int)
model_dimension (int)
ffn_dimension (int)
output_dim_trans (int)
output_dim_dense (int)
gene_list (Optional[Sequence[str]])
embedding_map (Optional[Dict[int, str]])
id_map (Optional[Dict[int, str]])
gene_select (Optional[Sequence[str]])
nbrs (Optional[np.ndarray])
origin_data (Optional[pd.DataFrame])
negs (Optional[np.ndarray])
output_csv (Optional[str])
learning_rate (float)
- training_step(batch, batch_idx)[source]¶
Here you compute and return the training loss and some additional metrics for e.g. the progress bar or logger.
- Args:
- batch (
Tensor| (Tensor, …) | [Tensor, …]): The output of your
DataLoader. A tensor, tuple or list.
batch_idx (
int): Integer displaying index of this batch- batch (
- Return:
Any of.
Tensor- The loss tensordict- A dictionary. Can include any keys, but must include the key'loss'None- Training will skip to the next batch. This is only for automatic optimization.This is not supported for multi-GPU, TPU, IPU, or DeepSpeed.
In this step you’d normally do the forward pass and calculate the loss for a batch. You can also do fancier things like multiple forward passes or something model specific.
Example:
def training_step(self, batch, batch_idx): x, y, z = batch out = self.encoder(x) loss = self.loss(out, x) return loss
To use multiple optimizers, you can switch to ‘manual optimization’ and control their stepping:
def __init__(self): super().__init__() self.automatic_optimization = False # Multiple optimizers (e.g.: GANs) def training_step(self, batch, batch_idx): opt1, opt2 = self.optimizers() # do training_step with encoder ... opt1.step() # do training_step with decoder ... opt2.step()
- Note:
When
accumulate_grad_batches> 1, the loss returned here will be automatically normalized byaccumulate_grad_batchesinternally.
- test_step(batch, batch_idx)[source]¶
Operates on a single batch of data from the test set. In this step you’d normally generate examples or calculate anything of interest such as accuracy.
- Args:
batch: The output of your
DataLoader. batch_idx: The index of this batch. dataloader_id: The index of the dataloader that produced this batch.(only if multiple test dataloaders used).
- Return:
Any of.
Any object or value
None- Testing will skip to the next batch
# if you have one test dataloader: def test_step(self, batch, batch_idx): ... # if you have multiple test dataloaders: def test_step(self, batch, batch_idx, dataloader_idx=0): ...
Examples:
# CASE 1: A single test dataset def test_step(self, batch, batch_idx): x, y = batch # implement your own out = self(x) loss = self.loss(out, y) # log 6 example images # or generated text... or whatever sample_imgs = x[:6] grid = torchvision.utils.make_grid(sample_imgs) self.logger.experiment.add_image('example_images', grid, 0) # calculate acc labels_hat = torch.argmax(out, dim=1) test_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0) # log the outputs! self.log_dict({'test_loss': loss, 'test_acc': test_acc})
If you pass in multiple test dataloaders,
test_step()will have an additional argument. We recommend setting the default value of 0 so that you can quickly switch between single and multiple dataloaders.# CASE 2: multiple test dataloaders def test_step(self, batch, batch_idx, dataloader_idx=0): # dataloader_idx tells you which dataset this is. ...
- Note:
If you don’t need to test you don’t need to implement this method.
- Note:
When the
test_step()is called, the model has been put in eval mode and PyTorch gradients have been disabled. At the end of the test epoch, the model goes back to training mode and gradients are enabled.
- on_test_epoch_end()[source]¶
Called in the test loop at the very end of the epoch.
- Return type:
None
- configure_optimizers()[source]¶
Choose what optimizers and learning-rate schedulers to use in your optimization. Normally you’d need one. But in the case of GANs or similar you might have multiple. Optimization with multiple optimizers only works in the manual optimization mode.
- Return:
Any of these 6 options.
Single optimizer.
List or Tuple of optimizers.
Two lists - The first list has multiple optimizers, and the second has multiple LR schedulers (or multiple
lr_scheduler_config).Dictionary, with an
"optimizer"key, and (optionally) a"lr_scheduler"key whose value is a single LR scheduler orlr_scheduler_config.None - Fit will run without any optimizer.
The
lr_scheduler_configis a dictionary which contains the scheduler and its associated configuration. The default configuration is shown below.lr_scheduler_config = { # REQUIRED: The scheduler instance "scheduler": lr_scheduler, # The unit of the scheduler's step size, could also be 'step'. # 'epoch' updates the scheduler on epoch end whereas 'step' # updates it after a optimizer update. "interval": "epoch", # How many epochs/steps should pass between calls to # `scheduler.step()`. 1 corresponds to updating the learning # rate after every epoch/step. "frequency": 1, # Metric to to monitor for schedulers like `ReduceLROnPlateau` "monitor": "val_loss", # If set to `True`, will enforce that the value specified 'monitor' # is available when the scheduler is updated, thus stopping # training if not found. If set to `False`, it will only produce a warning "strict": True, # If using the `LearningRateMonitor` callback to monitor the # learning rate progress, this keyword can be used to specify # a custom logged name "name": None, }
When there are schedulers in which the
.step()method is conditioned on a value, such as thetorch.optim.lr_scheduler.ReduceLROnPlateauscheduler, Lightning requires that thelr_scheduler_configcontains the keyword"monitor"set to the metric name that the scheduler should be conditioned on.Metrics can be made available to monitor by simply logging it using
self.log('metric_to_track', metric_val)in yourLightningModule.- Note:
Some things to know:
Lightning calls
.backward()and.step()automatically in case of automatic optimization.If a learning rate scheduler is specified in
configure_optimizers()with key"interval"(default “epoch”) in the scheduler configuration, Lightning will call the scheduler’s.step()method automatically in case of automatic optimization.If you use 16-bit precision (
precision=16), Lightning will automatically handle the optimizer.If you use
torch.optim.LBFGS, Lightning handles the closure function automatically for you.If you use multiple optimizers, you will have to switch to ‘manual optimization’ mode and step them yourself.
If you need to control how often the optimizer steps, override the
optimizer_step()hook.
- gravity.estimate_future_positions(stage1_csv, output_path, *, tau=0.5, show_plot=False, plot_path=None, projection_neighbor_choice='embedding', projection_neighbor_size=200, expression_scale='power10')[source]¶
Estimate future embeddings using the learned cell-wise velocities.
- Parameters:
stage1_csv (str) – Output CSV from the cell-wise stage (stage1).
output_path (str) – Destination
.npyfile to store the nearest-neighbour anchors.tau (float) – Scaling factor that shrinks the velocity vectors when forming the search radius.
show_plot (bool) – Whether to display a matplotlib window with the quiver plot.
plot_path (str | None) – Optional path to save the figure instead of (or in addition to) showing it.
projection_neighbor_choice (str)
projection_neighbor_size (int)
expression_scale (str | None)
- Returns:
final_positions, cell_df – Tuple of the
n × 3array with neighbour coordinates + indices, and the augmented dataframe returned bycompute_cell_velocity_().- Return type:
Tuple[ndarray, DataFrame]
- gravity.plot_velocity_gene(stage_csv, *, gene=None, x='splice', y='unsplice', color_by='clusters', palette=None, categories=None, cmap='viridis', point_size=200.0, alpha=0.5, arrow_grid=(20, 20), arrow_scale=1.0, arrow_color='k', projection_neighbor_choice='embedding', expression_scale='power10', projection_neighbor_size=200, min_mass=0.5, axis_off=True, output_path=None, show=False)[source]¶
Plot gene-level phase portraits with projected velocity arrows.
The arrows connect current to predicted expression for a sampled subset of cells.
- Parameters:
stage_csv (str)
gene (str | None)
x (str)
y (str)
color_by (str | None)
palette (Mapping[str, str] | None)
categories (Sequence[str] | None)
cmap (str)
point_size (float)
alpha (float)
arrow_grid (Tuple[int, int])
arrow_scale (float)
arrow_color (str)
projection_neighbor_choice (str)
expression_scale (str | None)
projection_neighbor_size (int)
min_mass (float)
axis_off (bool)
output_path (str | None)
show (bool)
- Return type:
Axes
- gravity.plot_velocity_cell(stage_csv, *, gene=None, x='splice', y='unsplice', color_by='clusters', palette=None, categories=None, cmap='viridis', point_size=200.0, alpha=0.5, arrow_grid=(20, 20), arrow_scale=1.0, arrow_color='k', projection_neighbor_choice='embedding', expression_scale='power10', projection_neighbor_size=200, min_mass=0.5, axis_off=True, output_path=None, show=False)[source]¶
Plot cell-level velocities in the embedding space.
The function computes velocity projections via
gravity.velocity.compute_cell_velocity_()and overlays grid-curved arrows. Color handling supports both discrete palettes and continuous scalars.- Parameters:
stage_csv (str)
gene (str | None)
x (str)
y (str)
color_by (str | None)
palette (Mapping[str, str] | None)
categories (Sequence[str] | None)
cmap (str)
point_size (float)
alpha (float)
arrow_grid (Tuple[int, int])
arrow_scale (float)
arrow_color (str)
projection_neighbor_choice (str)
expression_scale (str | None)
projection_neighbor_size (int)
min_mass (float)
axis_off (bool)
output_path (str | None)
show (bool)
- Return type:
Axes
- gravity.rank_tf_scores(attention_h5ad, *, groupby='cell_type', method='wilcoxon', key_added='tf_rankings', n_genes=30, output_plot=None, sort_group=None, top_n=30, reuse_h5ad=None)[source]¶
Run differential ranking on TF scores aggregated per cell type.
- Parameters:
attention_h5ad (str) – Path to attention_TF_scores_with_types.h5ad.
sort_group (int | str | None) – Group to sort by (logFC descending, p-value ascending). Accepts index (int) or name (str).
reuse_h5ad (str | None) – If provided with an h5ad already containing uns[key_added], reuse the stored ranking results.
groupby (str)
method (str)
key_added (str)
n_genes (int)
output_plot (str | None)
top_n (int)
- Returns:
all_rankings, top_rankings – DataFrame of all per-group rankings, and (if sort_group is given) the Top-N table for the selected group.
- Return type:
Tuple[DataFrame, DataFrame | None]
- gravity.compute_batc(adata, cluster_edges, *, cluster_key='clusters', embedding_key='umap', velocity_key='velocity', use_bins=True, n_bins=60, min_per_bin=5, n_samples=800, store_in_adata=True, progress=False)[source]¶
Compute the Branching-aware Trajectory Consistency (BATC) score.
BATC evaluates RNA velocity fields on predefined lineage graphs by fitting smooth principal curves between successive clusters and comparing their tangents with per-cell velocity vectors. For each directed edge :math:`A
- ightarrow B`, a curve \(\gamma_{A o B}(u)\) is fitted on the
2D embedding of cells belonging to \(A \cup B\). Every cell \(c\) on the edge is projected onto the curve and the cosine similarity between the local tangent :math:` au_c` and its velocity \(v_c\) is recorded:
. math:
\mathrm{BATC}_{A o B} =
- rac{1}{|I_{A\cup B}|}
sum_{c in I_{Acup B}}
rac{v_c cdot au_c}{lVert v_c Vert,lVert au_c Vert}.
To handle branching, for each source cluster \(A\) with outgoing targets \(B_1, \ldots, B_m\) we compute these cosine scores for every outgoing edge and, for each cell \(c \in I_A\), keep the best matching branch
. math:
b_c^{(A)} = \max_{B \in \mathrm{Out}(A)}
rac{v_c cdot au_c^{(A o B)}}{lVert v_c Vert,lVert au_c^{(A o B)} Vert}.
The final dataset-level BATC is the cell-weighted mean of these branch-aware scores over all source clusters:
. math:
\mathrm{BATC}_{\mathrm{overall}} =
rac{1}{sum_A |I_A|} sum_A sum_{c in I_A} b_c^{(A)}.
Zero-norm vectors are mapped to
nanscores, and all averages usenumpy.nanmeanfor robustness.- adata:
Annotated data matrix containing the embedding and velocity arrays.
- cluster_edges:
Directed edges describing permitted transitions between clusters. Nodes are cast to strings to match
adata.obs[cluster_key].- cluster_key:
Column in
adata.obsholding cluster labels. Defaults to'clusters'.- embedding_key:
Key identifying the 2D embedding. Defaults to
'umap'which is resolved toadata.obsm['X_umap']when present. Any other key is resolved in the same fashion (X_<key>or direct entry).- velocity_key:
Key identifying the velocity representation (must align with the embedding). With the default
'velocity'the function looks foradata.obsm['velocity_umap']whenembedding_key='umap'.- use_bins, n_bins, min_per_bin:
Controls for the skeletonisation step when fitting the principal curves.
- n_samples:
Number of evaluation points used to project cells onto each curve.
- store_in_adata:
If
True(default) write per-cell and aggregate scores back to theadata.obs/adata.unscontainers.- progress:
Whether to display a progress bar over edges (requires
tqdm).
- float
The BATC overall score (best-per-cell cosine mean).
- Parameters:
cluster_edges (Sequence[Tuple[str, str]])
cluster_key (str)
embedding_key (str)
velocity_key (str)
use_bins (bool)
n_bins (int)
min_per_bin (int)
n_samples (int)
store_in_adata (bool)
progress (bool)
- Return type:
float
- gravity.set_verbose(level)[source]¶
Set global verbosity level (higher means more logs).
- Parameters:
level (int)
- Return type:
None
- gravity.log_verbose(message, *, level=1)[source]¶
Print message when the verbosity is at least level.
- Parameters:
message (Any)
level (int)
- Return type:
None
Pipeline configuration and execution for the GRAVITY workflow.
This module exposes a high-level configuration dataclass and a single entry-point function that coordinates the complete GRAVITY pipeline:
preprocess long-format counts into a wide table used by subsequent stages,
train the cell-wise stage (stage 1) and export attention matrices if enabled,
estimate future positions from the stage-1 outputs,
train the gene-wise stage (stage 2),
optionally render velocity visualizations at the cell and gene level.
Multi-GPU and distributed execution are handled by PyTorch Lightning inside each training stage; this module coordinates preprocessing, stage transitions, future projection, and optional plotting.
References
- class gravity.pipeline.PipelineConfig(raw_counts, workdir='gravity_outputs', prior_network='./prior_data/nichenet_mouse.zip', gene_subset=None, gene_order_path=None, batch_size=16, n_pos_neighbors=30, n_neg_neighbors=10, stage1_epochs=6, stage2_epochs=6, stage1_lr=1e-06, stage2_lr=0.0001, stage1_pretrained_checkpoint=None, stage2_pretrained_checkpoint=None, embedding_size=16, model_dimension=16, ffn_dimension=16, val_fraction_stage1=0.0, val_fraction_stage2=0.0, accelerator='auto', devices=None, num_workers=8, strategy=None, precision=None, gradient_clip_val=None, future_tau=0.5, log_every_n_steps=50, progress_bar=True, make_plot=False, plot_gene=None, plot_color='clusters', plot_genes=None, arrow_grid=(20, 20), arrow_scale=1.0, middle_csv_name='combine.csv', stage1_csv_name='stage1.csv', stage2_csv_name='stage2.csv', future_positions_name='future_positions.npy', stage1_checkpoint_name='stage1.ckpt', stage2_checkpoint_name='stage2.ckpt')[source]¶
Configuration for the full GRAVITY pipeline.
- Parameters:
raw_counts (str) – Path to the input long-format CSV (must include at least
cellID, gene_name, unsplice, splice, embedding1, embedding2).workdir (str) – Output directory where intermediate and final artifacts are written.
prior_network (str | None) – Path to the prior TF–target network archive used by GRAVITY.
gene_subset (Sequence[str] | None) – Optional list of genes to restrict training and evaluation.
gene_order_path (str | None) – Optional newline-delimited gene list that fixes gene-index order. Use this when running pretrained/reference checkpoints; model weights and attention matrices are aligned by gene position, not only by gene name.
batch_size (int) – Mini-batch size used by both training stages.
stage1_epochs (int) – Number of epochs per stage.
stage2_epochs (int) – Number of epochs per stage.
stage1_lr (float) – Learning rates per stage. The unsupervised cell-wise stage and contrastive gene-wise stage are somewhat learning-rate sensitive; for reference-style runs we recommend keeping
stage1_lrbelow1e-5and tuningstage2_lrbetween1e-3and1e-5.stage2_lr (float) – Learning rates per stage. The unsupervised cell-wise stage and contrastive gene-wise stage are somewhat learning-rate sensitive; for reference-style runs we recommend keeping
stage1_lrbelow1e-5and tuningstage2_lrbetween1e-3and1e-5.stage1_pretrained_checkpoint (str | None) – Optional checkpoints used for inference/export instead of training the corresponding stage. Use these for checkpoint-based reference reproduction.
stage2_pretrained_checkpoint (str | None) – Optional checkpoints used for inference/export instead of training the corresponding stage. Use these for checkpoint-based reference reproduction.
val_fraction_stage1 (float) – Fraction of data reserved for validation in each stage.
val_fraction_stage2 (float) – Fraction of data reserved for validation in each stage.
accelerator (str) – Forwarded to PyTorch Lightning trainers to control device placement, distribution strategy and numerical precision.
devices (int | Sequence[int] | None) – Forwarded to PyTorch Lightning trainers to control device placement, distribution strategy and numerical precision.
strategy (str | None) – Forwarded to PyTorch Lightning trainers to control device placement, distribution strategy and numerical precision.
precision (int | str | None) – Forwarded to PyTorch Lightning trainers to control device placement, distribution strategy and numerical precision.
gradient_clip_val (float | None) – Forwarded to PyTorch Lightning trainers to control device placement, distribution strategy and numerical precision.
num_workers (int) – Forwarded to PyTorch Lightning trainers to control device placement, distribution strategy and numerical precision.
future_tau (float) – Scaling factor governing the radius used in future-neighbor search.
log_every_n_steps (int) – Logging and progress display controls.
progress_bar (bool) – Logging and progress display controls.
make_plot (bool) – Plotting options for optional velocity visualization.
plot_gene (str | None) – Plotting options for optional velocity visualization.
plot_color (str | None) – Plotting options for optional velocity visualization.
plot_genes (str | Sequence[str] | None) – Plotting options for optional velocity visualization.
arrow_grid (Tuple[int, int]) – Plotting options for optional velocity visualization.
arrow_scale (float) – Plotting options for optional velocity visualization.
middle_csv_name (str) – Filenames for artifacts written under
workdir.stage*_csv_name – Filenames for artifacts written under
workdir.future_positions_name (str) – Filenames for artifacts written under
workdir.stage*_checkpoint_name – Filenames for artifacts written under
workdir.n_pos_neighbors (int)
n_neg_neighbors (int)
embedding_size (int)
model_dimension (int)
ffn_dimension (int)
stage1_csv_name (str)
stage2_csv_name (str)
stage1_checkpoint_name (str)
stage2_checkpoint_name (str)
- raw_counts: str¶
- workdir: str = 'gravity_outputs'¶
- prior_network: str | None = './prior_data/nichenet_mouse.zip'¶
- gene_subset: Sequence[str] | None = None¶
- gene_order_path: str | None = None¶
- batch_size: int = 16¶
- n_pos_neighbors: int = 30¶
- n_neg_neighbors: int = 10¶
- stage1_epochs: int = 6¶
- stage2_epochs: int = 6¶
- stage1_lr: float = 1e-06¶
- stage2_lr: float = 0.0001¶
- stage1_pretrained_checkpoint: str | None = None¶
- stage2_pretrained_checkpoint: str | None = None¶
- embedding_size: int = 16¶
- model_dimension: int = 16¶
- ffn_dimension: int = 16¶
- val_fraction_stage1: float = 0.0¶
- val_fraction_stage2: float = 0.0¶
- accelerator: str = 'auto'¶
- devices: int | Sequence[int] | None = None¶
- num_workers: int = 8¶
- strategy: str | None = None¶
- precision: int | str | None = None¶
- gradient_clip_val: float | None = None¶
- future_tau: float = 0.5¶
- log_every_n_steps: int = 50¶
- progress_bar: bool = True¶
- make_plot: bool = False¶
- plot_gene: str | None = None¶
- plot_color: str | None = 'clusters'¶
- plot_genes: str | Sequence[str] | None = None¶
- arrow_grid: Tuple[int, int] = (20, 20)¶
- arrow_scale: float = 1.0¶
- middle_csv_name: str = 'combine.csv'¶
- stage1_csv_name: str = 'stage1.csv'¶
- stage2_csv_name: str = 'stage2.csv'¶
- future_positions_name: str = 'future_positions.npy'¶
- stage1_checkpoint_name: str = 'stage1.ckpt'¶
- stage2_checkpoint_name: str = 'stage2.ckpt'¶
- gravity.pipeline.run_pipeline(config)[source]¶
Execute preprocessing, two training stages, future projection, and optional plotting.
- Parameters:
config (PipelineConfig) –
PipelineConfiginstance that specifies inputs, training and plotting options, and output filenames.- Returns:
Dict[str, pathlib.Path] – Mapping of artifact names to absolute paths under
workdir. Keys includemiddle_csv,stage1_csv,stage1_checkpoint,attention_dir,future_positions,stage2_csv, andstage2_checkpoint. If plotting is enabled, additional keys may be present for generated figures.- Return type:
Dict[str, Path]
Notes
Multi-GPU controls in
configare forwarded to PyTorch Lightning inside the stage trainers. Preprocessing, future projection and plotting always run in the main process. When using DDP withstrategy='ddp'(spawn), child processes may re-import and execute the entry script; to avoid duplicated orchestration, this function performs an early return on non-zero ranks.
Data preprocessing utilities for GRAVITY training stages.
- gravity.data.preprocessing.preprocess_counts(input_file, output_csv, *, gene_order=None)[source]¶
Prepare the cell-wise training table from a long single-cell CSV.
- Parameters:
input_file (str) – Path to the raw long-format CSV with columns including cellID, gene_name, unsplice, splice, embedding1, embedding2.
output_csv (str) – Destination CSV containing one row per cell with serialized gene tuples.
gene_order (Sequence[str] | None) – Optional ordered gene list. When provided, matching gene columns are written first in this order so downstream checkpoints see the intended gene-index layout.
- Returns:
pathlib.Path – The path to the generated intermediate CSV.
- Return type:
Path
- gravity.data.preprocessing.load_cell_stage_dataset(middle_file, *, prior_path='./prior_data/nichenet_mouse.zip', gene_list=None, n_pos_neighbors=30, n_neg_neighbors=10)[source]¶
Instantiate the PyTorch dataset used for the cell-wise stage.
- Parameters:
middle_file (str)
prior_path (str)
gene_list (Sequence[str] | None)
- Return type:
CustomDataset
- gravity.data.preprocessing.load_gene_stage_dataset(middle_file, *, prior_path='./prior_data/nichenet_mouse.zip', future_positions='./final_positions_with_index_yixian.npy', gene_list=None)[source]¶
Instantiate the PyTorch dataset used for the gene-wise refinement stage.
- Parameters:
middle_file (str)
prior_path (str)
future_positions (str)
gene_list (Sequence[str] | None)
- Return type:
CustomDatasetGeneWise
- gravity.data.preprocessing.load_gene_order(path)[source]¶
Load a newline-delimited gene order file.
The order is part of the model contract for pretrained checkpoints because attention and solver tensors are indexed by gene position.
- Parameters:
path (str)
- Return type:
list[str]
- gravity.data.preprocessing.resolve_gene_order(gene_subset=None, gene_order_path=None)[source]¶
Resolve the effective gene list, preserving checkpoint-compatible order.
- Parameters:
gene_subset (Sequence[str] | None)
gene_order_path (str | None)
- Return type:
list[str] | None
- gravity.data.preprocessing.assert_gene_order_matches(path, expected, *, label='genes.txt')[source]¶
Raise if an existing gene-order file does not match expected order.
- Parameters:
path (Path)
expected (Sequence[str])
label (str)
- Return type:
None
- gravity.data.preprocessing.export_intermediate_from_h5ad(input_h5ad, output_csv, *, retain_genes=None, min_shared_counts=20, n_top_genes=2000, n_pcs=30, n_neighbors=30, embed_key='X_umap', celltype_key='cell_type', overwrite=False, normalized=False)[source]¶
Preprocess AnnData for GRAVITY and export a cellDancer-style CSV.
This function performs only the preprocessing required to create GRAVITY’s user-facing long-format count table. It does not run RNA velocity inference, does not plot, and does not save a processed .h5ad.
Pipeline¶
Read the input AnnData (.h5ad).
Optionally force-keep user-specified genes through filtering/HVG selection.
Normalize and filter genes with the local preprocessing settings.
Compute first-order moments with scv.pp.moments (produces ‘Mu’/’Ms’).
Export a cellDancer-style CSV via the local adata_to_df_with_embed function, including Mu/Ms, the chosen 2D embedding, and cell-type labels.
- param input_h5ad:
Path to an AnnData file that includes ‘spliced’ and ‘unspliced’ layers.
- param output_csv:
Destination CSV file path for the cellDancer-style long-format table.
- param retain_genes:
Genes that must be retained during filtering/HVG selection (if present).
- param min_shared_counts:
Minimum shared counts across cells for gene filtering.
- param n_top_genes:
Number of highly variable genes to retain (in addition to retain_genes).
- param n_pcs:
Number of principal components used by scv.pp.moments.
- param n_neighbors:
Neighborhood size used by scv.pp.moments.
- param embed_key:
Key in adata.obsm for a 2D embedding (e.g., “X_umap”).
- param celltype_key:
Column in adata.obs that holds cell-type labels (e.g., “cell_type” or “celltype”).
- param overwrite:
If False and output_csv exists, skip work and return the existing path.
- returns:
pathlib.Path – The path to the generated CSV.
- raises KeyError:
If embed_key is not found in adata.obsm.
- raises RuntimeError:
If required layers (‘spliced’ and ‘unspliced’) are missing.
- raises ImportError:
If scanpy or scvelo are missing.
- Parameters:
input_h5ad (str)
output_csv (str)
retain_genes (Sequence[str] | None)
min_shared_counts (int)
n_top_genes (int)
n_pcs (int)
n_neighbors (int)
embed_key (str)
celltype_key (str)
overwrite (bool)
normalized (bool)
- Return type:
Path
- gravity.data.preprocessing.adata_to_df_with_embed(adata, us_para=('Mu', 'Ms'), cell_type_para='celltype', embed_para='X_umap', save_path='cell_type_u_s_sample_df.csv', gene_list=None)[source]¶
Convert an AnnData object to a long CSV with per-gene/per-cell rows and 2D embedding.
The resulting CSV contains, for every (gene, cell) pair: - gene_name, unsplice, splice - cellID, clusters (cell-type), embedding1, embedding2
Notes
The two matrices are taken from adata.layers[us_para[0]] (unspliced) and adata.layers[us_para[1]] (spliced). By default, that is [‘Mu’, ‘Ms’].
The 2D embedding is read from adata.obsm[embed_para] (e.g., ‘X_umap’).
This function writes the CSV incrementally (one gene at a time) to keep memory usage manageable for large datasets, then appends cell metadata.
- Parameters:
adata – An anndata.AnnData object containing layers and embedding.
us_para (Sequence[str]) – Names of the two layers for unspliced and spliced moments/counts, respectively; default (‘Mu’, ‘Ms’).
cell_type_para (str) – Column name in adata.obs that holds cell-type labels (default ‘celltype’).
embed_para (str) – Key in adata.obsm for 2D embedding (default ‘X_umap’).
save_path (str) – Destination CSV file path.
gene_list (Sequence[str] | None) – Specific genes to export. If None, use all genes (adata.var.index).
- Returns:
pandas.DataFrame – The final DataFrame that was saved to save_path.
High-level helpers to train the GRAVITY cell-wise stage.
The functions in this module prepare datasets, configure PyTorch Lightning trainers, and export artefacts (CSV predictions, TF scores, priors) expected by downstream steps. They do not change the model’s inner computations.
- class gravity.train.cell_stage.CellStageConfig(raw_counts, middle_csv, prior_network='./prior_data/nichenet_mouse.zip', output_dir='gravity_outputs', stage1_csv='stage1.csv', checkpoint_name='stage1.ckpt', pretrained_checkpoint=None, attention_dir='attentions', gene_subset=None, gene_order_path=None, batch_size=32, epochs=6, accelerator='auto', devices=None, num_workers=0, val_fraction=0.0, attention_topk=64, attention_output=True, precision=None, gradient_clip_val=None, strategy=None, seed=42, log_every_n_steps=50, progress_bar=True, learning_rate=1e-06, embedding_size=16, model_dimension=16, ffn_dimension=16, n_pos_neighbors=30, n_neg_neighbors=10)[source]¶
Configuration bundle for the cell-wise training stage.
See also the top-level
gravity.pipeline.PipelineConfigfor how device/distribution options propagate.gene_order_pathfixes the checkpoint-compatible gene index order when using pretrained/reference weights.- Parameters:
raw_counts (str)
middle_csv (str)
prior_network (str | None)
output_dir (str)
stage1_csv (str)
checkpoint_name (str)
pretrained_checkpoint (str | None)
attention_dir (str)
gene_subset (Sequence[str] | None)
gene_order_path (str | None)
batch_size (int)
epochs (int)
accelerator (str)
devices (int | Sequence[int] | None)
num_workers (int)
val_fraction (float)
attention_topk (int)
attention_output (bool)
precision (int | str | None)
gradient_clip_val (float | None)
strategy (str | None)
seed (int)
log_every_n_steps (int)
progress_bar (bool)
learning_rate (float)
embedding_size (int)
model_dimension (int)
ffn_dimension (int)
n_pos_neighbors (int)
n_neg_neighbors (int)
- raw_counts: str¶
- middle_csv: str¶
- prior_network: str | None = './prior_data/nichenet_mouse.zip'¶
- output_dir: str = 'gravity_outputs'¶
- stage1_csv: str = 'stage1.csv'¶
- checkpoint_name: str = 'stage1.ckpt'¶
- pretrained_checkpoint: str | None = None¶
- attention_dir: str = 'attentions'¶
- gene_subset: Sequence[str] | None = None¶
- gene_order_path: str | None = None¶
- batch_size: int = 32¶
- epochs: int = 6¶
- accelerator: str = 'auto'¶
- devices: int | Sequence[int] | None = None¶
- num_workers: int = 0¶
- val_fraction: float = 0.0¶
- attention_topk: int = 64¶
- attention_output: bool = True¶
- precision: int | str | None = None¶
- gradient_clip_val: float | None = None¶
- strategy: str | None = None¶
- seed: int = 42¶
- log_every_n_steps: int = 50¶
- progress_bar: bool = True¶
- learning_rate: float = 1e-06¶
- embedding_size: int = 16¶
- model_dimension: int = 16¶
- ffn_dimension: int = 16¶
- n_pos_neighbors: int = 30¶
- n_neg_neighbors: int = 10¶
- gravity.train.cell_stage.train_cell_stage(config)[source]¶
Run the GRAVITY cell-wise stage and export key artefacts.
- Returns:
Dict[str, Path] – Paths to
stage1_csv,checkpoint, attention outputs (optional), and the generated gene files used by later steps.- Parameters:
config (CellStageConfig)
- Return type:
Dict[str, Path]
High-level helpers to refine GRAVITY models in the gene-wise stage.
This module loads the stage-1 checkpoint, freezes most parameters, and enables fine-tuning of a restricted set of solver layers. Artefacts are exported in the same format used by the cell-wise stage to ease downstream consumption.
- class gravity.train.gene_stage.GeneStageConfig(raw_counts, middle_csv, stage1_checkpoint, future_positions, prior_network='./prior_data/nichenet_mouse.zip', output_dir='gravity_outputs', stage2_csv='stage2.csv', checkpoint_name='stage2.ckpt', pretrained_checkpoint=None, gene_subset=None, gene_order_path=None, batch_size=32, epochs=6, accelerator='auto', devices=None, num_workers=0, val_fraction=0.0, precision=None, gradient_clip_val=None, strategy=None, seed=42, log_every_n_steps=50, progress_bar=True, learning_rate=0.0001, embedding_size=16, model_dimension=16, ffn_dimension=16)[source]¶
Configuration bundle for the gene-wise fine-tuning stage.
Device and distribution arguments mirror those in
gravity.pipeline.PipelineConfig.gene_order_pathshould point to the same gene list used by the stage-1 checkpoint when reproducing pretrained/reference runs.- Parameters:
raw_counts (str)
middle_csv (str)
stage1_checkpoint (str)
future_positions (str)
prior_network (str | None)
output_dir (str)
stage2_csv (str)
checkpoint_name (str)
pretrained_checkpoint (str | None)
gene_subset (Sequence[str] | None)
gene_order_path (str | None)
batch_size (int)
epochs (int)
accelerator (str)
devices (int | Sequence[int] | None)
num_workers (int)
val_fraction (float)
precision (int | str | None)
gradient_clip_val (float | None)
strategy (str | None)
seed (int)
log_every_n_steps (int)
progress_bar (bool)
learning_rate (float)
embedding_size (int)
model_dimension (int)
ffn_dimension (int)
- raw_counts: str¶
- middle_csv: str¶
- stage1_checkpoint: str¶
- future_positions: str¶
- prior_network: str | None = './prior_data/nichenet_mouse.zip'¶
- output_dir: str = 'gravity_outputs'¶
- stage2_csv: str = 'stage2.csv'¶
- checkpoint_name: str = 'stage2.ckpt'¶
- pretrained_checkpoint: str | None = None¶
- gene_subset: Sequence[str] | None = None¶
- gene_order_path: str | None = None¶
- batch_size: int = 32¶
- epochs: int = 6¶
- accelerator: str = 'auto'¶
- devices: int | Sequence[int] | None = None¶
- num_workers: int = 0¶
- val_fraction: float = 0.0¶
- precision: int | str | None = None¶
- gradient_clip_val: float | None = None¶
- strategy: str | None = None¶
- seed: int = 42¶
- log_every_n_steps: int = 50¶
- progress_bar: bool = True¶
- learning_rate: float = 0.0001¶
- embedding_size: int = 16¶
- model_dimension: int = 16¶
- ffn_dimension: int = 16¶
- gravity.train.gene_stage.train_gene_stage(config)[source]¶
Run the GRAVITY gene-wise refinement stage.
- Returns:
Dict[str, Path] – Paths to
stage2_csv,checkpointand the gene/prior files.- Parameters:
config (GeneStageConfig)
- Return type:
Dict[str, Path]
Utility to project GRAVITY velocities towards future positions.
- gravity.tools.future.estimate_future_positions(stage1_csv, output_path, *, tau=0.5, show_plot=False, plot_path=None, projection_neighbor_choice='embedding', projection_neighbor_size=200, expression_scale='power10')[source]¶
Estimate future embeddings using the learned cell-wise velocities.
- Parameters:
stage1_csv (str) – Output CSV from the cell-wise stage (stage1).
output_path (str) – Destination
.npyfile to store the nearest-neighbour anchors.tau (float) – Scaling factor that shrinks the velocity vectors when forming the search radius.
show_plot (bool) – Whether to display a matplotlib window with the quiver plot.
plot_path (str | None) – Optional path to save the figure instead of (or in addition to) showing it.
projection_neighbor_choice (str)
projection_neighbor_size (int)
expression_scale (str | None)
- Returns:
final_positions, cell_df – Tuple of the
n × 3array with neighbour coordinates + indices, and the augmented dataframe returned bycompute_cell_velocity_().- Return type:
Tuple[ndarray, DataFrame]