Source code for gravity.tools.future

"""Utility to project GRAVITY velocities towards future positions."""

from __future__ import annotations

from pathlib import Path
from typing import Optional, Tuple

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

from ..utils import log_verbose, resolve_path
from ..velocity import compute_cell_velocity_, extract_from_df

__all__ = [
    "estimate_future_positions",
]


[docs] def estimate_future_positions( stage1_csv: str, output_path: str, *, tau: float = 0.5, show_plot: bool = False, plot_path: Optional[str] = None, projection_neighbor_choice: str = 'embedding', projection_neighbor_size: int = 200, expression_scale: Optional[str] = 'power10', ) -> Tuple[np.ndarray, pd.DataFrame]: """Estimate future embeddings using the learned cell-wise velocities. Parameters ---------- stage1_csv: Output CSV from the cell-wise stage (stage1). output_path: Destination ``.npy`` file to store the nearest-neighbour anchors. tau: Scaling factor that shrinks the velocity vectors when forming the search radius. show_plot: Whether to display a matplotlib window with the quiver plot. plot_path: Optional path to save the figure instead of (or in addition to) showing it. Returns ------- final_positions, cell_df: Tuple of the ``n × 3`` array with neighbour coordinates + indices, and the augmented dataframe returned by :func:`compute_cell_velocity_`. """ stage1_path = resolve_path(stage1_csv) out_path = Path(output_path).expanduser().resolve() out_path.parent.mkdir(parents=True, exist_ok=True) stage1_df = pd.read_csv(stage1_path, index_col=0) existing_future = None if out_path.exists(): try: existing_future = np.load(str(out_path)) expected_cells = stage1_df['cellIndex'].nunique() if existing_future.shape[0] == expected_cells: log_verbose(f"[gravity] found existing future positions: {out_path}; skip.", level=1) return existing_future, stage1_df else: log_verbose( f"[gravity] existing future positions mismatch dataset (rows: {existing_future.shape[0]} vs {expected_cells}); recomputing.", level=1, ) existing_future = None except Exception as exc: log_verbose(f"[gravity] failed to load existing future positions ({exc}); recomputing.", level=1) existing_future = None log_verbose(f"[gravity] computing projected velocities from {stage1_path}", level=1) cell_df, velocity_embedding = compute_cell_velocity_( stage1_df, projection_neighbor_choice=projection_neighbor_choice, expression_scale=expression_scale, projection_neighbor_size=projection_neighbor_size, speed_up=None, ) embeddings = extract_from_df(cell_df, ['embedding1', 'embedding2'], None) if embeddings.ndim == 1: embeddings = embeddings.reshape(-1, 2) directions = velocity_embedding new_positions = embeddings + directions radius = np.linalg.norm(directions, axis=1) * tau final_positions = np.zeros((new_positions.shape[0], 3), dtype=float) hits = 0 for idx in range(new_positions.shape[0]): distances = np.linalg.norm(embeddings - new_positions[idx], axis=1) neighbours = np.where(distances < radius[idx])[0] if neighbours.size == 0: final_positions[idx, :2] = embeddings[idx] final_positions[idx, 2] = idx continue hits += 1 closest = neighbours[np.argmin(distances[neighbours])] final_positions[idx, :2] = embeddings[closest] final_positions[idx, 2] = closest log_verbose(f"[gravity] neighbours found within radius for {hits} cells", level=1) np.save(str(out_path), final_positions) log_verbose(f"[gravity] saved neighbour anchors to {out_path}", level=2) if show_plot or plot_path is not None: fig, ax = plt.subplots(figsize=(10, 10)) ax.scatter( embeddings[:, 0], embeddings[:, 1], s=8, color="#4A90E2", alpha=0.45, label="Current positions", edgecolor="none", ) ax.quiver( embeddings[:, 0], embeddings[:, 1], directions[:, 0], directions[:, 1], angles='xy', scale_units='xy', scale=1, color="#FF5A5F", alpha=0.6, linewidth=0.3, label="Velocity", ) ax.scatter( new_positions[:, 0], new_positions[:, 1], s=8, color="#37B26C", alpha=0.45, label="Projected positions", edgecolor="none", ) ax.scatter( final_positions[:, 0], final_positions[:, 1], s=8, color="#F7C948", alpha=0.6, label="Anchor neighbors", edgecolor="none", ) ax.set_xlabel('Embedding 1') ax.set_ylabel('Embedding 2') ax.legend(frameon=False, loc='upper right') ax.set_title('Future neighbor projection') ax.grid(True, linewidth=0.3, alpha=0.2) ax.set_aspect('equal', adjustable='datalim') plt.tight_layout() if plot_path is not None: plot_path_resolved = Path(plot_path).expanduser().resolve() plot_path_resolved.parent.mkdir(parents=True, exist_ok=True) fig.savefig(plot_path_resolved, dpi=300, bbox_inches='tight') log_verbose(f"[gravity] saved future-position plot to {plot_path_resolved}", level=2) if show_plot: plt.show() else: plt.close(fig) return final_positions, cell_df