# -*- coding: utf-8 -*-
"""Plotting helpers for GRAVITY velocity visualisations.
Cell-level arrows are generated with the grid-curve procedure used by the
GRAVITY velocity plots: two-end grids, two KNN passes, Gaussian weights,
absolute ``min_mass`` threshold, Bezier curves, and uniform arrow length.
Entrypoints
-----------
- :func:`plot_velocity_cell` – cell-level embedding plot with velocity arrows.
- :func:`plot_velocity_gene` – gene-level phase portrait with projected arrows.
- :func:`scatter_cell` – lower-level helper used internally.
"""
from __future__ import annotations
from pathlib import Path
from typing import Optional, Sequence, Tuple, Mapping, Dict
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from matplotlib import cm
from matplotlib.patches import Patch
from matplotlib.colors import ListedColormap, LinearSegmentedColormap, to_hex, hsv_to_rgb
from mpl_toolkits.axes_grid1.axes_divider import make_axes_locatable
from sklearn.neighbors import NearestNeighbors
from scipy.stats import norm as normal
import bezier
from ..utils import log_verbose, resolve_path, build_colormap_for_categories, build_colormap, map_colors_auto
from ..velocity import compute_cell_velocity_, sampling_neighbors
__all__ = ["plot_velocity_cell", "plot_velocity_gene", "scatter_cell", "build_colormap_for_categories", "build_colormap"]
# -----------------------------------------------------------------------------
# DataFrame column extraction
# -----------------------------------------------------------------------------
def _extract_from_df(df: pd.DataFrame, attrs: Sequence[str] | str, gene: Optional[str] = None) -> np.ndarray:
"""Return selected columns for a single gene, dropping rows with NaNs."""
if gene is None:
gene = df["gene_name"].iloc[0]
if isinstance(attrs, str):
attrs = [attrs]
sub = df.loc[df["gene_name"] == gene, list(attrs)].dropna()
arr = sub.to_numpy()
if arr.ndim == 2 and arr.shape[1] == 1:
arr = arr[:, 0]
return arr
def _resolve_gene_name(df: pd.DataFrame, gene: Optional[str]) -> str:
"""Resolve a requested gene name against the stage CSV, case-insensitively."""
genes = df["gene_name"].dropna().astype(str).unique().tolist()
if not genes:
raise ValueError("No genes found in stage CSV.")
if gene is None:
return genes[0]
gene = str(gene)
if gene in genes:
return gene
matches = [candidate for candidate in genes if candidate.upper() == gene.upper()]
if len(matches) == 1:
return matches[0]
if matches:
raise ValueError(f"Gene name {gene!r} is ambiguous; matching stage CSV entries: {matches[:10]}")
preview = ", ".join(genes[:10])
raise ValueError(f"Gene {gene!r} was not found in stage CSV. Available genes include: {preview}")
# -----------------------------------------------------------------------------
# Grid-curve implementation for cell-level arrows
# -----------------------------------------------------------------------------
def _find_nn_neighbors(data: np.ndarray, queries: np.ndarray, n_neighbors: int):
nn = NearestNeighbors(n_neighbors=max(1, int(n_neighbors)))
nn.fit(np.asarray(data))
dists, idxs = nn.kneighbors(np.asarray(queries), return_distance=True)
return dists, idxs
def _grid_curve_arrows(ax,
embedding_ds: np.ndarray,
velocity_embedding: np.ndarray,
arrow_grid: Tuple[int, int],
min_mass: float) -> None:
"""Curve-based arrow rendering on a rectangular grid.
Pipeline: two-end grids → KNN smoothing (Gaussian kernel) → absolute mass
threshold → Bezier curve evaluation → normalized tail arrows.
"""
def _calculate_two_end_grid(emb, vel, smooth: float, steps: Tuple[int, int], min_mass: float):
# Grid with slight offset and small padding
grs = []
for dim in range(emb.shape[1]):
m, M = np.min(emb[:, dim]) - 0.2, np.max(emb[:, dim]) - 0.2
m = m - 0.025 * np.abs(M - m)
M = M + 0.025 * np.abs(M - m)
grs.append(np.linspace(m, M, int(steps[dim])))
mesh = np.meshgrid(*grs)
XY = np.vstack([axis.flat for axis in mesh]).T
# Number of neighbors = n/3
k = max(1, int(vel.shape[0] / 3))
# Two-end KNN (head/tail)
d_head, ix_head = _find_nn_neighbors(emb, XY, k)
d_tail, ix_tail = _find_nn_neighbors(emb + vel, XY, k)
# Gaussian kernel (bandwidth = smooth * average grid spacing)
std = float(np.mean([(g[1] - g[0]) for g in grs]))
gw_head = normal.pdf(x=d_head, loc=0, scale=smooth * std)
mass_head = gw_head.sum(1)
gw_tail = normal.pdf(x=d_tail, loc=0, scale=smooth * std)
mass_tail = gw_tail.sum(1)
# Weighted averages
UZ_head = (vel[ix_head] * gw_head[:, :, None]).sum(1) / np.maximum(1, mass_head)[:, None]
UZ_tail = (vel[ix_tail] * gw_tail[:, :, None]).sum(1) / np.maximum(1, mass_tail)[:, None]
# Second KNN pass (after displacement)
d_head2, ix_head2 = _find_nn_neighbors(emb, XY + UZ_head, k)
d_tail2, ix_tail2 = _find_nn_neighbors(emb, XY - UZ_tail, k)
gw_head2 = normal.pdf(x=d_head2, loc=0, scale=smooth * std)
mass_head2 = gw_head2.sum(1)
gw_tail2 = normal.pdf(x=d_tail2, loc=0, scale=smooth * std)
mass_tail2 = gw_tail2.sum(1)
UZ_head2 = (vel[ix_head2] * gw_head2[:, :, None]).sum(1) / np.maximum(1, mass_head2)[:, None]
UZ_tail2 = (vel[ix_tail2] * gw_tail2[:, :, None]).sum(1) / np.maximum(1, mass_tail2)[:, None]
keep = mass_head >= float(min_mass) # Absolute threshold
return (XY[keep], UZ_head[keep], UZ_tail[keep], UZ_head2[keep], UZ_tail2[keep], grs)
curve_smooth = 0.8
XY, UH, UT, UH2, UT2, grs = _calculate_two_end_grid(
embedding_ds, velocity_embedding, smooth=curve_smooth, steps=arrow_grid, min_mass=min_mass
)
n_curves = XY.shape[0]
s_vals = np.linspace(0.0, 1.5, 15)
def _norm_ratio(XY, UT, UH, UT2, UH2, grs, s_vals):
def _seg_len(x, y):
return np.sqrt(np.diff(x) ** 2 + np.diff(y) ** 2)
max_len = 0.0
for i in range(n_curves):
nodes = np.asfortranarray([
[XY[i,0]-UT[i,0]-UT2[i,0], XY[i,0]-UT[i,0], XY[i,0], XY[i,0]+UH[i,0], XY[i,0]+UH[i,0]+UH2[i,0]],
[XY[i,1]-UT[i,1]-UT2[i,1], XY[i,1]-UT[i,1], XY[i,1], XY[i,1]+UH[i,1], XY[i,1]+UH[i,1]+UH2[i,1]],
])
curve = bezier.Curve(nodes, degree=4)
dots = curve.evaluate_multi(s_vals)
max_len = max(max_len, float(np.sum(_seg_len(dots[0], dots[1]))))
grid_step = (abs(grs[0][1]-grs[0][0]) + abs(grs[1][1]-grs[1][0])) / 2.0
return (grid_step / max_len) if max_len > 0 else 1.0
ratio = _norm_ratio(XY, UT, UH, UT2, UH2, grs, s_vals)
UT, UH, UT2, UH2 = UT * ratio, UH * ratio, UT2 * ratio, UH2 * ratio
# Draw Bezier curves and normalized tail arrows
for i in range(n_curves):
nodes = np.asfortranarray([
[XY[i,0]-UT[i,0]-UT2[i,0], XY[i,0]-UT[i,0], XY[i,0], XY[i,0]+UH[i,0], XY[i,0]+UH[i,0]+UH2[i,0]],
[XY[i,1]-UT[i,1]-UT2[i,1], XY[i,1]-UT[i,1], XY[i,1], XY[i,1]+UH[i,1], XY[i,1]+UH[i,1]+UH2[i,1]],
])
curve = bezier.Curve(nodes, degree=4)
dots = curve.evaluate_multi(s_vals)
ax.plot(dots[0], dots[1], linewidth=0.5, color='black', alpha=1.0)
U = dots[0][-1] - dots[0][-2]
V = dots[1][-1] - dots[1][-2]
N = (U**2 + V**2) ** 0.5 + 1e-12
U1, V1 = (U/N) * 0.5, (V/N) * 0.5 # 固定 0.5
ax.quiver(dots[0][-2], dots[1][-2], U1, V1,
units='xy', angles='xy', scale=1, linewidth=0,
color='black', alpha=1.0, minlength=0, width=0.1)
# -----------------------------------------------------------------------------
# Cell-level plotting API
# -----------------------------------------------------------------------------
def scatter_cell(
ax,
cellDancer_df: pd.DataFrame,
colors=None,
custom_xlim: Optional[Tuple[float, float]] = None,
custom_ylim: Optional[Tuple[float, float]] = None,
vmin: Optional[float] = None,
vmax: Optional[float] = None,
alpha: float = 0.5,
s: float = 200,
gene: Optional[str] = None,
velocity: bool = False,
legend: str = 'off',
colorbar: str = 'on',
min_mass: float = 0.5,
arrow_grid: Tuple[int, int] = (20, 20)
):
"""Render a cell-level scatter plot and optional velocity arrows.
This follows the GRAVITY cell-level plotting convention while accepting
categorical mappings, continuous columns, and fixed color values.
"""
# --- Color handling ---
if isinstance(colors, list):
colors = build_colormap(colors)
if isinstance(colors, dict):
# Cluster mapping
cluster_vals = _extract_from_df(cellDancer_df, 'clusters', gene)
c = np.vectorize(colors.get)(pd.Series(cluster_vals, dtype=str)).tolist()
cmap = ListedColormap(list(colors.values()))
if legend != 'off':
handles = [Patch(facecolor=colors[k], edgecolor="none", label=k) for k in colors]
lgd = ax.legend(handles=handles, bbox_to_anchor=(1.01, 1), loc='upper left')
elif isinstance(colors, str):
# Continuous variable: use the standard 'viridis' colormap
attr = colors
assert gene, "Error! gene is required!"
cmap = 'viridis'
c = _extract_from_df(cellDancer_df, attr, gene)
else:
cmap = None
c = 'Grey'
# --- Scatter ---
emb = _extract_from_df(cellDancer_df, ['embedding1', 'embedding2'], gene)
n_cells = emb.shape[0]
im = ax.scatter(emb[:, 0], emb[:, 1], c=c, cmap=cmap, s=s,
vmin=vmin, vmax=vmax, alpha=alpha, edgecolor="none")
if colorbar == 'on' and isinstance(colors, str):
ax_divider = make_axes_locatable(ax)
cax = ax_divider.append_axes("top", size="5%", pad="-5%")
cbar = plt.colorbar(im, cax=cax, orientation="horizontal", shrink=0.1)
cbar.set_ticks([])
# --- Grid-curve arrows ---
if velocity:
# Only draw for sampled cells with velocity
sample_cells = cellDancer_df['velocity1'][:n_cells].dropna().index
emb_ds = emb[sample_cells]
vel = _extract_from_df(cellDancer_df, ['velocity1', 'velocity2'], gene) # aligned with emb_ds after dropna
_grid_curve_arrows(ax, emb_ds, vel, arrow_grid, min_mass)
if custom_xlim is not None:
ax.set_xlim(*custom_xlim)
if custom_ylim is not None:
ax.set_ylim(*custom_ylim)
return ax
# -----------------------------------------------------------------------------
# High-level entry: plot_velocity_* helpers (embedding uses grid-curve arrows)
# -----------------------------------------------------------------------------
[docs]
def plot_velocity_cell(
stage_csv: str,
*,
gene: Optional[str] = None,
x: str = "splice",
y: str = "unsplice",
color_by: Optional[str] = "clusters",
palette: Optional[Mapping[str, str]] = None,
categories: Optional[Sequence[str]] = None,
cmap: str = "viridis",
point_size: float = 200.0,
alpha: float = 0.5,
arrow_grid: Tuple[int, int] = (20, 20),
arrow_scale: float = 1.0,
arrow_color: str = "k",
projection_neighbor_choice: str = "embedding",
expression_scale: Optional[str] = "power10",
projection_neighbor_size: int = 200,
min_mass: float = 0.5,
axis_off: bool = True,
output_path: Optional[str] = None,
show: bool = False,
) -> plt.Axes:
"""Plot cell-level velocities in the embedding space.
The function computes velocity projections via :func:`gravity.velocity.compute_cell_velocity_`
and overlays grid-curved arrows. Color handling supports both discrete
palettes and continuous scalars.
"""
path = resolve_path(stage_csv)
log_verbose(f"[gravity] plotting velocities from {path}", level=1)
df = pd.read_csv(path)
gene = _resolve_gene_name(df, gene)
fig, ax = plt.subplots(figsize=(20, 20))
# 1) Compute cell-level velocity.
res = compute_cell_velocity_(
df,
projection_neighbor_choice=projection_neighbor_choice,
expression_scale=expression_scale,
projection_neighbor_size=projection_neighbor_size,
)
cell_df = res[0] if isinstance(res, tuple) else res
# 2) Translate color_by to the form required by scatter_cell(colors=...)
# - Discrete column (e.g. 'clusters'): pass a dict (palette or auto)
# - Continuous column ('alpha'/'beta'/'gamma'/'splice'/'unsplice'/'pseudotime'): pass the column name
# - Other discrete column: temporarily map it to 'clusters' in a copy
colors_arg = None
plot_df = cell_df.copy()
if color_by is None:
colors_arg = None
elif color_by in ('alpha', 'beta', 'gamma', 'splice', 'unsplice', 'pseudotime'):
# scatter_cell uses the 'viridis' continuous colormap
colors_arg = color_by
else:
# Treat as discrete labels
if color_by != 'clusters':
# Copy the requested discrete column to 'clusters' (plot-local)
plot_df.loc[:, 'clusters'] = plot_df[color_by].astype(str)
# Use provided palette; otherwise auto-generate one
if palette is not None:
colors_arg = dict(palette)
else:
# Derive categories from the current gene only
gene_mask = (plot_df['gene_name'] == gene)
cats = list(pd.unique(plot_df.loc[gene_mask, 'clusters'].astype(str)))
colors_arg = build_colormap_for_categories(cats)
# 3) Delegate to scatter_cell (grid-curve arrows)
scatter_cell(
ax=ax,
cellDancer_df=plot_df,
colors=colors_arg,
alpha=alpha,
s=point_size,
gene=gene,
velocity=True, # Draw cell-level arrows
legend='off', # Enable if a legend is desired
colorbar='on' if isinstance(colors_arg, str) else 'off',
min_mass=min_mass,
arrow_grid=arrow_grid,
)
ax.set_xlabel("Embedding 1")
ax.set_ylabel("Embedding 2")
ax.set_title("GRAVITY cell velocities (embedding)")
if axis_off:
ax.axis("off")
ax.grid(False)
if output_path is not None:
out = Path(output_path).expanduser().resolve()
out.parent.mkdir(parents=True, exist_ok=True)
fig.savefig(out, dpi=300, bbox_inches="tight")
log_verbose(f"[gravity] saved velocity plot to {out}", level=2)
if show:
plt.show()
return ax
[docs]
def plot_velocity_gene(
stage_csv: str,
*,
gene: Optional[str] = None,
x: str = "splice",
y: str = "unsplice",
color_by: Optional[str] = "clusters",
palette: Optional[Mapping[str, str]] = None,
categories: Optional[Sequence[str]] = None,
cmap: str = "viridis",
point_size: float = 200.0,
alpha: float = 0.5,
arrow_grid: Tuple[int, int] = (20, 20),
arrow_scale: float = 1.0,
arrow_color: str = "k",
projection_neighbor_choice: str = "embedding",
expression_scale: Optional[str] = "power10",
projection_neighbor_size: int = 200,
min_mass: float = 0.5,
axis_off: bool = True,
output_path: Optional[str] = None,
show: bool = False,
) -> plt.Axes:
"""Plot gene-level phase portraits with projected velocity arrows.
The arrows connect current to predicted expression for a sampled subset of
cells.
"""
path = resolve_path(stage_csv)
log_verbose(f"[gravity] plotting velocities from {path}", level=1)
df = pd.read_csv(path)
gene = _resolve_gene_name(df, gene)
fig, ax = plt.subplots(figsize=(20, 20))
# ====== Gene-level straight-arrow implementation ======
if x not in {"splice", "unsplice"} or y not in {"splice", "unsplice"}:
raise ValueError("For gene phase portrait mode, x and y must be 'splice' or 'unsplice'.")
gene_df = df[df["gene_name"] == gene].copy()
if gene_df.empty:
raise ValueError(f"Gene {gene!r} resolved but no rows were available for plotting.")
# Colors: discrete → palette/autogenerated; continuous → colormap
mapped, meta = map_colors_auto(
gene_df,
color_by=color_by,
palette=palette,
categories=categories,
cmap_continuous=cmap,
)
scatter_kwargs = dict(s=point_size, alpha=alpha, edgecolor="none")
coords = gene_df[[x, y]].to_numpy()
if isinstance(meta, tuple) and meta[0] == "discrete":
_, handles = meta
ax.scatter(coords[:, 0], coords[:, 1], color=mapped, **scatter_kwargs)
if handles:
ax.legend(handles=handles, title=color_by, loc="best", frameon=False)
elif isinstance(meta, tuple) and meta[0] == "continuous":
_, cm_name = meta
sc = ax.scatter(coords[:, 0], coords[:, 1], c=mapped, cmap=cm_name, **scatter_kwargs)
plt.colorbar(sc, ax=ax, label=color_by)
else:
ax.scatter(coords[:, 0], coords[:, 1], color=mapped, **scatter_kwargs)
# Straight arrows.
u_s = gene_df[["unsplice", "splice", "unsplice_predict", "splice_predict"]].to_numpy()
idx = np.asarray(sampling_neighbors(u_s[:, 0:2], step=arrow_grid, percentile=15))
idx = idx[idx < u_s.shape[0]]
U = u_s[idx, :]
if x == "splice" and y == "unsplice":
P_x, P_y = U[:, 1], U[:, 0]
dX, dY = (U[:, 3] - U[:, 1]), (U[:, 2] - U[:, 0])
elif x == "unsplice" and y == "splice":
P_x, P_y = U[:, 0], U[:, 1]
dX, dY = (U[:, 2] - U[:, 0]), (U[:, 3] - U[:, 1])
else:
P_x, P_y = U[:, 1], U[:, 0]
dX, dY = (U[:, 3] - U[:, 1]), (U[:, 2] - U[:, 0])
if P_x.size > 0:
ax.scatter(P_x, P_y, facecolors="none", edgecolors=arrow_color, s=point_size * 1.2)
ax.quiver(P_x, P_y, dX * arrow_scale, dY * arrow_scale,
angles="xy", color=arrow_color, alpha=0.8)
ax.set_xlabel("Splice" if x == "splice" else "Unsplice")
ax.set_ylabel("Splice" if y == "splice" else "Unsplice")
ax.set_title(f"GRAVITY gene velocities (expression) – {gene} [{x} vs {y}]")
if axis_off:
ax.axis("off")
ax.grid(False)
if output_path is not None:
out = Path(output_path).expanduser().resolve()
out.parent.mkdir(parents=True, exist_ok=True)
fig.savefig(out, dpi=300, bbox_inches="tight")
log_verbose(f"[gravity] saved velocity plot to {out}", level=2)
if show:
plt.show()
return ax