Source code for delnx.tl._grouped
"""Grouped differential expression wrapper.
Runs any DE function per group (e.g., cell type) and combines results
with cross-group multiple testing correction.
"""
from collections.abc import Callable
import numpy as np
import pandas as pd
import statsmodels.api as sm
from anndata import AnnData
from delnx._logging import logger
[docs]
def grouped(
func: Callable[..., pd.DataFrame],
adata: AnnData,
group_key: str,
min_samples: int = 2,
multitest_method: str = "fdr_bh",
verbose: bool = True,
**kwargs,
) -> pd.DataFrame:
"""Run a DE function separately for each group and combine results.
Thin orchestrator that subsets ``adata`` by each unique value of
``adata.obs[group_key]``, calls ``func`` on each subset, and
re-corrects p-values across all groups.
Parameters
----------
func : callable
DE function with signature ``func(adata, **kwargs) -> pd.DataFrame``.
The returned DataFrame must contain a ``pval`` column.
Works with :func:`de`, :func:`rank_de`, or any custom function
(e.g., a lambda wrapping :func:`nb_fit` + :func:`nb_test`).
adata : AnnData
Annotated data object.
group_key : str
Column in ``adata.obs`` defining groups (e.g., ``"cell_type"``).
min_samples : int, default=2
Minimum observations per group to run the analysis. Groups with
fewer observations are skipped with a warning.
multitest_method : str, default="fdr_bh"
Method for multiple testing correction across all groups
(see :func:`statsmodels.stats.multipletests`).
verbose : bool, default=True
Whether to print progress messages.
**kwargs
Passed through to ``func``.
Returns
-------
pd.DataFrame
Combined results with an additional ``group`` column. The ``padj``
column is re-computed across all groups.
Examples
--------
Per-cell-type logistic regression:
>>> results = dx.tl.grouped(dx.tl.de, adata, group_key="cell_type",
... condition_key="treatment", reference="control",
... contrast="treatment[T.drugA]")
Per-cell-type rank-based markers:
>>> results = dx.tl.grouped(dx.tl.rank_de, adata, group_key="cell_type", condition_key="treatment")
Per-cell-type negative binomial DE:
>>> def nb_de(adata, **kw):
... fit = dx.tl.nb_fit(adata, **kw)
... return dx.tl.nb_test(adata, fit)
>>> results = dx.tl.grouped(nb_de, adata, group_key="cell_type", condition_key="treatment")
"""
if group_key not in adata.obs.columns:
raise ValueError(f"Group key '{group_key}' not found in adata.obs")
results = []
for group in adata.obs[group_key].unique():
mask = adata.obs[group_key].values == group
n_obs = np.sum(mask)
if n_obs < min_samples:
logger.warning(f"Skipping group '{group}' with {n_obs} < {min_samples} samples", verbose=verbose)
continue
logger.info(f"Running DE for group: {group}", verbose=verbose)
try:
group_results = func(adata[mask, :], **kwargs)
group_results["group"] = group
results.append(group_results)
except ValueError as e:
logger.warning(f"DE failed for group '{group}': {e}. Skipping.", verbose=verbose)
continue
if not results:
raise ValueError(
"Differential expression analysis failed for all groups. "
"Check input data or set verbose=True for details."
)
results = pd.concat(results, axis=0).reset_index(drop=True)
if "pval" not in results.columns:
raise ValueError("DE function must return a DataFrame with a 'pval' column.")
if results["pval"].notna().any():
# Re-correct p-values across all groups
valid = results["pval"].notna()
padj = sm.stats.multipletests(results.loc[valid, "pval"].values, method=multitest_method)[1]
results["padj"] = np.nan
results.loc[valid, "padj"] = padj
# Sort by group, then by padj
sort_cols = ["group"]
if "test_condition" in results.columns:
sort_cols += ["test_condition", "ref_condition"]
if "condition" in results.columns:
sort_cols += ["condition"]
sort_cols.append("padj")
# Only sort by columns that exist
sort_cols = [c for c in sort_cols if c in results.columns]
results = results.sort_values(by=sort_cols).reset_index(drop=True)
return results