#!/usr/bin/env python3
"""
Export raw dataframes/arrays to files used by the WebGL viewer.
Includes memory/disk optimization features:
- Quantization for continuous data (var/gene expression and obs continuous)
- Auto dtype selection for categorical obs based on category count
- Gzip compression for all binary files
Instead of AnnData, accepts:
- latent_space: (n_cells, n_dims) numpy/sparse array for outlier quantile calculation
- X_umap_1d / X_umap_2d / X_umap_3d: explicit embeddings (at least one required)
- vector_fields: dict[str, array] of per-cell displacement vectors (optional)
- obs: pandas DataFrame with cell metadata columns
- var: pandas DataFrame with gene/feature metadata
- gene_expression: (n_cells, n_genes) numpy/sparse array for gene expression matrix
- var_gene_id_column: column name in var containing gene identifiers, or "index" to use var.index
- connectivities: sparse matrix with KNN connectivities
"""
import gzip
import json
import re
import tqdm
from datetime import datetime, timezone
from pathlib import Path
from typing import Optional, Sequence, Union, Literal
import numpy as np
import pandas as pd
from scipy import sparse
DEFAULT_EXPORT_DIR = Path.cwd() / "exports"
DEFAULT_OBS_DIRNAME = "obs"
DEFAULT_VAR_DIRNAME = "var"
DEFAULT_CONNECTIVITY_DIRNAME = "connectivity"
# Manifest format version for compact format
MANIFEST_FORMAT_VERSION = "compact_v1"
def _safe_filename_component(name: str) -> str:
"""Return a filesystem-friendly version of a field key."""
safe = re.sub(r"[^A-Za-z0-9._-]+", "_", str(name))
safe = safe.strip("._")
return safe or "field"
def _to_dense(arr: Union[np.ndarray, sparse.spmatrix]) -> np.ndarray:
"""Convert sparse matrix to dense numpy array if necessary."""
if sparse.issparse(arr):
return np.asarray(arr.toarray())
return np.asarray(arr)
def _file_exists_skip(path: Path, description: str, force: bool = False) -> bool:
"""Check if file exists and should be skipped. Returns True if should skip."""
if path.exists() and not force:
print(f"⚠ Skipping {description}: {path} already exists (use force=True to overwrite)")
return True
return False
def _write_binary(
path: Path,
data: np.ndarray,
compression: Optional[int] = None,
) -> Path:
"""
Write binary data, optionally with gzip compression.
Parameters
----------
path : Path
Output path. If compression is enabled, '.gz' will be appended.
data : np.ndarray
Data to write.
compression : int or None
Gzip compression level (1-9). None or 0 means no compression.
Returns
-------
Path
Actual path written (may have .gz suffix).
"""
if compression and compression > 0:
gz_path = Path(str(path) + ".gz")
with gzip.open(gz_path, 'wb', compresslevel=compression) as f:
f.write(data.tobytes())
return gz_path
else:
data.tofile(path)
return path
def _quantize_continuous(
values: np.ndarray,
bits: int = 8,
field_name: str = "unknown",
) -> tuple[np.ndarray, float, float, float, dict]:
"""
Quantize continuous float32 values to uint8 or uint16.
Parameters
----------
values : np.ndarray
Float32 values to quantize.
bits : int
Number of bits for quantization (8 or 16).
field_name : str
Name of field for debug messages.
Returns
-------
quantized : np.ndarray
Quantized values as uint8 or uint16.
min_val : float
Minimum value (for dequantization).
max_val : float
Maximum value (for dequantization).
scale : float
Scale factor (for dequantization).
stats : dict
Statistics about the quantization for debugging.
"""
n_total = len(values)
# Identify problematic values
nan_mask = np.isnan(values)
inf_mask = np.isinf(values)
invalid_mask = nan_mask | inf_mask
valid_mask = ~invalid_mask
n_nan = int(nan_mask.sum())
n_inf = int(inf_mask.sum())
n_valid = int(valid_mask.sum())
stats = {
"n_total": n_total,
"n_valid": n_valid,
"n_nan": n_nan,
"n_inf": n_inf,
}
if n_valid == 0:
# All invalid values
min_val, max_val = 0.0, 1.0
stats["warning"] = "all_invalid"
else:
valid_values = values[valid_mask]
min_val = float(np.min(valid_values))
max_val = float(np.max(valid_values))
stats["data_min"] = min_val
stats["data_max"] = max_val
# Avoid division by zero
if max_val == min_val:
max_val = min_val + 1.0
stats["constant_value"] = True
if bits == 8:
max_quant = 254 # Reserve 255 for NaN/Inf
dtype = np.uint8
nan_value = 255
else: # 16 bits
max_quant = 65534 # Reserve 65535 for NaN/Inf
dtype = np.uint16
nan_value = 65535
scale = max_quant / (max_val - min_val)
# Create output array
quantized = np.empty(n_total, dtype=dtype)
# Only quantize valid values to avoid numpy warnings
if n_valid > 0:
normalized = (values[valid_mask] - min_val) * scale
quantized[valid_mask] = np.clip(normalized, 0, max_quant).astype(dtype)
# Set invalid values to reserved marker
quantized[invalid_mask] = nan_value
return quantized, min_val, max_val, scale, stats
def _select_category_dtype(n_categories: int) -> tuple[np.dtype, int]:
"""
Select optimal dtype for category codes based on number of categories.
Parameters
----------
n_categories : int
Number of unique categories (not counting missing).
Returns
-------
dtype : np.dtype
Optimal dtype (uint8 or uint16).
missing_value : int
Value to use for missing/NaN codes.
"""
if n_categories <= 254:
# uint8 can hold 0-254 for categories, 255 for missing
return np.uint8, 255
else:
# uint16 can hold 0-65534 for categories, 65535 for missing
return np.uint16, 65535
def _report_quantization_stats(field_name: str, stats: dict, field_type: str = "field") -> None:
"""
Report quantization statistics and warnings.
Parameters
----------
field_name : str
Name of the field being quantized.
stats : dict
Statistics from _quantize_continuous.
field_type : str
Type of field for messages (e.g., "obs field", "gene", "outlier quantiles").
"""
n_nan = stats.get("n_nan", 0)
n_inf = stats.get("n_inf", 0)
n_total = stats.get("n_total", 0)
n_valid = stats.get("n_valid", 0)
issues = []
if n_nan > 0:
pct = 100 * n_nan / n_total if n_total > 0 else 0
issues.append(f"{n_nan:,} NaN values ({pct:.1f}%)")
if n_inf > 0:
pct = 100 * n_inf / n_total if n_total > 0 else 0
issues.append(f"{n_inf:,} Inf values ({pct:.1f}%)")
if stats.get("warning") == "all_invalid":
print(f" ⚠ WARNING: {field_type} '{field_name}' has NO valid values (all NaN/Inf)")
elif issues:
print(f" ⚠ {field_type} '{field_name}': {', '.join(issues)} → mapped to missing marker")
if stats.get("constant_value"):
print(f" ℹ {field_type} '{field_name}' has constant value (min=max)")
def _check_data_quality(values: np.ndarray, field_name: str, field_type: str = "field") -> None:
"""
Check data quality and print warnings for common issues.
Parameters
----------
values : np.ndarray
Values to check.
field_name : str
Name of the field.
field_type : str
Type of field for messages.
"""
n_total = len(values)
if n_total == 0:
print(f" ⚠ WARNING: {field_type} '{field_name}' is empty")
return
n_nan = int(np.isnan(values).sum())
n_inf = int(np.isinf(values).sum())
n_neg_inf = int(np.isneginf(values).sum())
n_pos_inf = int(np.isposinf(values).sum())
issues = []
if n_nan > 0:
pct = 100 * n_nan / n_total
issues.append(f"{n_nan:,} NaN ({pct:.1f}%)")
if n_neg_inf > 0:
issues.append(f"{n_neg_inf:,} -Inf")
if n_pos_inf > 0:
issues.append(f"{n_pos_inf:,} +Inf")
if issues:
print(f" ⚠ {field_type} '{field_name}' contains: {', '.join(issues)}")
def _compute_centroids_for_field(
coords: np.ndarray,
codes: np.ndarray,
categories: list[str],
outlier_quantile: float = 0.95,
min_points: int = 10,
) -> list[dict]:
"""
Compute centroids per category with outlier removal (based on embedding coords for display).
Works with any dimensionality (1D, 2D, 3D, etc.) - the position will have the same
number of dimensions as the input coords.
"""
if coords.shape[0] != codes.shape[0]:
raise ValueError("coords and codes must have the same length.")
centroids: list[dict] = []
if not (0.5 < outlier_quantile < 1.0):
outlier_quantile = 0.95
for code, label in enumerate(categories):
mask = codes == code
idx = np.nonzero(mask)[0]
n = idx.size
if n < min_points:
continue
pts = coords[idx, :] # (n, ndim)
center = pts.mean(axis=0)
if n > min_points:
dists = np.linalg.norm(pts - center, axis=1)
thr = float(np.quantile(dists, outlier_quantile))
inlier_mask = dists <= thr
n_in = int(inlier_mask.sum())
if n_in >= min_points:
pts_in = pts[inlier_mask, :]
center = pts_in.mean(axis=0)
used_count = n_in
else:
used_count = n
else:
used_count = n
centroids.append(
{
"category": str(label),
"position": center.astype(float).tolist(),
"n_points": int(used_count),
}
)
return centroids
def _compute_centroids_for_all_dimensions(
embeddings: dict[int, np.ndarray],
codes: np.ndarray,
categories: list[str],
outlier_quantile: float = 0.95,
min_points: int = 10,
) -> dict[int, list[dict]]:
"""
Compute centroids for each available dimension.
Returns a dictionary keyed by dimension (1, 2, 3) with centroid lists for each.
"""
centroids_by_dim = {}
for dim, coords in embeddings.items():
centroids_by_dim[dim] = _compute_centroids_for_field(
coords, codes, categories, outlier_quantile, min_points
)
return centroids_by_dim
def _compute_latent_space_quantiles(
latent: np.ndarray,
codes: np.ndarray,
categories: list[str],
min_points: int = 10,
) -> np.ndarray:
"""
Compute per-cell outlier quantiles based on latent space distances to category centroids.
"""
n_cells = latent.shape[0]
quantiles = np.full(n_cells, np.nan, dtype=np.float32)
for code, label in enumerate(categories):
mask = codes == code
idx = np.nonzero(mask)[0]
n = idx.size
if n < min_points:
continue
pts = latent[idx, :]
centroid = pts.mean(axis=0)
dists = np.linalg.norm(pts - centroid, axis=1)
sorted_dists = np.sort(dists)
ranks = np.searchsorted(sorted_dists, dists, side='right')
cell_quantiles = ranks.astype(np.float32) / n
quantiles[idx] = cell_quantiles
return quantiles
[docs]
def prepare(
latent_space: Optional[Union[np.ndarray, sparse.spmatrix]] = None,
obs: Optional[pd.DataFrame] = None,
var: Optional[pd.DataFrame] = None,
gene_expression: Optional[Union[np.ndarray, sparse.spmatrix]] = None,
var_gene_id_column: str = "index",
gene_identifiers: Optional[Sequence[str]] = None,
connectivities: Optional[sparse.spmatrix] = None,
out_dir: Path | str = DEFAULT_EXPORT_DIR,
obs_keys: Optional[Sequence[str]] = None,
centroid_outlier_quantile: float = 0.95,
centroid_min_points: int = 10,
obs_manifest_filename: str = "obs_manifest.json",
obs_binary_dirname: str = DEFAULT_OBS_DIRNAME,
var_manifest_filename: str = "var_manifest.json",
var_binary_dirname: str = DEFAULT_VAR_DIRNAME,
connectivity_manifest_filename: str = "connectivity_manifest.json",
connectivity_binary_dirname: str = DEFAULT_CONNECTIVITY_DIRNAME,
force: bool = False,
# Optimization parameters (all disabled by default)
var_quantization: Optional[int] = None,
obs_continuous_quantization: Optional[int] = None,
obs_categorical_dtype: Literal["auto", "uint8", "uint16"] = "auto",
compression: Optional[int] = None,
# Dataset metadata parameters (for dataset_identity.json)
dataset_name: Optional[str] = None,
dataset_description: Optional[str] = None,
dataset_id: Optional[str] = None,
source_name: Optional[str] = None,
source_url: Optional[str] = None,
source_citation: Optional[str] = None,
# Multi-dimensional embedding parameters (at least one required)
X_umap_1d: Optional[np.ndarray] = None,
X_umap_2d: Optional[np.ndarray] = None,
X_umap_3d: Optional[np.ndarray] = None,
X_umap_4d: Optional[np.ndarray] = None,
# Optional per-cell vector fields aligned to the embedding(s)
# (e.g. scVelo velocity, CellRank drift vectors)
vector_fields: Optional[dict[str, Union[np.ndarray, sparse.spmatrix]]] = None,
) -> None:
"""
Export raw data arrays to files used by the WebGL viewer.
Memory/Disk Optimization Options
--------------------------------
var_quantization : int or None
Bits for gene expression quantization (8, 16, or None for full float32).
8-bit reduces file size by 4x with minimal visual impact for colormapping.
obs_continuous_quantization : int or None
Bits for continuous obs field quantization (8, 16, or None for full float32).
obs_categorical_dtype : 'auto', 'uint8', or 'uint16'
- 'auto': Select based on number of categories (uint8 if ≤254, else uint16)
- 'uint8': Force uint8 (max 254 categories)
- 'uint16': Force uint16 (max 65534 categories)
compression : int or None
Gzip compression level (1-9). None or 0 disables compression.
Level 6 is a good balance of speed and size. Files get .gz extension.
Multi-Dimensional Embeddings
----------------------------
At least one dimensional embedding must be provided. The viewer supports
switching between different dimensionalities of the same data at runtime.
All embeddings must have the same number of cells (rows) but different
column counts matching their dimensionality.
IMPORTANT: Each embedding is normalized independently to fit within the
[-1, 1] coordinate range. Within each dimension, the same scale factor is
used for all axes to preserve aspect ratios. This ensures each dimension
fills the viewing area optimally without requiring manual zoom adjustment.
X_umap_1d : np.ndarray, optional
1D embedding coordinates, shape (n_cells, 1). Stored as points_1d.bin.
X_umap_2d : np.ndarray, optional
2D embedding coordinates, shape (n_cells, 2). Stored as points_2d.bin.
X_umap_3d : np.ndarray, optional
3D embedding coordinates, shape (n_cells, 3). Stored as points_3d.bin.
This is the primary visualization and is used for centroid computation.
X_umap_4d : np.ndarray, optional
4D embedding coordinates, shape (n_cells, 4). Stored as points_4d.bin.
NOTE: 4D visualization is not yet implemented in the viewer.
vector_fields : dict[str, np.ndarray] or None
Optional per-cell displacement vectors aligned to the embedding space.
Keys follow the same naming convention as AnnData ``obsm``:
- Explicit: ``<field>_umap_<dim>d`` (e.g. ``velocity_umap_2d``, ``T_fwd_umap_3d``)
- Implicit: ``<field>_umap`` with shape ``(n_cells, 1|2|3)``
(used only if the explicit key for that dim is not provided)
Each value must be shaped ``(n_cells, dim)`` (or ``(n_cells,)`` for 1D).
Vectors are scaled by the same per-dimension normalization scale as points.
Standard Parameters
-------------------
latent_space : np.ndarray or sparse matrix
Latent space for outlier quantile calculation, shape (n_cells, n_dims).
obs : pd.DataFrame
Cell metadata, shape (n_cells, n_obs_columns).
var : pd.DataFrame, optional
Gene/feature metadata. Required if gene_expression is provided.
gene_expression : np.ndarray or sparse matrix, optional
Gene expression matrix, shape (n_cells, n_genes).
var_gene_id_column : str
Column name in var containing gene identifiers, or "index" to use var.index.
gene_identifiers : sequence of str, optional
Which genes to export. If None, all genes are exported.
connectivities : sparse matrix, optional
KNN connectivity matrix from scanpy (n_cells, n_cells).
out_dir : Path or str
Output directory (default: exports/ under the current working directory).
obs_keys : sequence of str or None
Which obs columns to export. If None, all columns are exported.
centroid_outlier_quantile : float
Quantile of distances to keep as inliers when computing centroids.
centroid_min_points : int
Minimum number of points in a category to compute a centroid.
force : bool
If True, overwrite existing files. If False, skip files that already exist.
Dataset Metadata Parameters
---------------------------
dataset_name : str, optional
Human-readable name for the dataset (e.g., "Human Lung Cell Atlas").
If not provided, defaults to the output directory name.
dataset_description : str, optional
Description of the dataset.
dataset_id : str, optional
Unique identifier for the dataset. If not provided, a filesystem-safe
version of the dataset_name is used.
source_name : str, optional
Name of the data source (e.g., "HLCA Consortium").
source_url : str, optional
URL to the data source.
source_citation : str, optional
Citation text for the data source.
"""
out_dir = Path(out_dir)
out_dir.mkdir(parents=True, exist_ok=True)
obs_binary_dir = out_dir / obs_binary_dirname
obs_binary_dir.mkdir(parents=True, exist_ok=True)
# Normalize compression parameter
if compression is not None and compression <= 0:
compression = None
# =========================================================================
# MULTI-DIMENSIONAL EMBEDDING VALIDATION & PROCESSING
# =========================================================================
# Collect all provided embeddings
embeddings: dict[int, np.ndarray] = {}
if X_umap_1d is not None:
embeddings[1] = np.asarray(X_umap_1d, dtype=np.float32)
if X_umap_2d is not None:
embeddings[2] = np.asarray(X_umap_2d, dtype=np.float32)
if X_umap_3d is not None:
embeddings[3] = np.asarray(X_umap_3d, dtype=np.float32)
if X_umap_4d is not None:
# 4D is a hook for future development - raise error for now
raise NotImplementedError(
"4D visualization is not yet implemented. "
"The X_umap_4d parameter is reserved for future development. "
"Please use X_umap_1d, X_umap_2d, or X_umap_3d for now."
)
if not embeddings:
raise ValueError(
"At least one dimensional embedding must be provided. "
"Use X_umap_1d, X_umap_2d, or X_umap_3d."
)
# Validate each embedding has correct dimensions
n_cells = None
for dim, arr in embeddings.items():
if arr.ndim != 2:
raise ValueError(
f"X_umap_{dim}d must be a 2D array, got shape {arr.shape}."
)
if arr.shape[1] != dim:
raise ValueError(
f"X_umap_{dim}d must have exactly {dim} columns, got {arr.shape[1]}. "
f"Shape is {arr.shape}."
)
if n_cells is None:
n_cells = arr.shape[0]
elif arr.shape[0] != n_cells:
raise ValueError(
f"All embeddings must have the same number of cells. "
f"First embedding has {n_cells} cells, but X_umap_{dim}d has {arr.shape[0]} cells."
)
# =========================================================================
# NORMALIZE EACH EMBEDDING INDEPENDENTLY TO FIT WITHIN [-1, 1] RANGE
# =========================================================================
# Each dimensional embedding (1D, 2D, 3D) is normalized independently so that
# it fills the viewing area optimally. Within each dimension, we use the same
# scale factor for all axes to preserve aspect ratios.
#
# This ensures that switching between dimensions doesn't require manual zoom
# adjustments - each dimension will fill the view appropriately.
normalization_info = {}
for dim, arr in embeddings.items():
# Find min/max for each axis
axis_mins = arr.min(axis=0)
axis_maxs = arr.max(axis=0)
axis_ranges = axis_maxs - axis_mins
# Use the maximum range across all axes to preserve aspect ratio
max_range = float(axis_ranges.max())
if max_range < 1e-8:
max_range = 1.0 # Avoid division by zero for degenerate data
# Center of the bounding box
center = (axis_mins + axis_maxs) / 2
# Scale to fit in [-1, 1] based on the max range (preserves aspect ratio)
scale_factor = 2.0 / max_range
embeddings[dim] = ((arr - center) * scale_factor).astype(np.float32)
# Store info for logging
normalization_info[dim] = {
'original_range': max_range,
'center': center.tolist(),
'scale_factor': scale_factor,
}
# Determine primary 3D coords for centroid computation
# Priority: 3D > 2D (padded) > 1D (padded)
# Note: 4D support is reserved for future development
if 3 in embeddings:
coords3d = embeddings[3]
elif 2 in embeddings:
# Pad 2D to 3D with zeros in Z
coords3d = np.hstack([embeddings[2], np.zeros((n_cells, 1), dtype=np.float32)])
elif 1 in embeddings:
# Pad 1D to 3D with zeros in Y and Z
coords3d = np.hstack([embeddings[1], np.zeros((n_cells, 2), dtype=np.float32)])
else:
raise ValueError("No usable embedding for 3D coordinates.")
# Track available dimensions for metadata
available_dimensions = sorted(embeddings.keys())
# Determine default dimension (priority: 3D > 2D > 1D)
default_dimension = 3 if 3 in embeddings else (2 if 2 in embeddings else 1)
# Print export settings summary
print("=" * 60)
print("Export Settings:")
print(f" Output directory: {out_dir}")
print(f" Compression: {'gzip level ' + str(compression) if compression else 'disabled'}")
print(f" Var (gene) quantization: {str(var_quantization) + '-bit' if var_quantization else 'disabled (float32)'}")
print(f" Obs continuous quantization: {str(obs_continuous_quantization) + '-bit' if obs_continuous_quantization else 'disabled (float32)'}")
print(f" Obs categorical dtype: {obs_categorical_dtype}")
print(f" Available dimensions: {available_dimensions}")
print(f" Default dimension: {default_dimension}D")
print(f" Coordinate normalization (per-dimension, aspect-ratio preserved):")
for dim in sorted(normalization_info.keys()):
info = normalization_info[dim]
print(f" {dim}D: range {info['original_range']:.2f} → [-1, 1]")
print("=" * 60)
# Validate and convert latent space
if latent_space is None:
raise ValueError("latent_space is required for outlier quantile calculation.")
latent = _to_dense(latent_space).astype(np.float32)
if latent.shape[0] != n_cells:
raise ValueError(
f"Latent space has {latent.shape[0]} cells, but embeddings have {n_cells} cells."
)
# Validate obs
if obs is None:
raise ValueError("obs DataFrame is required.")
if len(obs) != n_cells:
raise ValueError(
f"obs has {len(obs)} rows, but embeddings have {n_cells} cells."
)
# =========================================================================
# SAVE DIMENSIONAL EMBEDDING FILES
# =========================================================================
for dim, arr in embeddings.items():
dim_filename = f"points_{dim}d.bin"
dim_path = out_dir / dim_filename
check_path = Path(str(dim_path) + ".gz") if compression and compression > 0 else dim_path
if _file_exists_skip(check_path, check_path.name, force):
pass
else:
actual_path = _write_binary(dim_path, arr, compression)
suffix = " (gzip)" if compression else ""
print(f"✓ Wrote {dim}D positions ({arr.shape[0]:,} cells × {dim} dims) to {actual_path}{suffix}")
# =========================================================================
# SAVE VECTOR FIELDS (OPTIONAL)
# =========================================================================
vector_fields_identity: Optional[dict] = None
if vector_fields:
if not isinstance(vector_fields, dict):
raise TypeError("vector_fields must be a dict[str, array]")
vectors_dir = out_dir / "vectors"
vectors_dir.mkdir(parents=True, exist_ok=True)
def _infer_vector_shape(arr: Union[np.ndarray, sparse.spmatrix], name: str) -> tuple[int, np.ndarray]:
dense = _to_dense(arr)
dense = np.asarray(dense)
if dense.ndim == 1:
dense = dense.reshape(-1, 1)
if dense.ndim != 2:
raise ValueError(f"Vector field '{name}' must be 1D or 2D array, got shape {dense.shape}")
if dense.shape[1] not in (1, 2, 3):
raise ValueError(f"Vector field '{name}' must have 1/2/3 components, got shape {dense.shape}")
return int(dense.shape[1]), dense
def _label_for(field_id: str) -> str:
base = field_id[:-5] if field_id.endswith("_umap") else field_id
base = base.replace("_", " ").strip()
titled = (base[:1].upper() + base[1:]) if base else field_id
return f"{titled} (UMAP)" if field_id.endswith("_umap") else titled
# First pass: collect explicit keys (<field>_<dim>d), then fill gaps with implicit keys.
grouped: dict[str, dict[int, Union[np.ndarray, sparse.spmatrix]]] = {}
explicit_dims: dict[str, set[int]] = {}
suffix_re = re.compile(r"^(.*)_([123])d$")
for name, arr in vector_fields.items():
if arr is None:
continue
key = str(name)
match = suffix_re.match(key)
if not match:
continue
field_id = match.group(1)
dim = int(match.group(2))
explicit_dims.setdefault(field_id, set()).add(dim)
grouped.setdefault(field_id, {})[dim] = arr
for name, arr in vector_fields.items():
if arr is None:
continue
key = str(name)
if suffix_re.match(key):
continue
inferred_dim, _dense = _infer_vector_shape(arr, key)
field_id = key
if inferred_dim in explicit_dims.get(field_id, set()):
continue # clash-safe: explicit wins
grouped.setdefault(field_id, {})[inferred_dim] = arr
fields_meta: dict[str, dict] = {}
gz_suffix = ".gz" if compression else ""
for field_id, by_dim in grouped.items():
safe_id = _safe_filename_component(field_id)
if safe_id != field_id:
raise ValueError(
f"Vector field id '{field_id}' contains unsupported characters. "
f"Use '{safe_id}' instead."
)
files: dict[str, str] = {}
dims: list[int] = []
for dim in sorted(by_dim.keys()):
if dim not in embeddings:
print(
f" ⚠ Skipping vector field '{field_id}' {dim}D: "
f"embedding points_{dim}d not provided"
)
continue
inferred_dim, dense = _infer_vector_shape(by_dim[dim], f"{field_id}_{dim}d")
if inferred_dim != dim:
raise ValueError(
f"Vector field '{field_id}' declared as {dim}D but has shape {dense.shape}"
)
if dense.shape[0] != n_cells:
raise ValueError(
f"Vector field '{field_id}' {dim}D has {dense.shape[0]} rows, expected {n_cells}"
)
vec = np.asarray(dense, dtype=np.float32)
scale_factor = float(normalization_info.get(dim, {}).get("scale_factor", 1.0))
if scale_factor != 1.0:
vec *= scale_factor
filename = f"{field_id}_{dim}d.bin"
path = vectors_dir / filename
check_path = Path(str(path) + ".gz") if compression and compression > 0 else path
if _file_exists_skip(check_path, check_path.name, force):
pass
else:
actual_path = _write_binary(path, vec, compression)
suffix = " (gzip)" if compression else ""
print(
f"✓ Wrote vector field '{field_id}' {dim}D "
f"({vec.shape[0]:,} cells × {dim} comps) to {actual_path}{suffix}"
)
files[f"{dim}d"] = f"vectors/{filename}{gz_suffix}"
dims.append(dim)
if not dims:
continue
entry = {
"label": _label_for(field_id),
"available_dimensions": dims,
"default_dimension": max(dims),
"files": files,
}
if field_id.endswith("_umap"):
entry["basis"] = "umap"
fields_meta[field_id] = entry
if fields_meta:
default_field = "velocity_umap" if "velocity_umap" in fields_meta else sorted(fields_meta.keys())[0]
vector_fields_identity = {
"default_field": default_field,
"fields": fields_meta,
}
# Decide which obs columns to export
if obs_keys is None:
obs_keys = list(obs.columns)
else:
obs_keys = list(obs_keys)
missing = [k for k in obs_keys if k not in obs.columns]
if missing:
raise KeyError(
f"obs_keys contain columns not in obs: {missing}. "
f"Available columns: {list(obs.columns)}"
)
# Collect lightweight metadata for obs fields (used in dataset identity)
obs_field_summaries: list[dict] = []
for key in obs_keys:
s = obs[key]
if pd.api.types.is_categorical_dtype(s):
kind = "category"
elif pd.api.types.is_bool_dtype(s):
kind = "category"
elif pd.api.types.is_numeric_dtype(s):
kind = "continuous"
else:
kind = "category"
if kind == "continuous":
if obs_continuous_quantization is None:
dtype_str = "float32"
elif obs_continuous_quantization == 8:
dtype_str = "uint8"
else:
dtype_str = "uint16"
obs_field_summaries.append(
{
"key": str(key),
"kind": "continuous",
"quantized": obs_continuous_quantization is not None,
"quantization_bits": int(obs_continuous_quantization)
if obs_continuous_quantization is not None
else None,
"dtype": dtype_str,
}
)
else:
cat = s.astype("category")
categories = [str(c) for c in cat.cat.categories]
n_categories = len(categories)
if obs_categorical_dtype == "auto":
dtype_str = "uint8" if n_categories <= 254 else "uint16"
elif obs_categorical_dtype == "uint8":
dtype_str = "uint8"
else:
dtype_str = "uint16"
obs_field_summaries.append(
{
"key": str(key),
"kind": "category",
"category_count": n_categories,
"codes_dtype": dtype_str,
"outlier_quantized": obs_continuous_quantization is not None,
"outlier_quantization_bits": int(obs_continuous_quantization)
if obs_continuous_quantization is not None
else None,
}
)
# Check if obs manifest already exists
obs_manifest_path = out_dir / obs_manifest_filename
if _file_exists_skip(obs_manifest_path, "obs manifest", force):
pass
else:
# Compact format: separate lists for continuous and categorical fields
obs_continuous_fields: list = []
obs_categorical_fields: list = []
# Track dtype info for schema (will use first encountered)
continuous_dtype_info: dict = {}
categorical_dtype_info: dict = {}
for key in obs_keys:
s = obs[key]
safe_key = _safe_filename_component(key)
# Decide kind: continuous vs categorical
if pd.api.types.is_categorical_dtype(s):
kind = "category"
elif pd.api.types.is_bool_dtype(s):
kind = "category"
elif pd.api.types.is_numeric_dtype(s):
kind = "continuous"
else:
kind = "category"
if kind == "continuous":
values = pd.to_numeric(s, errors="coerce").to_numpy(dtype=np.float32, copy=False)
if values.shape[0] != n_cells:
raise ValueError(
f"Length mismatch for obs['{key}']: {values.shape[0]} vs {n_cells}"
)
# Apply quantization if requested
if obs_continuous_quantization is not None:
quantized, min_val, max_val, scale, stats = _quantize_continuous(
values, bits=obs_continuous_quantization, field_name=key
)
_report_quantization_stats(key, stats, "obs continuous")
if obs_continuous_quantization == 8:
dtype_str = "uint8"
ext = "u8"
else:
dtype_str = "uint16"
ext = "u16"
value_fname = f"{safe_key}.values.{ext}"
value_path = obs_binary_dir / value_fname
actual_path = _write_binary(value_path, quantized, compression)
# Adjust path in manifest if compressed
manifest_path = f"{obs_binary_dirname}/{value_fname}"
if compression:
manifest_path += ".gz"
# Compact format: [key, minValue, maxValue]
obs_continuous_fields.append([key, min_val, max_val])
if not continuous_dtype_info:
continuous_dtype_info["ext"] = ext
continuous_dtype_info["dtype"] = dtype_str
continuous_dtype_info["quantized"] = True
continuous_dtype_info["quantizationBits"] = obs_continuous_quantization
else:
# Full precision
value_fname = f"{safe_key}.values.f32"
value_path = obs_binary_dir / value_fname
actual_path = _write_binary(value_path, values, compression)
# Compact format: [key]
obs_continuous_fields.append([key])
if not continuous_dtype_info:
continuous_dtype_info["ext"] = "f32"
continuous_dtype_info["dtype"] = "float32"
continuous_dtype_info["quantized"] = False
else:
# Categorical
cat = s.astype("category")
categories = [str(c) for c in cat.cat.categories]
codes = cat.cat.codes.to_numpy(dtype=np.int32) # -1 for NaN
if codes.shape[0] != n_cells:
raise ValueError(
f"Length mismatch for obs['{key}']: {codes.shape[0]} vs {n_cells}"
)
n_categories = len(categories)
# Select dtype based on settings
if obs_categorical_dtype == "auto":
dtype, missing_value = _select_category_dtype(n_categories)
elif obs_categorical_dtype == "uint8":
if n_categories > 254:
raise ValueError(
f"Field '{key}' has {n_categories} categories, "
f"but uint8 can only hold 254. Use 'auto' or 'uint16'."
)
dtype, missing_value = np.uint8, 255
else: # uint16
dtype, missing_value = np.uint16, 65535
codes_typed = np.full(n_cells, missing_value, dtype=dtype)
valid_mask = codes >= 0
codes_typed[valid_mask] = codes[valid_mask].astype(dtype)
if dtype == np.uint8:
codes_fname = f"{safe_key}.codes.u8"
dtype_str = "uint8"
else:
codes_fname = f"{safe_key}.codes.u16"
dtype_str = "uint16"
codes_path = obs_binary_dir / codes_fname
actual_path = _write_binary(codes_path, codes_typed, compression)
manifest_codes_path = f"{obs_binary_dirname}/{codes_fname}"
if compression:
manifest_codes_path += ".gz"
# Compute centroids for all available dimensions
if centroid_outlier_quantile is None:
centroids_by_dim = {dim: [] for dim in embeddings.keys()}
else:
centroids_by_dim = _compute_centroids_for_all_dimensions(
embeddings,
codes,
categories,
outlier_quantile=centroid_outlier_quantile,
min_points=centroid_min_points,
)
# Compute per-cell outlier quantiles based on latent space
outlier_quantiles = _compute_latent_space_quantiles(
latent=latent,
codes=codes,
categories=categories,
min_points=centroid_min_points,
)
# Quantize outlier quantiles (they're always 0-1)
if obs_continuous_quantization is not None:
oq_quantized, oq_min, oq_max, oq_scale, oq_stats = _quantize_continuous(
outlier_quantiles, bits=obs_continuous_quantization, field_name=f"{key}_outliers"
)
_report_quantization_stats(f"{key}_outliers", oq_stats, "outlier quantiles")
if obs_continuous_quantization == 8:
oq_dtype_str = "uint8"
oq_ext = "u8"
else:
oq_dtype_str = "uint16"
oq_ext = "u16"
outlier_fname = f"{safe_key}.outliers.{oq_ext}"
outlier_path = obs_binary_dir / outlier_fname
_write_binary(outlier_path, oq_quantized, compression)
manifest_outlier_path = f"{obs_binary_dirname}/{outlier_fname}"
if compression:
manifest_outlier_path += ".gz"
# Compact format: [key, categories, codesDtype, codesMissingValue, centroidsByDim, outlierMinValue, outlierMaxValue]
# centroidsByDim is a dict keyed by dimension: {"1": [...], "2": [...], "3": [...]}
centroids_serializable = {str(dim): cents for dim, cents in centroids_by_dim.items()}
obs_categorical_fields.append([
key, categories, dtype_str, int(missing_value), centroids_serializable, oq_min, oq_max
])
if not categorical_dtype_info:
categorical_dtype_info["codesExt"] = "u8" if dtype == np.uint8 else "u16"
categorical_dtype_info["outlierExt"] = oq_ext
categorical_dtype_info["outlierDtype"] = oq_dtype_str
categorical_dtype_info["outlierQuantized"] = True
else:
# Full precision outliers
outlier_fname = f"{safe_key}.outliers.f32"
outlier_path = obs_binary_dir / outlier_fname
_write_binary(outlier_path, outlier_quantiles.astype(np.float32), compression)
# Compact format: [key, categories, codesDtype, codesMissingValue, centroidsByDim]
# centroidsByDim is a dict keyed by dimension: {"1": [...], "2": [...], "3": [...]}
centroids_serializable = {str(dim): cents for dim, cents in centroids_by_dim.items()}
obs_categorical_fields.append([
key, categories, dtype_str, int(missing_value), centroids_serializable
])
if not categorical_dtype_info:
categorical_dtype_info["codesExt"] = "u8" if dtype == np.uint8 else "u16"
categorical_dtype_info["outlierExt"] = "f32"
categorical_dtype_info["outlierDtype"] = "float32"
categorical_dtype_info["outlierQuantized"] = False
# Build compact manifest with schemas
gz_suffix = ".gz" if compression else ""
obs_schemas = {}
if continuous_dtype_info:
obs_schemas["continuous"] = {
"pathPattern": f"{obs_binary_dirname}/{{key}}.values.{continuous_dtype_info['ext']}{gz_suffix}",
"ext": continuous_dtype_info["ext"],
"dtype": continuous_dtype_info["dtype"],
"quantized": continuous_dtype_info.get("quantized", False),
}
if continuous_dtype_info.get("quantized"):
obs_schemas["continuous"]["quantizationBits"] = continuous_dtype_info["quantizationBits"]
if categorical_dtype_info:
obs_schemas["categorical"] = {
"codesPathPattern": f"{obs_binary_dirname}/{{key}}.codes.{{ext}}{gz_suffix}",
"outlierPathPattern": f"{obs_binary_dirname}/{{key}}.outliers.{categorical_dtype_info['outlierExt']}{gz_suffix}",
"outlierExt": categorical_dtype_info["outlierExt"],
"outlierDtype": categorical_dtype_info["outlierDtype"],
"outlierQuantized": categorical_dtype_info.get("outlierQuantized", False),
}
obs_manifest_payload = {
"_format": MANIFEST_FORMAT_VERSION,
"n_points": int(n_cells),
"centroid_outlier_quantile": float(centroid_outlier_quantile)
if centroid_outlier_quantile is not None
else None,
"latent_key": "latent_space",
"compression": compression if compression else None,
"_obsSchemas": obs_schemas,
"_continuousFields": obs_continuous_fields,
"_categoricalFields": obs_categorical_fields,
}
obs_manifest_path.write_text(json.dumps(obs_manifest_payload), encoding="utf-8")
total_fields = len(obs_continuous_fields) + len(obs_categorical_fields)
print(
f"✓ Wrote obs manifest ({total_fields} fields: {len(obs_continuous_fields)} continuous, "
f"{len(obs_categorical_fields)} categorical) to {obs_manifest_path} "
f"with binaries in {obs_binary_dirname}/"
)
# Process gene expression if provided
genes_to_export: list[str] = []
if gene_expression is not None:
if var is None:
raise ValueError("var DataFrame must be provided when gene_expression is given.")
gene_expr_is_sparse = sparse.issparse(gene_expression)
gene_expression_for_export = (
gene_expression.tocsc() if gene_expr_is_sparse and not sparse.isspmatrix_csc(gene_expression) else gene_expression
)
n_expr_cells = gene_expression_for_export.shape[0]
n_genes = gene_expression_for_export.shape[1]
if n_expr_cells != n_cells:
raise ValueError(
f"gene_expression has {n_expr_cells} cells, but embeddings have {n_cells} cells."
)
if len(var) != n_genes:
raise ValueError(
f"var has {len(var)} rows, but gene_expression has {n_genes} genes."
)
if var_gene_id_column == "index" or var_gene_id_column is None:
all_gene_ids = var.index.astype(str).tolist()
else:
if var_gene_id_column not in var.columns:
raise KeyError(
f"var_gene_id_column '{var_gene_id_column}' not found in var. "
f"Available columns: {list(var.columns)}"
)
all_gene_ids = var[var_gene_id_column].astype(str).tolist()
gene_id_to_idx = {gid: idx for idx, gid in enumerate(all_gene_ids)}
if gene_identifiers is None:
genes_to_export = all_gene_ids
else:
genes_to_export = list(gene_identifiers)
missing_genes = [g for g in genes_to_export if g not in gene_id_to_idx]
if missing_genes:
print(f"⚠ Warning: {len(missing_genes)} gene identifiers not found in var: {missing_genes[:5]}...")
genes_to_export = [g for g in genes_to_export if g in gene_id_to_idx]
var_manifest_path = out_dir / var_manifest_filename
if _file_exists_skip(var_manifest_path, "var manifest", force):
pass
else:
var_binary_dir = out_dir / var_binary_dirname
var_binary_dir.mkdir(parents=True, exist_ok=True)
var_manifest_fields: list[dict] = []
# Track problematic genes for aggregated reporting
genes_with_nan: list[str] = []
genes_with_inf: list[str] = []
genes_all_invalid: list[str] = []
for gene_id in tqdm.tqdm(genes_to_export, desc="Exporting genes"):
gene_idx = gene_id_to_idx[gene_id]
safe_gene_id = _safe_filename_component(gene_id)
if gene_expr_is_sparse:
col = gene_expression_for_export.getcol(gene_idx).toarray().flatten()
else:
col = gene_expression_for_export[:, gene_idx]
values = np.asarray(col, dtype=np.float32)
if values.shape[0] != n_cells:
raise ValueError(
f"Gene '{gene_id}' expression length mismatch: {values.shape[0]} vs {n_cells}"
)
# Apply quantization if requested
if var_quantization is not None:
quantized, min_val, max_val, scale, stats = _quantize_continuous(
values, bits=var_quantization, field_name=gene_id
)
# Track aggregated stats for genes (don't spam per-gene)
if stats.get("n_nan", 0) > 0:
genes_with_nan.append(gene_id)
if stats.get("n_inf", 0) > 0:
genes_with_inf.append(gene_id)
if stats.get("warning") == "all_invalid":
genes_all_invalid.append(gene_id)
if var_quantization == 8:
dtype_str = "uint8"
ext = "u8"
else:
dtype_str = "uint16"
ext = "u16"
value_fname = f"{safe_gene_id}.values.{ext}"
value_path = var_binary_dir / value_fname
_write_binary(value_path, quantized, compression)
manifest_path = f"{var_binary_dirname}/{value_fname}"
if compression:
manifest_path += ".gz"
# Compact format: [key, minValue, maxValue]
var_manifest_fields.append([gene_id, min_val, max_val])
else:
# Full precision
value_fname = f"{safe_gene_id}.values.f32"
value_path = var_binary_dir / value_fname
_write_binary(value_path, values, compression)
# Compact format: [key] for non-quantized
var_manifest_fields.append([gene_id])
# Report aggregated gene stats
if genes_with_nan:
print(f" ⚠ {len(genes_with_nan)} genes contain NaN values (mapped to missing marker)")
if len(genes_with_nan) <= 10:
print(f" Genes: {', '.join(genes_with_nan)}")
else:
print(f" First 10: {', '.join(genes_with_nan[:10])}...")
if genes_with_inf:
print(f" ⚠ {len(genes_with_inf)} genes contain Inf values (mapped to missing marker)")
if len(genes_with_inf) <= 10:
print(f" Genes: {', '.join(genes_with_inf)}")
if genes_all_invalid:
print(f" ⚠ WARNING: {len(genes_all_invalid)} genes have NO valid values (all NaN/Inf)")
if len(genes_all_invalid) <= 10:
print(f" Genes: {', '.join(genes_all_invalid)}")
# Build compact manifest with schema
gz_suffix = ".gz" if compression else ""
if var_quantization is not None:
ext = "u8" if var_quantization == 8 else "u16"
dtype_str = "uint8" if var_quantization == 8 else "uint16"
var_schema = {
"kind": "continuous",
"pathPattern": f"{var_binary_dirname}/{{key}}.values.{ext}{gz_suffix}",
"ext": ext,
"dtype": dtype_str,
"quantized": True,
"quantizationBits": var_quantization,
}
else:
var_schema = {
"kind": "continuous",
"pathPattern": f"{var_binary_dirname}/{{key}}.values.f32{gz_suffix}",
"ext": "f32",
"dtype": "float32",
"quantized": False,
}
var_manifest_payload = {
"_format": MANIFEST_FORMAT_VERSION,
"n_points": int(n_cells),
"var_gene_id_column": var_gene_id_column,
"compression": compression if compression else None,
"quantization": var_quantization,
"_varSchema": var_schema,
"fields": var_manifest_fields,
}
var_manifest_path.write_text(json.dumps(var_manifest_payload), encoding="utf-8")
compression_info = f", gzip level {compression}" if compression else ""
quant_info = f", {var_quantization}-bit quantized" if var_quantization else ""
print(
f"✓ Wrote var manifest ({len(var_manifest_fields)} genes{quant_info}{compression_info}) "
f"to {var_manifest_path}"
)
else:
print("INFO: No gene expression data provided, skipping var export.")
# Process connectivity data if provided
# GPU-optimized edge format for instanced rendering
connectivity_meta = {
"n_edges": None,
"max_neighbors": None,
"index_dtype": None,
}
if connectivities is not None:
connectivity_manifest_path = out_dir / connectivity_manifest_filename
if connectivities.shape[0] != n_cells or connectivities.shape[1] != n_cells:
raise ValueError(
f"Connectivity matrix shape {connectivities.shape} does not match "
f"number of cells {n_cells}."
)
# Determine optimal dtype based on cell count up front (used in hash even if we skip writing)
if n_cells <= 65535:
index_dtype = np.uint16
index_dtype_str = "uint16"
index_bytes = 2
elif n_cells <= 4294967295:
index_dtype = np.uint32
index_dtype_str = "uint32"
index_bytes = 4
else:
index_dtype = np.uint64
index_dtype_str = "uint64"
index_bytes = 8
approx_edges = None
if sparse.issparse(connectivities):
approx_edges = int(connectivities.nnz // 2)
else:
approx_edges = int(np.count_nonzero(connectivities) // 2)
connectivity_meta["n_edges"] = approx_edges
connectivity_meta["index_dtype"] = index_dtype_str
if _file_exists_skip(connectivity_manifest_path, "connectivity manifest", force):
pass
else:
connectivity_binary_dir = out_dir / connectivity_binary_dirname
connectivity_binary_dir.mkdir(parents=True, exist_ok=True)
if not sparse.isspmatrix_csr(connectivities):
connectivities = sparse.csr_matrix(connectivities)
# Symmetrize and binarize the connectivity matrix
connectivities_sym = connectivities + connectivities.T
connectivities_sym.data[:] = 1
connectivities_csr = connectivities_sym.tocsr()
indptr = connectivities_csr.indptr
indices = connectivities_csr.indices
# Extract unique edges (src < dst to avoid duplicates)
# Using vectorized operations for speed with large datasets
print(f" Extracting unique edges from {n_cells:,} cells...")
edge_sources = []
edge_destinations = []
max_neighbors_found = 0
# Process in chunks for memory efficiency with very large datasets
chunk_size = 100000
for chunk_start in range(0, n_cells, chunk_size):
chunk_end = min(chunk_start + chunk_size, n_cells)
for cell_idx in range(chunk_start, chunk_end):
start = indptr[cell_idx]
end = indptr[cell_idx + 1]
neighbor_count = end - start
if neighbor_count > max_neighbors_found:
max_neighbors_found = neighbor_count
for j in range(start, end):
neighbor_idx = indices[j]
# Only keep edges where src < dst (avoid duplicates)
if cell_idx < neighbor_idx:
edge_sources.append(cell_idx)
edge_destinations.append(neighbor_idx)
edge_sources = np.array(edge_sources, dtype=index_dtype)
edge_destinations = np.array(edge_destinations, dtype=index_dtype)
n_unique_edges = len(edge_sources)
connectivity_meta["n_edges"] = int(n_unique_edges)
connectivity_meta["max_neighbors"] = int(max_neighbors_found)
connectivity_meta["index_dtype"] = index_dtype_str
print(f" Found {n_unique_edges:,} unique edges, max {max_neighbors_found} neighbors/cell")
# Sort edges by source, then by destination for optimal gzip compression
print(f" Sorting edges for optimal compression...")
sort_idx = np.lexsort((edge_destinations, edge_sources))
edge_sources = edge_sources[sort_idx]
edge_destinations = edge_destinations[sort_idx]
# Write binary files (column-separated for better compression)
sources_fname = "edges.src.bin"
dests_fname = "edges.dst.bin"
sources_path = connectivity_binary_dir / sources_fname
dests_path = connectivity_binary_dir / dests_fname
_write_binary(sources_path, edge_sources, compression)
_write_binary(dests_path, edge_destinations, compression)
manifest_sources = f"{connectivity_binary_dirname}/{sources_fname}"
manifest_dests = f"{connectivity_binary_dirname}/{dests_fname}"
if compression:
manifest_sources += ".gz"
manifest_dests += ".gz"
# Write manifest
connectivity_manifest_payload = {
"format": "edge_pairs",
"n_cells": int(n_cells),
"n_edges": int(n_unique_edges),
"max_neighbors": int(max_neighbors_found),
"index_bytes": index_bytes,
"index_dtype": index_dtype_str,
"sourcesPath": manifest_sources,
"destinationsPath": manifest_dests,
"compression": compression if compression else None,
}
connectivity_manifest_path.write_text(
json.dumps(connectivity_manifest_payload), encoding="utf-8"
)
print(
f"✓ Wrote connectivity ({n_unique_edges:,} edges, "
f"max {max_neighbors_found} neighbors/cell, {index_dtype_str}) "
f"to {connectivity_binary_dir}"
)
else:
print("INFO: No connectivity data provided, skipping connectivity export.")
# =========================================================================
# Generate dataset_identity.json (metadata for multi-dataset support)
# =========================================================================
identity_path = out_dir / "dataset_identity.json"
# Try to import version from package
try:
from cellucid import __version__ as cellucid_version
except ImportError:
cellucid_version = "unknown"
# Determine dataset ID and name
if dataset_id is None:
if dataset_name:
dataset_id = _safe_filename_component(dataset_name)
else:
dataset_id = _safe_filename_component(out_dir.name)
if dataset_name is None:
dataset_name = out_dir.name
# Count genes
n_genes = 0
if gene_expression is not None and var is not None:
if sparse.issparse(gene_expression):
n_genes = gene_expression.shape[1]
else:
n_genes = np.asarray(gene_expression).shape[1]
# Count obs fields from the obs_field_summaries list populated earlier (lines 528-582)
# This works regardless of whether the obs manifest was written or skipped
n_obs_fields = len(obs_field_summaries)
n_categorical_fields = sum(1 for f in obs_field_summaries if f.get("kind") == "category")
n_continuous_fields = sum(1 for f in obs_field_summaries if f.get("kind") == "continuous")
# Validate that field counts add up (sanity check)
if n_obs_fields != n_categorical_fields + n_continuous_fields:
print(
f" ⚠ WARNING: Field count mismatch! "
f"n_obs_fields={n_obs_fields}, categorical={n_categorical_fields}, continuous={n_continuous_fields}"
)
# Debug: show fields with unexpected kinds
unexpected = [f for f in obs_field_summaries if f.get("kind") not in ("category", "continuous")]
if unexpected:
print(f" Fields with unexpected kind: {[f.get('key') for f in unexpected]}")
# Convert obs_field_summaries to the format expected in dataset_identity.json
# (simplified version with just key, kind, and n_categories for categorical)
identity_obs_fields = []
for field_info in obs_field_summaries:
entry = {
"key": field_info["key"],
"kind": field_info["kind"]
}
if field_info["kind"] == "category" and "category_count" in field_info:
entry["n_categories"] = field_info["category_count"]
identity_obs_fields.append(entry)
# Build source info if provided
source_info = None
if source_name or source_url or source_citation:
source_info = {}
if source_name:
source_info["name"] = source_name
if source_url:
source_info["url"] = source_url
if source_citation:
source_info["citation"] = source_citation
# Build export settings
export_settings = {
"compression": compression if compression else None,
"var_quantization": var_quantization,
"obs_continuous_quantization": obs_continuous_quantization,
"obs_categorical_dtype": obs_categorical_dtype
}
# Build embeddings metadata
gz_suffix = ".gz" if compression else ""
embeddings_meta = {
"available_dimensions": available_dimensions,
"default_dimension": default_dimension,
"files": {}
}
for dim in available_dimensions:
embeddings_meta["files"][f"{dim}d"] = f"points_{dim}d.bin{gz_suffix}"
# Build identity payload
identity_payload = {
"version": 2, # Bumped version for multi-dimensional support
"id": dataset_id,
"name": dataset_name,
"description": dataset_description or "",
"created_at": datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ"),
"cellucid_data_version": cellucid_version,
"stats": {
"n_cells": int(n_cells),
"n_genes": int(n_genes),
"n_obs_fields": int(n_obs_fields),
"n_categorical_fields": int(n_categorical_fields),
"n_continuous_fields": int(n_continuous_fields),
"has_connectivity": connectivity_meta.get("n_edges") is not None,
"n_edges": connectivity_meta.get("n_edges")
},
"embeddings": embeddings_meta,
"obs_fields": identity_obs_fields,
"export_settings": export_settings
}
if source_info:
identity_payload["source"] = source_info
if vector_fields_identity:
identity_payload["vector_fields"] = vector_fields_identity
identity_path.write_text(json.dumps(identity_payload, indent=2), encoding="utf-8")
print(f"✓ Wrote dataset identity to {identity_path}")
def generate_datasets_manifest(
exports_dir: Union[str, Path] = DEFAULT_EXPORT_DIR,
output_filename: str = "datasets.json",
default_dataset: Optional[str] = None,
) -> Optional[Path]:
"""
Scan an exports directory for dataset_identity.json files and generate a datasets.json manifest.
This utility helps maintain the datasets.json manifest file that the frontend uses
to discover available demo datasets. Run this after adding or removing datasets.
Parameters
----------
exports_dir : Path or str
Directory containing dataset subdirectories (default: exports/).
output_filename : str
Name of the output manifest file (default: datasets.json).
default_dataset : str or None
ID of the default dataset. If None, uses the first dataset found.
Returns
-------
Path
Path to the generated datasets.json file.
Example
-------
>>> from cellucid.prepare_data import generate_datasets_manifest
>>> generate_datasets_manifest("./exports", default_dataset="my_dataset")
"""
exports_dir = Path(exports_dir)
if not exports_dir.exists():
raise FileNotFoundError(f"Exports directory not found: {exports_dir}")
print(f"Scanning {exports_dir} for datasets...")
datasets = []
# Scan subdirectories for dataset_identity.json
for subdir in sorted(exports_dir.iterdir()):
if not subdir.is_dir():
continue
identity_file = subdir / "dataset_identity.json"
if not identity_file.exists():
print(f" ⚠ Skipping {subdir.name}: no dataset_identity.json")
continue
try:
identity = json.loads(identity_file.read_text(encoding="utf-8"))
dataset_entry = {
"id": identity.get("id", subdir.name),
"path": f"{subdir.name}/",
"name": identity.get("name", subdir.name),
}
# Include quick stats for display in dropdown
stats = identity.get("stats", {})
if stats.get("n_cells"):
dataset_entry["n_cells"] = stats["n_cells"]
if stats.get("n_genes"):
dataset_entry["n_genes"] = stats["n_genes"]
datasets.append(dataset_entry)
print(f" ✓ Found dataset: {dataset_entry['name']} ({dataset_entry['id']})")
except json.JSONDecodeError as e:
print(f" ⚠ Skipping {subdir.name}: invalid JSON - {e}")
except Exception as e:
print(f" ⚠ Skipping {subdir.name}: {e}")
if not datasets:
print(" No datasets found!")
return None
# Determine default dataset
if default_dataset:
if not any(d["id"] == default_dataset for d in datasets):
print(f" ⚠ Warning: default_dataset '{default_dataset}' not found, using first dataset")
default_dataset = datasets[0]["id"]
else:
default_dataset = datasets[0]["id"]
# Build manifest
manifest = {
"version": 1,
"default": default_dataset,
"datasets": datasets
}
# Write manifest
manifest_path = exports_dir / output_filename
manifest_path.write_text(json.dumps(manifest, indent=2), encoding="utf-8")
print(f"✓ Wrote datasets manifest with {len(datasets)} datasets to {manifest_path}")
print(f" Default dataset: {default_dataset}")
return manifest_path