"""BATC metric computation utilities.
This module implements the BATC (Branch-Aware Tangent Consistency) metric. BATC
quantifies how well per-cell velocity vectors align with principal curves that
connect successive cell clusters along a lineage graph.
The implementation follows these high-level steps:
1. Extract a 2D embedding and velocity representation from an :class:`AnnData`
object (from ``.obsm`` or ``.layers``).
2. For each directed edge (``source`` → ``target`` cluster) fit a smooth
parametric curve using a robust PCHIP interpolation over the projected
manifold points belonging to the two clusters.
3. Project every cell along the edge onto the fitted curve and compute the
cosine similarity between its velocity vector and the local curve tangent.
4. Aggregate cosine scores per edge, per cell (taking the best outgoing branch
when multiple exist), and overall.
The public entry point :func:`compute_batc` exposes configuration options and
can optionally write results back into ``adata.obs``/``adata.uns`` for
downstream inspection.
"""
from __future__ import annotations
from typing import Dict, Iterable, Sequence, Tuple
import numpy as np
from scipy.interpolate import PchipInterpolator
from scipy.spatial import cKDTree
try: # Optional progress feedback – silently skip if tqdm is absent.
from tqdm import tqdm
except Exception: # pragma: no cover - dependency not guaranteed
def tqdm(iterable, **kwargs): # type: ignore
return iterable
from ..utils import log_verbose
__all__ = ["compute_batc"]
ArrayLike = np.ndarray
class ParametricCurve2D:
"""Piecewise-cubic parametric curve ``(x(s), y(s))`` defined on ``s ∈ [0, 1]``."""
def __init__(self, parameters: ArrayLike, points: ArrayLike) -> None:
s = np.asarray(parameters, dtype=float)
pts = np.asarray(points, dtype=float)
if s.ndim != 1 or pts.ndim != 2 or pts.shape[1] != 2:
raise ValueError("ParametricCurve2D expects 1D parameters and Nx2 points.")
self.fx = PchipInterpolator(s, pts[:, 0])
self.fy = PchipInterpolator(s, pts[:, 1])
def eval(self, u: ArrayLike) -> ArrayLike:
u_arr = np.clip(np.asarray(u, dtype=float), 0.0, 1.0)
return np.column_stack([self.fx(u_arr), self.fy(u_arr)])
def deriv(self, u: ArrayLike, *, h: float = 1e-3) -> ArrayLike:
"""Return numerical first derivative using centred differences."""
u_arr = np.asarray(u, dtype=float)
u_minus = np.clip(u_arr - h, 0.0, 1.0)
u_plus = np.clip(u_arr + h, 0.0, 1.0)
xy_minus = self.eval(u_minus)
xy_plus = self.eval(u_plus)
denom = (u_plus - u_minus)[:, None]
denom[denom == 0.0] = 1e-12
return (xy_plus - xy_minus) / denom
def _curve_eval(curve: ParametricCurve2D, u: float | ArrayLike, *, derivative: bool) -> ArrayLike:
if derivative:
values = curve.deriv(np.atleast_1d(u))
else:
values = curve.eval(np.atleast_1d(u))
if np.ndim(u) == 0:
return values[0]
return values
def _extract_embedding_and_velocity(
adata,
embedding_key: str,
velocity_key: str,
) -> Tuple[ArrayLike, ArrayLike]:
"""Return ``(X, V)`` matrices from an :class:`AnnData` object.
Lookup order: direct keys, ``X_<key>`` / ``velocity_<key>`` style keys, then
the ``layers`` container.
"""
if hasattr(adata, "obsm"):
if embedding_key in adata.obsm and velocity_key in adata.obsm:
return np.asarray(adata.obsm[embedding_key]), np.asarray(adata.obsm[velocity_key])
if embedding_key.startswith("X_"):
X_key = embedding_key
base = embedding_key[1:]
else:
X_key = f"X_{embedding_key}"
base = embedding_key
V_key = velocity_key
if base and f"{velocity_key}_{base}" in adata.obsm:
V_key = f"{velocity_key}_{base}"
elif base.startswith("_") and f"{velocity_key}{base}" in adata.obsm:
V_key = f"{velocity_key}{base}"
if X_key in adata.obsm and V_key in adata.obsm:
return np.asarray(adata.obsm[X_key]), np.asarray(adata.obsm[V_key])
if hasattr(adata, "layers"):
layers = getattr(adata, "layers", {})
if embedding_key in layers and velocity_key in layers:
return np.asarray(layers[embedding_key]), np.asarray(layers[velocity_key])
if embedding_key.startswith("X_"):
base = embedding_key[1:]
else:
base = embedding_key
alt_X = base
alt_V = f"{velocity_key}_{base}" if base else velocity_key
if alt_X in layers and alt_V in layers:
return np.asarray(layers[alt_X]), np.asarray(layers[alt_V])
raise KeyError(
f"Could not find a matching embedding/velocity pair for keys '{embedding_key}' and '{velocity_key}'."
)
def _fit_curve_pchip(
points: ArrayLike,
src_center: ArrayLike,
tgt_center: ArrayLike,
*,
use_bins: bool,
n_bins: int,
min_per_bin: int,
n_samples: int,
) -> Tuple[ParametricCurve2D, ArrayLike, ArrayLike, float]:
"""Fit a parametric curve through points belonging to source & target clusters."""
P = np.asarray(points, dtype=float)
direction = tgt_center - src_center
direction_norm = np.linalg.norm(direction)
unit_dir = direction / (direction_norm + 1e-12)
projection = (P - src_center) @ unit_dir
order = np.argsort(projection)
P_sorted = P[order]
proj_sorted = projection[order]
if use_bins and len(P_sorted) >= n_bins:
bin_edges = np.linspace(proj_sorted.min(), proj_sorted.max(), n_bins + 1)
anchors = []
for idx in range(n_bins):
if idx < n_bins - 1:
mask = (proj_sorted >= bin_edges[idx]) & (proj_sorted < bin_edges[idx + 1])
else:
mask = (proj_sorted >= bin_edges[idx]) & (proj_sorted <= bin_edges[idx + 1])
if np.count_nonzero(mask) >= min_per_bin:
anchors.append(P_sorted[mask].mean(axis=0))
skeleton = np.vstack(anchors) if len(anchors) >= 3 else P_sorted
else:
skeleton = P_sorted
if len(skeleton) < 2:
skeleton = np.vstack([P_sorted[0], P_sorted[-1]])
segment_lengths = np.linalg.norm(np.diff(skeleton, axis=0), axis=1)
s = np.r_[0.0, np.cumsum(segment_lengths)]
if s[-1] == 0.0:
s[-1] = 1.0
s /= s[-1]
curve = ParametricCurve2D(s, skeleton)
u_grid = np.linspace(0.0, 1.0, int(n_samples))
samples = curve.eval(u_grid)
tree = cKDTree(samples)
_, nearest = tree.query(P, k=1)
u_all = u_grid[nearest]
start = curve.eval(0.0).ravel()
end = curve.eval(1.0).ravel()
tangent = end - start
orient = 1.0
if direction_norm > 0 and np.dot(tangent, direction) < 0:
orient = -1.0
return curve, u_all[order], order, orient
def _prepare_cluster_data(
X: ArrayLike,
labels: ArrayLike,
) -> Tuple[Dict[str, ArrayLike], Dict[str, ArrayLike]]:
unique_labels = np.unique(labels)
idx_by_cluster = {lbl: np.where(labels == lbl)[0] for lbl in unique_labels}
centroids = {lbl: X[idx].mean(axis=0) if len(idx) else np.zeros(2) for lbl, idx in idx_by_cluster.items()}
return idx_by_cluster, centroids
[docs]
def compute_batc(
adata,
cluster_edges: Sequence[Tuple[str, str]],
*,
cluster_key: str = "clusters",
embedding_key: str = "umap",
velocity_key: str = "velocity",
use_bins: bool = True,
n_bins: int = 60,
min_per_bin: int = 5,
n_samples: int = 800,
store_in_adata: bool = True,
progress: bool = False,
) -> float:
"""Compute the Branching-aware Trajectory Consistency (BATC) score.
BATC evaluates RNA velocity fields on predefined lineage graphs by fitting
smooth principal curves between successive clusters and comparing their
tangents with per-cell velocity vectors. For each directed edge
:math:`A \rightarrow B`, a curve :math:`\gamma_{A\to B}(u)` is fitted on the
2D embedding of cells belonging to :math:`A \cup B`. Every cell :math:`c`
on the edge is projected onto the curve and the cosine similarity between the
local tangent :math:`\tau_c` and its velocity :math:`v_c` is recorded:
. math::
\mathrm{BATC}_{A\to B} = \frac{1}{|I_{A\cup B}|}
\sum_{c \in I_{A\cup B}}
\frac{v_c \cdot \tau_c}{\lVert v_c \rVert\,\lVert \tau_c \rVert}.
To handle branching, for each source cluster :math:`A` with outgoing targets
:math:`B_1, \ldots, B_m` we compute these cosine scores for every outgoing
edge and, for each cell :math:`c \in I_A`, keep the best matching branch
. math::
b_c^{(A)} = \max_{B \in \mathrm{Out}(A)}
\frac{v_c \cdot \tau_c^{(A\to B)}}{\lVert v_c \rVert\,\lVert \tau_c^{(A\to B)} \rVert}.
The final dataset-level BATC is the cell-weighted mean of these branch-aware
scores over all source clusters:
. math::
\mathrm{BATC}_{\mathrm{overall}} =
\frac{1}{\sum_A |I_A|} \sum_A \sum_{c \in I_A} b_c^{(A)}.
Zero-norm vectors are mapped to ``nan`` scores, and all averages use
``numpy.nanmean`` for robustness.
Parameters
----------
adata:
Annotated data matrix containing the embedding and velocity arrays.
cluster_edges:
Directed edges describing permitted transitions between clusters. Nodes
are cast to strings to match ``adata.obs[cluster_key]``.
cluster_key:
Column in ``adata.obs`` holding cluster labels. Defaults to ``'clusters'``.
embedding_key:
Key identifying the 2D embedding. Defaults to ``'umap'`` which is
resolved to ``adata.obsm['X_umap']`` when present. Any other key is
resolved in the same fashion (``X_<key>`` or direct entry).
velocity_key:
Key identifying the velocity representation (must align with the
embedding). With the default ``'velocity'`` the function looks for
``adata.obsm['velocity_umap']`` when ``embedding_key='umap'``.
use_bins, n_bins, min_per_bin:
Controls for the skeletonisation step when fitting the principal curves.
n_samples:
Number of evaluation points used to project cells onto each curve.
store_in_adata:
If ``True`` (default) write per-cell and aggregate scores back to the
``adata.obs``/``adata.uns`` containers.
progress:
Whether to display a progress bar over edges (requires ``tqdm``).
Returns
-------
float
The BATC overall score (best-per-cell cosine mean).
"""
cluster_series = adata.obs[cluster_key].astype(str)
X, V = _extract_embedding_and_velocity(adata, embedding_key, velocity_key)
if X.shape != V.shape:
raise ValueError(f"Embedding shape {X.shape} does not match velocity shape {V.shape}.")
if X.shape[1] != 2:
raise ValueError("BATC expects a 2D embedding.")
clusters = cluster_series.to_numpy()
idx_by_cluster, centroids_data = _prepare_cluster_data(X, clusters)
edges = [(str(src), str(tgt)) for src, tgt in cluster_edges]
edge_fits: Dict[Tuple[str, str], Dict[str, object]] = {}
iterable = edges
if progress:
iterable = tqdm(iterable, desc="Fitting BATC curves")
for src, tgt in iterable:
src_idx = idx_by_cluster.get(src, np.empty(0, dtype=int))
tgt_idx = idx_by_cluster.get(tgt, np.empty(0, dtype=int))
if src_idx.size == 0 and tgt_idx.size == 0:
log_verbose(f"[batc] skip edge {src}->{tgt}: no cells in either cluster", level=2)
continue
points_idx = np.concatenate([src_idx, tgt_idx])
curve, u_sorted, order, orient = _fit_curve_pchip(
X[points_idx],
centroids_data.get(src, X[points_idx].mean(axis=0)),
centroids_data.get(tgt, X[points_idx].mean(axis=0)),
use_bins=use_bins,
n_bins=n_bins,
min_per_bin=min_per_bin,
n_samples=n_samples,
)
edge_fits[(src, tgt)] = {
"curve": curve,
"u": u_sorted,
"sorted_idx": points_idx[order],
"orient": orient,
}
edge_cosines: Dict[Tuple[str, str], Dict[int, float]] = {}
for (src, tgt), fit in edge_fits.items():
curve: ParametricCurve2D = fit["curve"] # type: ignore[assignment]
u_vals = fit["u"]
sorted_idx = fit["sorted_idx"]
orient = float(fit["orient"])
cos_scores: Dict[int, float] = {}
for pos, cell_idx in enumerate(sorted_idx):
velocity = V[cell_idx]
tangent = orient * _curve_eval(curve, u_vals[pos], derivative=True)
if np.linalg.norm(tangent) == 0 or np.linalg.norm(velocity) == 0:
cos_scores[int(cell_idx)] = np.nan
continue
cos = float(np.dot(tangent, velocity) / (np.linalg.norm(tangent) * np.linalg.norm(velocity)))
cos_scores[int(cell_idx)] = cos
edge_cosines[(src, tgt)] = cos_scores
edge_mean_scores: Dict[Tuple[str, str], float] = {}
all_values: list[ArrayLike] = []
for edge, mapping in edge_cosines.items():
values = np.array(list(mapping.values()), dtype=float)
if values.size == 0:
edge_mean_scores[edge] = np.nan
continue
edge_mean_scores[edge] = float(np.nanmean(values))
all_values.append(values)
overall_cell_weighted_mean = float(np.nanmean(np.concatenate(all_values))) if all_values else np.nan
overall_by_edge = (
float(np.nanmean(list(edge_mean_scores.values()))) if edge_mean_scores else np.nan
)
next_map: Dict[str, list[str]] = {}
for src, tgt in edges:
next_map.setdefault(src, []).append(tgt)
best_cosine_per_cell = np.full(X.shape[0], np.nan, dtype=float)
best_target_per_cell = np.full(X.shape[0], "", dtype=object)
for src, targets in next_map.items():
src_indices = idx_by_cluster.get(src, np.empty(0, dtype=int))
if src_indices.size == 0:
continue
stacks = []
for tgt in targets:
mapping = edge_cosines.get((src, tgt), {})
scores = np.array([mapping.get(int(cell_idx), np.nan) for cell_idx in src_indices], dtype=float)
stacks.append(scores)
if not stacks:
continue
matrix = np.vstack(stacks)
best_idx = np.nanargmax(matrix, axis=0)
row_indices = np.arange(matrix.shape[1])
best_values = matrix[best_idx, row_indices]
best_cosine_per_cell[src_indices] = best_values
targets_array = np.asarray(targets, dtype=object)
best_target_per_cell[src_indices] = targets_array[best_idx]
valid_mask = np.isfinite(best_cosine_per_cell)
batc_overall = float(np.nanmean(best_cosine_per_cell[valid_mask])) if valid_mask.any() else np.nan
if store_in_adata:
adata.uns["batc_overall"] = batc_overall
return batc_overall