"""
Apply a Cellucid `.cellucid-session` bundle to an `anndata.AnnData`.
Session bundles are treated as untrusted input:
- bounds checks on indices/codes
- dataset mismatch policies
- opt-in destructive mutations
"""
from __future__ import annotations
import logging
import re
from dataclasses import dataclass
from typing import Any, Literal
import numpy as np
from .session_bundle import CellucidSessionBundle
from .session_codecs import decode_delta_uvarint, decode_user_defined_codes
logger = logging.getLogger("cellucid.anndata_session")
DatasetMismatchPolicy = Literal["error", "warn_skip", "skip"]
ColumnConflictPolicy = Literal["error", "overwrite", "suffix"]
def _ensure_cellucid_uns(uns: dict[str, Any]) -> dict[str, Any]:
root = uns.setdefault("cellucid", {})
if not isinstance(root, dict):
raise TypeError("adata.uns['cellucid'] must be a dict if present")
return root
def _safe_key(s: str) -> str:
key = re.sub(r"[^0-9a-zA-Z_]+", "_", (s or "").strip())
key = re.sub(r"_+", "_", key).strip("_")
if not key:
return "cellucid"
if key[0].isdigit():
return f"_{key}"
return key
def _resolve_column_name(
existing: set[str],
name: str,
policy: ColumnConflictPolicy,
) -> str:
if name not in existing:
return name
if policy == "error":
raise ValueError(f"Column already exists: {name}")
if policy == "overwrite":
return name
if policy != "suffix":
raise ValueError(f"Unknown column conflict policy: {policy}")
i = 2
while True:
candidate = f"{name}__{i}"
if candidate not in existing:
return candidate
i += 1
@dataclass(frozen=True)
class ApplySummary:
added_obs_columns: list[str]
added_var_columns: list[str]
skipped_due_to_mismatch: bool
mismatch_reasons: list[str]
[docs]
def apply_cellucid_session_to_anndata(
bundle: CellucidSessionBundle | str,
adata: "Any",
*,
inplace: bool = False,
dataset_mismatch: DatasetMismatchPolicy = "warn_skip",
expected_dataset_id: str | None = None,
add_highlights: bool = True,
highlights_prefix: str = "cellucid_highlight__",
add_user_defined_fields: bool = True,
user_defined_prefix: str = "",
include_deleted_user_defined_fields: bool = False,
store_uns: bool = True,
column_conflict: ColumnConflictPolicy = "suffix",
return_summary: bool = False,
) -> "Any" | tuple["Any", ApplySummary]:
"""
Apply a `.cellucid-session` bundle onto an AnnData object.
By default returns the (possibly copied) `adata`.
If `return_summary=True`, returns `(adata, summary)`.
"""
try:
import pandas as pd # type: ignore
except Exception as e: # pragma: no cover
raise ImportError("apply_cellucid_session_to_anndata requires pandas") from e
if isinstance(bundle, str):
bundle = CellucidSessionBundle(bundle)
if not inplace:
adata = adata.copy()
mismatch_reasons: list[str] = []
fp = bundle.dataset_fingerprint or {}
fp_cells = fp.get("cellCount")
fp_vars = fp.get("varCount")
fp_dataset_id = fp.get("datasetId")
if isinstance(fp_cells, int) and fp_cells != getattr(adata, "n_obs", None):
mismatch_reasons.append(f"cellCount {fp_cells} != adata.n_obs {getattr(adata, 'n_obs', None)}")
if isinstance(fp_vars, int) and fp_vars != getattr(adata, "n_vars", None):
mismatch_reasons.append(f"varCount {fp_vars} != adata.n_vars {getattr(adata, 'n_vars', None)}")
if expected_dataset_id is not None and fp_dataset_id is not None and fp_dataset_id != expected_dataset_id:
mismatch_reasons.append(f"datasetId {fp_dataset_id!r} != expected_dataset_id {expected_dataset_id!r}")
has_mismatch = len(mismatch_reasons) > 0
skip_dataset_dependent = False
if has_mismatch:
if dataset_mismatch == "error":
raise ValueError("Dataset fingerprint mismatch: " + "; ".join(mismatch_reasons))
if dataset_mismatch == "warn_skip":
logger.warning("Dataset fingerprint mismatch; skipping dataset-dependent chunks: %s", mismatch_reasons)
skip_dataset_dependent = True
elif dataset_mismatch == "skip":
skip_dataset_dependent = True
else:
raise ValueError(f"Unknown dataset_mismatch policy: {dataset_mismatch}")
added_obs: list[str] = []
added_var: list[str] = []
if store_uns:
cellucid_uns = _ensure_cellucid_uns(adata.uns)
session_uns = cellucid_uns.setdefault("session", {})
if not isinstance(session_uns, dict):
raise TypeError("adata.uns['cellucid']['session'] must be a dict if present")
session_uns["manifest"] = bundle.manifest
session_uns["dataset_fingerprint"] = fp
session_uns["applied"] = {
"dataset_mismatch_policy": dataset_mismatch,
"expected_dataset_id": expected_dataset_id,
"skip_dataset_dependent": skip_dataset_dependent,
}
chunk_ids = set(bundle.list_chunk_ids())
# ---------------------------------------------------------------------
# Highlights → adata.obs boolean columns
# ---------------------------------------------------------------------
if add_highlights and not skip_dataset_dependent and "highlights/meta" in chunk_ids:
meta = bundle.decode_chunk("highlights/meta")
if store_uns:
_ensure_cellucid_uns(adata.uns).setdefault("session", {}).setdefault("chunks", {})["highlights/meta"] = meta
pages = meta.get("pages") if isinstance(meta, dict) else None
if isinstance(pages, list):
existing_cols = set(adata.obs.columns)
highlights_uns = None
if store_uns:
session = _ensure_cellucid_uns(adata.uns).setdefault("session", {})
highlights_uns = session.setdefault("highlights", {})
if not isinstance(highlights_uns, dict):
highlights_uns = None
for page in pages:
if not isinstance(page, dict):
continue
page_id = page.get("id")
page_name = page.get("name")
groups = page.get("highlightedGroups") or []
if not isinstance(groups, list):
continue
for group in groups:
if not isinstance(group, dict):
continue
group_id = group.get("id")
if not isinstance(group_id, str) or not group_id:
continue
membership_chunk_id = f"highlights/cells/{group_id}"
if membership_chunk_id in chunk_ids:
raw = bundle.decode_chunk(membership_chunk_id)
indices = decode_delta_uvarint(
raw,
max_count=int(getattr(adata, "n_obs", 0)),
max_index=int(getattr(adata, "n_obs", 0)) - 1,
)
else:
indices = np.empty(0, dtype=np.uint32)
base_name = f"{highlights_prefix}{_safe_key(group_id)}"
col_name = _resolve_column_name(existing_cols, base_name, column_conflict)
existing_cols.add(col_name)
mask = np.zeros(int(getattr(adata, "n_obs", 0)), dtype=bool)
if indices.size > 0:
mask[indices] = True
adata.obs[col_name] = pd.Series(mask, index=adata.obs_names)
added_obs.append(col_name)
if isinstance(highlights_uns, dict):
highlights_uns.setdefault("groups", {})[group_id] = {
"obs_column": col_name,
"page_id": page_id,
"page_name": page_name,
"group": group,
}
# ---------------------------------------------------------------------
# User-defined categorical fields → adata.obs/adata.var
# ---------------------------------------------------------------------
if add_user_defined_fields and not skip_dataset_dependent and "core/field-overlays" in chunk_ids:
overlays = bundle.decode_chunk("core/field-overlays")
if store_uns:
_ensure_cellucid_uns(adata.uns).setdefault("session", {}).setdefault("chunks", {})[
"core/field-overlays"
] = overlays
udf = overlays.get("userDefinedFields") if isinstance(overlays, dict) else None
if isinstance(udf, list):
existing_obs = set(adata.obs.columns)
existing_var = set(getattr(adata, "var", pd.DataFrame()).columns)
for field in udf:
if not isinstance(field, dict):
continue
if field.get("kind") != "category":
continue
if field.get("isDeleted") is True and not include_deleted_user_defined_fields:
continue
field_id = field.get("id")
if not isinstance(field_id, str) or not field_id:
continue
source = field.get("source")
target = "var" if source == "var" else "obs"
codes_chunk_id = f"user-defined/codes/{field_id}"
if codes_chunk_id not in chunk_ids:
continue
raw = bundle.decode_chunk(codes_chunk_id)
codes = decode_user_defined_codes(raw)
if target == "obs":
expected_len = int(getattr(adata, "n_obs", 0))
if codes.shape[0] != expected_len:
logger.warning(
"Skipping user-defined field %s: codes length %s != adata.n_obs %s",
field_id,
codes.shape[0],
expected_len,
)
continue
else:
expected_len = int(getattr(adata, "n_vars", 0))
if codes.shape[0] != expected_len:
logger.warning(
"Skipping user-defined var field %s: codes length %s != adata.n_vars %s",
field_id,
codes.shape[0],
expected_len,
)
continue
categories = field.get("categories") if isinstance(field.get("categories"), list) else []
categories = [str(c) for c in categories]
base_key = str(field.get("key") or field_id)
col_base = f"{user_defined_prefix}{_safe_key(base_key)}"
if target == "obs":
col_name = _resolve_column_name(existing_obs, col_base, column_conflict)
existing_obs.add(col_name)
else:
col_name = _resolve_column_name(existing_var, col_base, column_conflict)
existing_var.add(col_name)
# Pandas uses -1 for NA in categorical codes; sanitize out-of-range codes.
codes_i32 = codes.astype(np.int32, copy=False)
if categories:
invalid = (codes_i32 < 0) | (codes_i32 >= len(categories))
if invalid.any():
codes_i32 = codes_i32.copy()
codes_i32[invalid] = -1
cat = pd.Categorical.from_codes(codes_i32, categories=categories, ordered=False)
else:
# No categories list; store raw codes as integers.
cat = pd.Series(codes_i32)
if target == "obs":
adata.obs[col_name] = pd.Series(cat, index=adata.obs_names)
added_obs.append(col_name)
else:
adata.var[col_name] = pd.Series(cat, index=adata.var_names)
added_var.append(col_name)
summary = ApplySummary(
added_obs_columns=added_obs,
added_var_columns=added_var,
skipped_due_to_mismatch=skip_dataset_dependent,
mismatch_reasons=mismatch_reasons,
)
if return_summary:
return adata, summary
return adata