Source code for delnx.tl._rank_de

"""Rank-based differential expression analysis for single-cell data.

This module provides rank-based differential expression analysis using Area Under the
ROC Curve (AUROC) statistics. It efficiently handles sparse gene expression matrices
by implementing optimized ranking algorithms with optional tie correction for improved
statistical accuracy.
"""

import jax
import jax.numpy as jnp
import numba
import numpy as np
import pandas as pd
import statsmodels.api as sm
import tqdm
from anndata import AnnData
from scipy import sparse, stats

from delnx._logging import logger
from delnx._utils import _get_layer

# =============================================================================
# Core Numba-optimized ranking functions
# =============================================================================


@numba.njit
def _rankdata(arr: np.ndarray) -> np.ndarray:
    """Fast 1D ranking with average tie handling."""
    n = len(arr)
    if n <= 1:
        return np.array([1.0]) if n == 1 else np.empty(0, dtype=np.float64)

    sorter = np.argsort(arr)
    arr = arr[sorter]
    obs = np.concatenate((np.array([True]), arr[1:] != arr[:-1]))
    dense = np.empty(obs.size, dtype=np.int64)
    dense[sorter] = obs.cumsum()
    # cumulative counts of each unique value
    count = np.concatenate((np.flatnonzero(obs), np.array([len(obs)])))
    ranks = 0.5 * (count[dense] + count[dense - 1] + 1)
    return ranks


@numba.njit
def _rankdata_with_ties(arr: np.ndarray) -> tuple[np.ndarray, np.float64]:
    """Fast 1D ranking with tie correction calculation."""
    n = len(arr)
    if n <= 1:
        return (np.array([1.0]) if n == 1 else np.empty(0, dtype=np.float64), 1.0)

    sorter = np.argsort(arr)
    sorted_arr = arr[sorter]

    # Find tie group boundaries more efficiently
    tie_starts = [0]
    for i in range(1, n):
        if sorted_arr[i] != sorted_arr[i - 1]:
            tie_starts.append(i)
    tie_starts.append(n)

    # Calculate ranks and tie correction simultaneously
    ranks = np.empty(n, dtype=np.float64)
    tie_sum = 0.0

    for i in range(len(tie_starts) - 1):
        start, end = tie_starts[i], tie_starts[i + 1]
        tie_size = end - start
        avg_rank = (start + end + 1) / 2.0

        # Assign average rank to tied group
        for j in range(start, end):
            ranks[sorter[j]] = avg_rank

        # Accumulate tie contribution
        if tie_size > 1:
            tie_sum += tie_size**3 - tie_size

    # Calculate tie correction
    tie_correction = 1.0 - tie_sum / (n**3 - n) if n >= 2 else 1.0
    return ranks, tie_correction


@numba.njit(parallel=True, cache=True)
def _rank_sparse_batch_parallel(
    data: np.ndarray, indptr: np.ndarray, nrows: int, ncols: int, use_ties: bool
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
    """Optimized parallel ranking for sparse matrix batches."""
    ranked_data = np.empty_like(data, dtype=np.float64)
    zero_ranks = np.zeros(ncols, dtype=np.float64)
    tie_corrections = np.ones(ncols, dtype=np.float64)

    for col_idx in numba.prange(ncols):
        start_idx = indptr[col_idx]
        end_idx = indptr[col_idx + 1]
        n_nonzero = end_idx - start_idx

        if n_nonzero == 0:
            zero_ranks[col_idx] = 0.0
            continue

        n_zero = nrows - n_nonzero
        zero_ranks[col_idx] = (n_zero + 1) / 2.0 if n_zero > 0 else 0.0

        col_data = data[start_idx:end_idx]

        if use_ties:
            nonzero_ranks, nonzero_tie_corr = _rankdata_with_ties(col_data)
            # Calculate combined tie correction including zeros
            if nrows >= 2:
                zero_contrib = n_zero**3 - n_zero if n_zero > 1 else 0.0
                nonzero_contrib = (1.0 - nonzero_tie_corr) * (n_nonzero**3 - n_nonzero)
                total_contrib = zero_contrib + nonzero_contrib
                tie_corrections[col_idx] = 1.0 - total_contrib / (nrows**3 - nrows)
        else:
            nonzero_ranks = _rankdata(col_data)

        ranked_data[start_idx:end_idx] = nonzero_ranks + n_zero

    return ranked_data, zero_ranks, tie_corrections


def _rank_sparse_batch_serial(
    data: np.ndarray, indptr: np.ndarray, nrows: int, ncols: int, use_ties: bool
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
    """Serial ranking for sparse matrix batches using scipy.stats.rankdata."""
    ranked_data = np.empty_like(data, dtype=np.float64)
    zero_ranks = np.zeros(ncols, dtype=np.float64)
    tie_corrections = np.ones(ncols, dtype=np.float64)

    for col_idx in range(ncols):
        start_idx = indptr[col_idx]
        end_idx = indptr[col_idx + 1]

        if end_idx > start_idx:  # if column has non-zero elements
            n_nonzero = end_idx - start_idx
            n_zero = nrows - n_nonzero

            # Calculate zero rank (average of ranks 1 to n_zero)
            zero_ranks[col_idx] = (n_zero + 1) / 2.0 if n_zero > 0 else 0.0

            # Rank non-zero values
            col_data = data[start_idx:end_idx]
            nonzero_ranks = stats.rankdata(col_data, method="average")
            ranked_data[start_idx:end_idx] = nonzero_ranks + n_zero

        else:
            # Empty column case
            zero_ranks[col_idx] = 0.0

    return ranked_data, zero_ranks, tie_corrections


# =============================================================================
# JAX AUROC calculation
# =============================================================================


@jax.jit
def _auroc_batch_with_ties(
    X: jnp.ndarray, one_hot: jnp.ndarray, zero_ranks: jnp.ndarray, tie_corrections: jnp.ndarray
) -> tuple[jnp.ndarray, jnp.ndarray]:
    """Optimized AUROC calculation with tie correction."""
    # Reconstruct full ranks efficiently
    nonzero_mask = X > 0
    full_ranks = jnp.where(nonzero_mask, X, zero_ranks[None, :])

    # Compute statistics
    n_pos = jnp.maximum(one_hot.sum(axis=0), 1e-10)
    n_neg = jnp.maximum(X.shape[0] - n_pos, 1e-10)

    # Mann-Whitney U statistic
    rank_sum_pos = full_ranks.T @ one_hot
    U = rank_sum_pos - n_pos * (n_pos + 1) / 2

    # AUROC and p-values
    aucs = jnp.clip(U / (n_pos * n_neg), 0.0, 1.0)

    # Tie-corrected p-values
    mu_U = n_pos * n_neg / 2
    sigma_U_sq = (n_pos * n_neg * (n_pos + n_neg + 1) / 12) * tie_corrections[:, None]
    sigma_U = jnp.sqrt(jnp.maximum(sigma_U_sq, 1e-20))

    z_score = jnp.where(sigma_U > 1e-10, jnp.abs(U - mu_U) / sigma_U, 0.0)

    return aucs, z_score


# =============================================================================
# Batch processing function
# =============================================================================


def _process_batch(
    X_batch: sparse.csc_matrix, one_hot: jnp.ndarray, use_ties: bool, use_parallel: bool
) -> tuple[np.ndarray, np.ndarray]:
    """Process a single batch with optimized ranking and AUROC calculation."""
    # Ensure data is writable
    X_batch.data.setflags(write=True)

    # Choose ranking algorithm
    rank_func = _rank_sparse_batch_parallel if use_parallel else _rank_sparse_batch_serial
    ranked_data, zero_ranks, tie_corrections = rank_func(
        X_batch.data, X_batch.indptr, X_batch.shape[0], X_batch.shape[1], use_ties
    )

    # Convert to JAX arrays efficiently
    X_batch.data = ranked_data
    X_dense = jnp.asarray(X_batch.toarray(), dtype=jnp.float32)
    zero_ranks_jax = jnp.asarray(zero_ranks, dtype=jnp.float32)
    tie_corrections_jax = jnp.asarray(tie_corrections, dtype=jnp.float32)

    aucs, z_scores = _auroc_batch_with_ties(X_dense, one_hot, zero_ranks_jax, tie_corrections_jax)
    pvals = 2 * jax.scipy.stats.norm.sf(z_scores)

    # Single device-to-host transfer
    return np.asarray(aucs), np.asarray(pvals), np.asarray(z_scores)


def _validate_inputs(adata: AnnData, condition_key: str, min_samples: int) -> list[str]:
    """Validate inputs and return valid conditions."""
    if condition_key not in adata.obs.columns:
        raise ValueError(f"Condition key '{condition_key}' not found in adata.obs")

    condition_values = adata.obs[condition_key].values
    unique_conditions = np.unique(condition_values)

    valid_conditions = []
    for condition in unique_conditions:
        n_samples = np.sum(condition_values == condition)
        if n_samples >= min_samples:
            valid_conditions.append(condition)

    if len(valid_conditions) < 2:
        raise ValueError(f"Need at least 2 valid conditions, found {len(valid_conditions)}")

    return valid_conditions


def _determine_algorithm(n_cpus: int, use_ties: bool, verbose: bool) -> tuple[bool, int]:
    """Set threads and determine whether to use parallel processing."""
    if n_cpus is None:
        n_cpus = numba.get_num_threads()
    else:
        max_threads = numba.get_num_threads()
        if n_cpus > max_threads:
            logger.warning(f"Requested n_cpus={n_cpus} exceeds available {max_threads}, using {max_threads}")
            n_cpus = max_threads
        numba.set_num_threads(n_cpus)

    # Decide on parallel processing
    # Serial processing (without ties) tends to be faster for small CPU counts
    use_parallel = (n_cpus >= 4) or use_ties

    return use_parallel


# =============================================================================
# Main user-facing function
# =============================================================================


[docs] def rank_de( adata: AnnData, condition_key: str, layer: str | None = None, use_ties: bool = False, multitest_method: str = "fdr_bh", n_cpus: int | None = None, min_samples: int = 2, batch_size: int | None = 512, verbose: bool = True, ) -> pd.DataFrame: """Perform rank-based differential expression analysis using AUROC statistics. This function identifies differentially expressed features between conditions using Area Under the ROC Curve (AUROC) analysis on ranked gene expression data. It efficiently handles sparse matrices with memory-optimized batch processing and provides optional tie correction for improved statistical accuracy. Parameters ---------- adata : AnnData AnnData object containing expression data and metadata. condition_key : str Column name in `adata.obs` containing condition labels. layer : str | None, default=None Layer in `adata.layers` to use for expression data. If None, uses `adata.X`. use_ties : bool, default=False Whether to apply tie correction to p-value calculations. Recommended for sparse data with many identical values (especially zeros). multitest_method : str, default='fdr_bh' Method for multiple testing correction. Accepts any method supported by :func:`statsmodels.stats.multipletests`. Common options include: - "fdr_bh": Benjamini-Hochberg FDR correction - "bonferroni": Bonferroni correction n_cpus : int | None, default=None Number of CPU cores for parallel processing. If None, uses available threads. Parallel processing is enabled when n_cpus >= 4 or use_ties=True. min_samples : int, default=2 Minimum samples required per condition. Conditions with fewer samples excluded. batch_size : int | None, default=512 Features per batch. Larger values use more memory but may be more efficient. If None, processes all features at once. verbose : bool, default=True Whether to show progress and algorithm information. Returns ------- pd.DataFrame Results with columns: - "feature": Feature/gene names - "condition": Condition label (one-vs-all comparison) - "auroc": AUROC values (0.5=random, >0.5=upregulated, <0.5=downregulated) - "pval": Two-tailed p-values from Mann-Whitney U test - "tie_corrected": Whether tie correction was applied Examples -------- Basic one-vs-all differential expression: >>> import delnx as dx >>> results = dx.tl.rank_de(adata, condition_key="cell_type") With tie correction for improved p-values: >>> results = dx.tl.rank_de(adata, condition_key="treatment", use_ties=True) Custom batch size for memory optimization: >>> results = dx.tl.rank_de(adata, condition_key="condition", batch_size=1024) Notes ----- This implementation uses several optimizations: - Numba JIT compilation for fast ranking algorithms - JAX acceleration for AUROC calculations - Memory-efficient batch processing to handle large datasets - Automatic algorithm selection based on data characteristics - Optional tie correction for statistical accuracy with sparse data """ # Input validation valid_conditions = _validate_inputs(adata, condition_key, min_samples) use_parallel = _determine_algorithm(n_cpus, use_ties, verbose) # Get data and encode conditions X = _get_layer(adata, layer) if not sparse.issparse(X): X = sparse.csc_matrix(X) elif not isinstance(X, sparse.csc_matrix): X = X.tocsc() condition_values = adata.obs[condition_key].values groups = np.zeros(len(condition_values), dtype=np.int32) for i, condition in enumerate(valid_conditions): groups[condition_values == condition] = i # Pre-compute one-hot encoding for efficiency n_classes = len(valid_conditions) one_hot = jax.nn.one_hot(jnp.array(groups, dtype=jnp.int32), n_classes, dtype=jnp.float32) if verbose: logger.info(f"Processing {X.shape[1]} features across {len(valid_conditions)} conditions", verbose=True) # Batch processing batch_size = batch_size or X.shape[1] n_features = X.shape[1] all_aucs, all_pvals, all_z_scores = [], [], [] for start_idx in tqdm.tqdm(range(0, n_features, batch_size), disable=not verbose): end_idx = min(start_idx + batch_size, n_features) X_batch = X[:, start_idx:end_idx].copy() aucs, pvals, z_scores = _process_batch(X_batch, one_hot, use_ties, use_parallel) all_aucs.append(aucs) all_pvals.append(pvals) all_z_scores.append(z_scores) # Combine results and create output DataFrame final_aucs = np.concatenate(all_aucs, axis=0) final_pvals = np.concatenate(all_pvals, axis=0) final_z_scores = np.concatenate(all_z_scores, axis=0) # Reshape to long format n_features, n_conditions = final_aucs.shape features_long = np.tile(adata.var_names.values, n_conditions) conditions_long = np.repeat(valid_conditions, n_features) aucs_long = final_aucs.T.flatten() pvals_long = final_pvals.T.flatten() z_scores_long = final_z_scores.T.flatten() results = pd.DataFrame( { "feature": features_long, "condition": conditions_long, "auroc": aucs_long, "z_score": z_scores_long, "pval": pvals_long, } ) # Perform multiple testing correction padj = sm.stats.multipletests( results["pval"][results["pval"].notna()].values, method=multitest_method, )[1] results["padj"] = np.nan # Initialize with NaN results.loc[results["pval"].notna(), "padj"] = padj results["pval"] = np.clip(results["pval"], 1e-50, 1) results["padj"] = np.clip(results["padj"], 1e-50, 1) results = results.sort_values(by=["condition", "auroc", "pval"], ascending=[True, False, True]).reset_index( drop=True ) return results