Source code for cellucid.vector_fields

"""
Vector field utilities for Cellucid.

Cellucid visualizes "vector fields" (animated arrow/flow overlays) as **per-cell
displacement vectors in embedding space**. This module provides helpers to
derive such vectors from common CellRank outputs (transition matrices) and to
store them in AnnData `obsm` using Cellucid’s naming convention.
"""

from __future__ import annotations

from typing import Optional, Union

import numpy as np
from scipy import sparse


[docs] def compute_transition_drift( transition_matrix: Union[np.ndarray, sparse.spmatrix], embedding: np.ndarray, *, normalize_rows: bool = True, ) -> np.ndarray: """ Compute a per-cell drift vector field from a transition matrix. The drift is defined as: drift = E[next_embedding | current_cell] - current_embedding Where: E[next_embedding] = T @ embedding Parameters ---------- transition_matrix: A (n_cells, n_cells) matrix. Sparse matrices are recommended for large datasets. Typical source: CellRank kernels' transition matrices. embedding: A (n_cells, dim) array of embedding coordinates (e.g. `adata.obsm['X_umap']`). normalize_rows: If True (default), divides each row's expectation by its row-sum. This makes the drift invariant to a matrix that isn't strictly row-stochastic. Returns ------- np.ndarray Drift vectors shaped (n_cells, dim) as float32. """ X = np.asarray(embedding, dtype=np.float32) if X.ndim != 2: raise ValueError(f"embedding must be 2D, got shape {X.shape}") if sparse.issparse(transition_matrix): T = transition_matrix.tocsr() TX = T.dot(X) if normalize_rows: row_sums = np.asarray(T.sum(axis=1)).reshape(-1) safe = row_sums.copy() safe[safe == 0] = 1.0 TX = TX / safe[:, None] else: T = np.asarray(transition_matrix, dtype=np.float32) if T.ndim != 2: raise ValueError(f"transition_matrix must be 2D, got shape {T.shape}") TX = T @ X if normalize_rows: row_sums = T.sum(axis=1) safe = np.where(row_sums == 0, 1.0, row_sums).astype(np.float32) TX = TX / safe[:, None] if TX.shape != X.shape: raise ValueError(f"T @ embedding produced shape {TX.shape}, expected {X.shape}") return (TX - X).astype(np.float32, copy=False)
def _resolve_embedding_key( adata: "anndata.AnnData", *, basis: str, dim: int, ) -> str: """ Resolve an AnnData `obsm` key for an embedding basis + dimension. Prefers explicit Cellucid-style keys (e.g. `X_umap_2d`) and falls back to Scanpy-style `X_umap` if it matches the requested dimension. """ key_explicit = f"X_{basis}_{dim}d" if key_explicit in adata.obsm: return key_explicit key_base = f"X_{basis}" if key_base in adata.obsm: arr = adata.obsm[key_base] shape = arr.shape if len(shape) == 2 and shape[1] == dim: return key_base raise ValueError(f"{key_base} is {shape[1]}D, expected {dim}D") raise KeyError(f"Missing embedding in obsm: {key_explicit} (or {key_base})")
[docs] def add_transition_drift_to_obsm( adata: "anndata.AnnData", transition_matrix: Union[np.ndarray, sparse.spmatrix], *, basis: str = "umap", field_prefix: str = "T_fwd", dim: Optional[int] = None, explicit_dim_suffix: bool = True, normalize_rows: bool = True, overwrite: bool = False, ) -> str: """ Compute drift vectors from a transition matrix and store them in `adata.obsm`. The output key follows Cellucid’s vector field naming convention: - Explicit: `<field_prefix>_<basis>_<dim>d` (default) - Implicit: `<field_prefix>_<basis>` Examples -------- - Forward drift in UMAP (2D): `T_fwd_umap_2d` - Backward drift in UMAP (2D): `T_bwd_umap_2d` Parameters ---------- adata: AnnData object to modify. transition_matrix: (n_cells, n_cells) matrix, dense or sparse. basis: Embedding basis name (default: "umap"). field_prefix: Prefix for the vector field (default: "T_fwd"). Use e.g. "T_bwd" for backward matrices. dim: Embedding dimensionality to use. If None, it is inferred from the best available embedding in `obsm`: `X_{basis}_3d` → `X_{basis}_2d` → `X_{basis}_1d` → `X_{basis}` (if 1D/2D/3D). explicit_dim_suffix: If True, appends `_{dim}d` to the output key (recommended for clash-safe imports). normalize_rows: Whether to row-normalize during drift computation (see `compute_transition_drift`). overwrite: If False (default), raises if the target key already exists. Returns ------- str The `obsm` key written. """ if dim is None: # Prefer explicit Cellucid-style keys (X_<basis>_<dim>d) and fall back to # Scanpy-style X_<basis> only if it matches 1D/2D/3D. for candidate_dim in (3, 2, 1): explicit_key = f"X_{basis}_{candidate_dim}d" if explicit_key not in adata.obsm: continue shape = adata.obsm[explicit_key].shape if len(shape) != 2 or shape[1] != candidate_dim: raise ValueError(f"{explicit_key} has shape {shape}, expected (n_cells, {candidate_dim})") dim = candidate_dim break if dim is None: base_key = f"X_{basis}" if base_key not in adata.obsm: raise KeyError( f"dim not provided and no embedding found: " f"{base_key} or X_{basis}_{{1,2,3}}d" ) shape = adata.obsm[base_key].shape if len(shape) == 1: dim = 1 elif len(shape) == 2 and shape[1] in (1, 2, 3): dim = int(shape[1]) else: raise ValueError(f"Cannot infer dim from {base_key} with shape {shape}") assert dim is not None emb_key = _resolve_embedding_key(adata, basis=basis, dim=dim) X = np.asarray(adata.obsm[emb_key], dtype=np.float32) drift = compute_transition_drift(transition_matrix, X, normalize_rows=normalize_rows) base_id = f"{field_prefix}_{basis}" out_key = f"{base_id}_{dim}d" if explicit_dim_suffix else base_id if not overwrite and out_key in adata.obsm: raise KeyError(f"adata.obsm already contains key '{out_key}' (set overwrite=True to replace)") adata.obsm[out_key] = drift return out_key