"""Negative binomial differential expression analysis.
This module provides the main interface for GPU-accelerated negative binomial
differential expression analysis using the glmGamPoi approach. Core solvers
are vectorized with JAX vmap for batch-parallel processing across genes.
References
----------
Ahlmann-Eltze, C., & Huber, W. (2020). glmGamPoi: fitting Gamma-Poisson
generalized linear models on single cell count data. Bioinformatics.
"""
from dataclasses import dataclass, field
import jax
import jax.numpy as jnp
import numpy as np
import pandas as pd
import statsmodels.api as sm
import tqdm
from anndata import AnnData
from scipy import sparse
from delnx._logging import logger
from delnx._utils import _get_layer, _to_dense
from delnx.models._glm_gp import (
compute_gp_deviance,
compute_gp_deviance_batch,
estimate_dispersion_mle,
estimate_dispersion_mle_batch,
estimate_dispersion_moments,
estimate_dispersion_moments_batch,
fit_beta_fisher_scoring,
fit_beta_newton_batch,
fit_beta_one_group,
fit_beta_one_group_batch,
)
from delnx.models._quasi_likelihood import (
compute_ql_dispersions,
fit_dispersion_trend,
ql_f_test,
shrink_ql_dispersions,
)
from delnx.pp._size_factors import size_factors as compute_size_factors
# =============================================================================
# Result dataclass
# =============================================================================
[docs]
@dataclass
class NBFitResult:
"""Result container for nb_fit.
Attributes
----------
beta : np.ndarray
Fitted coefficients, shape (n_genes, n_coefficients).
overdispersions : np.ndarray
MLE dispersion estimates per gene, shape (n_genes,).
mu : np.ndarray
Fitted mean values, shape (n_samples, n_genes).
size_factors : np.ndarray
Size factors per sample, shape (n_samples,).
deviances : np.ndarray
Deviance per gene, shape (n_genes,).
design_matrix : np.ndarray
Design matrix used, shape (n_samples, n_coefficients).
design_column_names : list[str]
Column names for the design matrix (e.g., ["intercept", "condB"]).
feature_names : pd.Index
Feature/gene names.
layer : str | None
Layer used for count data.
condition_key : str | None
Condition key used for design.
ql_dispersions : np.ndarray | None
Quasi-likelihood dispersions (if shrinkage applied).
df0_prior : float
Prior degrees of freedom from QL shrinkage.
dispersion_trend : np.ndarray | None
Fitted dispersion trend.
converged : np.ndarray
Convergence status per gene.
"""
beta: np.ndarray
overdispersions: np.ndarray
mu: np.ndarray
size_factors: np.ndarray
deviances: np.ndarray
design_matrix: np.ndarray
design_column_names: list[str]
feature_names: pd.Index
layer: str | None = None
condition_key: str | None = None
ql_dispersions: np.ndarray | None = None
df0_prior: float = 0.0
dispersion_trend: np.ndarray | None = None
converged: np.ndarray = field(default_factory=lambda: np.array([]))
# =============================================================================
# Main fitting function
# =============================================================================
[docs]
def nb_fit(
adata: AnnData,
condition_key: str | None = None,
formula: str | None = None,
design: np.ndarray | None = None,
design_column_names: list[str] | None = None,
reference: str | None = None,
covariate_keys: list[str] | None = None,
size_factors: str | np.ndarray | None = "normed_sum",
layer: str | None = None,
overdispersion: bool = True,
batch_size: int = 512,
maxiter: int = 100,
verbose: bool = True,
overdispersion_shrinkage: bool = True,
do_cox_reid_adjustment: bool = True,
) -> NBFitResult:
"""Fit Gamma-Poisson (negative binomial) GLMs to count data.
This implements the glmGamPoi approach for fast and accurate differential
expression analysis using GPU-accelerated Newton-Raphson with quasi-likelihood
dispersion shrinkage.
Parameters
----------
adata : AnnData
AnnData object containing count data.
condition_key : str | None, default=None
Column in ``adata.obs`` for condition labels. Creates a design matrix
with intercept + condition indicators. Mutually exclusive with ``formula``.
formula : str | None, default=None
R-style formula for the design matrix (e.g., ``"~ treatment + batch"``
or ``"~ treatment * batch"``). Parsed by patsy. Mutually exclusive
with ``condition_key``.
design : np.ndarray | None, default=None
Custom design matrix. If provided, overrides ``condition_key``,
``formula``, ``reference``, and ``covariate_keys``.
Shape should be (n_samples, n_coefficients).
design_column_names : list[str] | None, default=None
Column names for a custom ``design`` matrix. Enables string-based
``contrast`` in :func:`nb_test`. If None with custom ``design``,
generic names ``coef_0``, ``coef_1``, ... are used.
reference : str | None, default=None
Reference level for the condition. This level becomes the intercept.
If None, the alphabetically first level is used.
covariate_keys : list[str] | None, default=None
Columns in ``adata.obs`` to include as covariates in the design matrix.
Only used with ``condition_key`` (include covariates in ``formula`` directly).
size_factors : str | np.ndarray | None, default="normed_sum"
Size factors for normalization. Can be:
- ``"normed_sum"``: Compute using normalized sum method
- ``"poscounts"``: DESeq2-style positive counts method
- np.ndarray: Pre-computed size factors
- None: No size factor normalization (all 1s)
layer : str | None, default=None
Layer in ``adata.layers`` containing counts. If None, uses ``adata.X``.
overdispersion : bool, default=True
Whether to estimate overdispersion. If False, uses Poisson (disp=0).
batch_size : int, default=512
Number of genes to process in each batch.
maxiter : int, default=100
Maximum iterations for Newton-Raphson.
verbose : bool, default=True
Whether to show progress.
overdispersion_shrinkage : bool, default=True
Whether to apply quasi-likelihood shrinkage to dispersions.
do_cox_reid_adjustment : bool, default=True
Whether to apply Cox-Reid adjustment to dispersion MLE.
Returns
-------
NBFitResult
Fitted model results containing coefficients, dispersions, and
fitted values.
Examples
--------
Basic usage with condition comparison:
>>> import delnx as dx
>>> fit = dx.tl.nb_fit(adata, condition_key="treatment")
>>> results = dx.tl.nb_test(adata, fit, contrast="treatment[T.drugA]")
With reference level and covariates:
>>> fit = dx.tl.nb_fit(adata, condition_key="treatment", reference="control",
... covariate_keys=["batch", "sex"])
Formula-based design with interactions:
>>> fit = dx.tl.nb_fit(adata, formula="~ treatment * batch")
>>> results = dx.tl.nb_test(adata, fit, contrast="treatment[T.drugA]:batch[T.batch2]")
Continuous covariate:
>>> fit = dx.tl.nb_fit(adata, formula="~ age + sex")
>>> results = dx.tl.nb_test(adata, fit, contrast="age")
"""
from ._design import build_design
# Get count matrix
X = _get_layer(adata, layer)
n_samples, n_genes = X.shape
# Compute or validate size factors
if size_factors is None:
sf = np.ones(n_samples)
elif isinstance(size_factors, str):
# compute_size_factors stores result in adata.obs
compute_size_factors(adata, layer=layer, method=size_factors)
sf = adata.obs["size_factors"].values.astype(float)
else:
sf = np.asarray(size_factors)
sf = np.maximum(sf, 1e-10)
log_sf = np.log(sf)
# Build design matrix (priority: design > formula > condition_key)
if design is not None:
design_matrix = np.asarray(design)
if design_column_names is None:
design_column_names = [f"coef_{i}" for i in range(design_matrix.shape[1])]
elif formula is not None or condition_key is not None:
design_matrix, design_column_names = build_design(
adata.obs,
formula=formula,
condition_key=condition_key,
reference=reference,
covariate_keys=covariate_keys,
)
else:
# Intercept only
design_matrix = np.ones((n_samples, 1))
design_column_names = ["Intercept"]
n_coef = design_matrix.shape[1]
# Convert to JAX arrays
design_jax = jnp.array(design_matrix)
offset_jax = jnp.array(log_sf)
# Determine if intercept-only model
is_intercept_only = n_coef == 1
# Pre-compute QR decomposition for beta initialization
# (matching R's estimate_betas_roughly: OLS on log-normalized counts)
if not is_intercept_only:
Q_init, R_init = np.linalg.qr(design_matrix)
Q_init = Q_init[:, :n_coef]
R_init = R_init[:n_coef, :]
# Auto-tune batch size to fit GPU memory.
# Peak memory per gene in Newton-Raphson: ~n_samples * n_coef * 8 * 10 bytes
# (Hessian, score, W, delta, mu โ multiple intermediates alive during backprop).
# Target: keep batch below ~4 GB to leave room for JAX overhead.
if not is_intercept_only and n_coef > 10:
mem_per_gene = n_samples * n_coef * 8 * 10
max_batch = max(16, int(4e9 / mem_per_gene))
if max_batch < batch_size:
batch_size = max_batch
n_batches = (n_genes + batch_size - 1) // batch_size
if verbose:
logger.info(f"Fitting {n_genes} genes with {n_coef} coefficient(s) (batch_size={batch_size})", verbose=True)
# Initialize storage
all_beta = np.zeros((n_genes, n_coef))
dispersions = np.zeros(n_genes)
deviances = np.zeros(n_genes)
converged = np.zeros(n_genes, dtype=bool)
all_mu = np.zeros((n_samples, n_genes))
sf_jax = jnp.array(sf)
for batch_idx in tqdm.tqdm(range(n_batches), disable=not verbose, desc="Fitting GLMs"):
start = batch_idx * batch_size
end = min(start + batch_size, n_genes)
b = end - start # batch width
# Get batch data as dense JAX array (n_samples, b)
X_batch_np = _to_dense(X[:, start:end]).astype(np.float64)
X_batch = jnp.array(X_batch_np)
# --- Step 1: Initial beta fit with moment-based dispersion ---
if is_intercept_only:
# Moment dispersion: broadcast sf across genes
mu_init = sf_jax[:, None] * jnp.mean(X_batch / sf_jax[:, None], axis=0, keepdims=True)
disp_init = estimate_dispersion_moments_batch(X_batch, mu_init)
disp_init = jnp.maximum(disp_init, 1e-8)
# Batch fit intercept-only
beta0_arr, dev_arr, conv_arr = fit_beta_one_group_batch(
X_batch, sf_jax, disp_init, maxiter, 1e-8,
)
# Compute mu from intercept
mu_batch = sf_jax[:, None] * jnp.exp(beta0_arr[None, :])
else:
# Moment dispersion
mu_init = sf_jax[:, None] * jnp.mean(X_batch / sf_jax[:, None], axis=0, keepdims=True)
disp_init = estimate_dispersion_moments_batch(X_batch, mu_init)
disp_init = jnp.maximum(disp_init, 1e-8)
# OLS initialization for all genes in batch
log_norm = np.log(X_batch_np / sf[:, None] + 1.0) # (n_samples, b)
init_betas = np.linalg.solve(R_init, Q_init.T @ log_norm) # (n_coef, b)
init_betas_jax = jnp.array(init_betas.T) # (b, n_coef)
# Batch Newton-Raphson fit
beta_arr, dev_arr, conv_arr = fit_beta_newton_batch(
X_batch, design_jax, offset_jax, disp_init, init_betas_jax,
maxiter, 1e-8,
)
# Compute mu
# beta_arr: (b, n_coef), design_jax: (n_samples, n_coef)
eta_batch = design_jax @ beta_arr.T + offset_jax[:, None] # (n_samples, b)
mu_batch = jnp.exp(jnp.clip(eta_batch, -50, 50))
# --- Step 2: Dispersion MLE ---
if overdispersion:
disp_arr, _ = estimate_dispersion_mle_batch(
X_batch, mu_batch, design_jax, disp_init,
do_cox_reid_adjustment, 50, 1e-6,
)
else:
disp_arr = jnp.full(b, 1e-10)
# --- Step 3: Refit beta with final dispersion ---
if is_intercept_only:
beta0_arr, dev_arr, conv_arr = fit_beta_one_group_batch(
X_batch, sf_jax, disp_arr, maxiter, 1e-8,
)
mu_batch = sf_jax[:, None] * jnp.exp(beta0_arr[None, :])
all_beta[start:end, 0] = np.asarray(beta0_arr)
else:
beta_arr, dev_arr, conv_arr = fit_beta_newton_batch(
X_batch, design_jax, offset_jax, disp_arr, beta_arr,
maxiter, 1e-8,
)
eta_batch = design_jax @ beta_arr.T + offset_jax[:, None]
mu_batch = jnp.exp(jnp.clip(eta_batch, -50, 50))
all_beta[start:end] = np.asarray(beta_arr)
# Store results
dispersions[start:end] = np.asarray(disp_arr)
deviances[start:end] = np.asarray(dev_arr)
converged[start:end] = np.asarray(conv_arr)
all_mu[:, start:end] = np.asarray(mu_batch)
# Apply quasi-likelihood shrinkage if requested
ql_dispersions = None
df0_prior = 0.0
dispersion_trend = None
if overdispersion_shrinkage and overdispersion:
if verbose:
logger.info("Applying quasi-likelihood shrinkage", verbose=True)
# Compute mean expression
mean_expression = np.mean(all_mu, axis=0)
# Fit dispersion trend
dispersion_trend = fit_dispersion_trend(
dispersions, mean_expression, method="local_median"
)
# Recompute deviances using trend dispersion (batched)
disp_trend_jax = jnp.array(dispersion_trend)
for batch_start in range(0, n_genes, batch_size):
batch_end = min(batch_start + batch_size, n_genes)
counts_batch = jnp.array(_to_dense(X[:, batch_start:batch_end]).astype(np.float64))
mu_batch = jnp.array(all_mu[:, batch_start:batch_end])
disp_batch = disp_trend_jax[batch_start:batch_end]
deviances[batch_start:batch_end] = np.asarray(
compute_gp_deviance_batch(counts_batch, mu_batch, disp_batch)
)
# Transform to QL scale
ql_dispersions = compute_ql_dispersions(
dispersions, mean_expression, dispersion_trend
)
# Apply empirical Bayes shrinkage
df_residual = n_samples - n_coef
ql_dispersions, df0_prior, _ = shrink_ql_dispersions(
ql_dispersions, df_residual
)
return NBFitResult(
beta=all_beta,
overdispersions=dispersions,
mu=all_mu,
size_factors=sf,
deviances=deviances,
design_matrix=design_matrix,
design_column_names=design_column_names,
feature_names=adata.var_names,
layer=layer,
condition_key=condition_key,
ql_dispersions=ql_dispersions,
df0_prior=df0_prior,
dispersion_trend=dispersion_trend,
converged=converged,
)
# =============================================================================
# Differential expression testing
# =============================================================================
def _wald_test(
fit: NBFitResult,
adata: AnnData,
contrast_vec: np.ndarray,
multitest_method: str = "fdr_bh",
batch_size: int = 512,
lfc_threshold: float = 0.0,
) -> pd.DataFrame:
"""Wald test for a custom contrast vector.
Tests H0: c'ฮฒ = 0 for each gene using the Wald statistic
(c'ฮฒ)ยฒ / (c' V c) where V = (X'WX)โปยน is the coefficient covariance.
"""
n_genes = fit.beta.shape[0]
n_coef = fit.beta.shape[1]
n_samples = fit.design_matrix.shape[0]
df_full = n_samples - n_coef
c = contrast_vec
X = fit.design_matrix
# c'ฮฒ for all genes: (n_genes,)
coefs = fit.beta @ c
log2fc = np.clip(coefs / np.log(2), -10.0, 10.0)
# Use dispersion trend if available, else MLE
disp = fit.dispersion_trend if fit.dispersion_trend is not None else fit.overdispersions
# Compute Wald variance per gene: c' (X'WX)โปยน c
# W_ii = mu_i / (1 + disp * mu_i) โ expected Fisher weights
wald_var = np.zeros(n_genes)
for start in range(0, n_genes, batch_size):
end = min(start + batch_size, n_genes)
mu_batch = fit.mu[:, start:end] # (n_samples, batch)
disp_batch = disp[start:end] # (batch,)
for j in range(end - start):
w = mu_batch[:, j] / (1.0 + disp_batch[j] * mu_batch[:, j])
XtWX = X.T @ (X * w[:, None])
try:
V = np.linalg.solve(XtWX, np.eye(n_coef))
wald_var[start + j] = c @ V @ c
except np.linalg.LinAlgError:
wald_var[start + j] = np.nan
# F-statistic
f_stats = coefs**2 / np.maximum(wald_var, 1e-300)
if fit.ql_dispersions is not None:
f_stats = f_stats / np.maximum(fit.ql_dispersions, 1e-10)
df_denom = df_full + fit.df0_prior
from scipy import stats
pvals = stats.f.sf(f_stats, 1, df_denom)
else:
from scipy import stats
pvals = stats.chi2.sf(f_stats, df=1)
# Multiple testing correction
valid_pvals = np.isfinite(pvals)
padj = np.ones_like(pvals)
if valid_pvals.sum() > 0:
padj[valid_pvals] = sm.stats.multipletests(
pvals[valid_pvals], method=multitest_method
)[1]
results = pd.DataFrame({
"feature": fit.feature_names,
"log2fc": log2fc,
"coef": coefs,
"stat": f_stats,
"pval": pvals,
"padj": padj,
})
if lfc_threshold > 0:
results = results[np.abs(results["log2fc"]) >= lfc_threshold]
return results.sort_values("pval").reset_index(drop=True)
[docs]
def nb_test(
adata: AnnData,
fit: NBFitResult,
contrast: str | int | None = None,
reduced_design: np.ndarray | None = None,
multitest_method: str = "fdr_bh",
lfc_threshold: float = 0.0,
batch_size: int = 512,
) -> pd.DataFrame:
"""Test for differential expression using quasi-likelihood F-test.
Parameters
----------
adata : AnnData
AnnData object (same one passed to ``nb_fit()``).
fit : NBFitResult
Fitted model from ``nb_fit()``.
contrast : str | int | None, default=None
Coefficient to test. Supports shorthand:
- Level name: ``"drugA"`` (resolved via ``condition_key``).
- Bracket shorthand: ``"treatment[drugA]"`` (for multi-covariate models).
- Full patsy name: ``"treatment[T.drugA]"`` (always works).
- Integer index or None (last coefficient).
- None: Test last coefficient.
reduced_design : np.ndarray | None, default=None
Reduced design matrix for likelihood ratio test.
If None, automatically created by dropping the tested coefficient.
multitest_method : str, default="fdr_bh"
Method for multiple testing correction.
lfc_threshold : float, default=0.0
Threshold for log2 fold change filtering.
batch_size : int, default=512
Number of genes to process in each batch. Controls GPU memory usage
for the reduced model refit.
Returns
-------
pd.DataFrame
Results with columns:
- ``feature``: Gene/feature names
- ``log2fc``: Log2 fold change
- ``coef``: Model coefficient
- ``stat``: F-statistic
- ``pval``: Raw p-value
- ``padj``: Adjusted p-value
Examples
--------
Test for treatment effect:
>>> fit = dx.tl.nb_fit(adata, condition_key="treatment")
>>> results = dx.tl.nb_test(adata, fit)
Test specific contrast by name:
>>> results = dx.tl.nb_test(adata, fit, contrast="treatmentB")
"""
n_genes = fit.beta.shape[0]
n_coef = fit.beta.shape[1]
n_samples = fit.design_matrix.shape[0]
# Determine which coefficient to test
from ._design import parse_contrast_vector, resolve_contrast
contrast_vec = parse_contrast_vector(contrast, fit.design_column_names, condition_key=fit.condition_key)
if contrast_vec is not None:
return _wald_test(fit, adata, contrast_vec, multitest_method, batch_size, lfc_threshold)
test_idx = resolve_contrast(contrast, fit.design_column_names, condition_key=fit.condition_key)
# Get coefficients for tested term
coefs = fit.beta[:, test_idx]
# Log2 fold change (coefficient is on log scale), capped at ยฑ10
log2fc = np.clip(coefs / np.log(2), -10.0, 10.0)
# Build reduced design by dropping tested column
if reduced_design is None:
keep_cols = [i for i in range(n_coef) if i != test_idx]
reduced_design = fit.design_matrix[:, keep_cols]
# Likelihood ratio test: refit reduced model and compare deviances.
reduced_jax = jnp.array(reduced_design)
offset_jax = jnp.array(np.log(np.maximum(fit.size_factors, 1e-10)))
n_coef_reduced = reduced_design.shape[1]
# Use dispersion trend if available (matching R), else MLE
disp_vec = (
jnp.array(fit.dispersion_trend)
if fit.dispersion_trend is not None
else jnp.array(fit.overdispersions)
)
# Get counts from adata (not stored on fit)
X = _get_layer(adata, fit.layer)
sf_jax = jnp.array(fit.size_factors)
# Auto-tune batch size for reduced model refit
if n_coef_reduced > 10:
mem_per_gene = n_samples * n_coef_reduced * 8 * 10
max_batch = max(16, int(4e9 / mem_per_gene))
if max_batch < batch_size:
batch_size = max_batch
# Refit reduced model in gene batches to control GPU memory
deviances_reduced = np.zeros(n_genes)
for start in range(0, n_genes, batch_size):
end = min(start + batch_size, n_genes)
counts_batch = jnp.array(_to_dense(X[:, start:end]).astype(np.float64))
disp_batch = disp_vec[start:end]
if n_coef_reduced == 1:
_, dev_batch, _ = fit_beta_one_group_batch(
counts_batch, sf_jax, disp_batch, 100, 1e-8,
)
else:
init_batch = jnp.zeros((end - start, n_coef_reduced))
init_batch = init_batch.at[:, 0].set(jnp.array(fit.beta[start:end, 0]))
_, dev_batch, _ = fit_beta_newton_batch(
counts_batch, reduced_jax, offset_jax, disp_batch, init_batch,
100, 1e-8,
)
deviances_reduced[start:end] = np.asarray(dev_batch)
# F-statistic from deviance difference
df_full = n_samples - n_coef
df_test = n_coef - n_coef_reduced # typically 1
dev_diff = np.maximum(deviances_reduced - fit.deviances, 0.0)
if fit.ql_dispersions is not None:
# Quasi-likelihood F-test
f_stats = dev_diff / (df_test * np.maximum(fit.ql_dispersions, 1e-10))
df_denom = df_full + fit.df0_prior
from scipy import stats
pvals = stats.f.sf(f_stats, df_test, df_denom)
else:
# Without QL shrinkage, use chi-squared approximation
f_stats = dev_diff / df_test
from scipy import stats
pvals = stats.chi2.sf(f_stats, df=df_test)
# Multiple testing correction
valid_pvals = np.isfinite(pvals)
padj = np.ones_like(pvals)
if valid_pvals.sum() > 0:
padj[valid_pvals] = sm.stats.multipletests(
pvals[valid_pvals], method=multitest_method
)[1]
# Create results DataFrame
results = pd.DataFrame({
"feature": fit.feature_names,
"log2fc": log2fc,
"coef": coefs,
"stat": f_stats,
"pval": pvals,
"padj": padj,
})
# Apply LFC threshold if specified
if lfc_threshold > 0:
results = results[np.abs(results["log2fc"]) >= lfc_threshold]
# Sort by p-value
results = results.sort_values("pval").reset_index(drop=True)
return results
[docs]
def nb_de(
adata: AnnData,
condition_key: str | None = None,
formula: str | None = None,
design: np.ndarray | None = None,
design_column_names: list[str] | None = None,
reference: str | None = None,
covariate_keys: list[str] | None = None,
size_factors: str | np.ndarray | None = "normed_sum",
layer: str | None = None,
contrast: str | int | None = None,
multitest_method: str = "fdr_bh",
lfc_threshold: float = 0.0,
overdispersion: bool = True,
batch_size: int = 512,
maxiter: int = 100,
verbose: bool = True,
overdispersion_shrinkage: bool = True,
do_cox_reid_adjustment: bool = True,
) -> pd.DataFrame:
"""One-shot negative binomial DE: fit model and test in one call.
Convenience wrapper around :func:`nb_fit` + :func:`nb_test`.
For reusing a fit across multiple contrasts, call them separately.
Parameters
----------
adata : AnnData
AnnData object containing count data.
condition_key : str | None, default=None
Column in ``adata.obs`` for condition labels. Mutually exclusive
with ``formula``.
formula : str | None, default=None
R-style formula for the design matrix (e.g., ``"~ treatment + batch"``).
Mutually exclusive with ``condition_key``.
design : np.ndarray | None, default=None
Custom design matrix. Overrides ``condition_key`` and ``formula``.
design_column_names : list[str] | None, default=None
Column names for a custom ``design`` matrix.
reference : str | None, default=None
Reference level for the condition.
covariate_keys : list[str] | None, default=None
Columns in ``adata.obs`` to include as covariates.
size_factors : str | np.ndarray | None, default="normed_sum"
Size factors for normalization.
layer : str | None, default=None
Layer in ``adata.layers`` containing counts.
contrast : str | int | None, default=None
Contrast to test (passed to :func:`nb_test`).
multitest_method : str, default="fdr_bh"
Method for multiple testing correction.
lfc_threshold : float, default=0.0
Minimum absolute log2 fold change threshold.
overdispersion : bool, default=True
Whether to estimate overdispersion.
batch_size : int, default=512
Number of genes per batch.
maxiter : int, default=100
Maximum iterations for Newton-Raphson.
verbose : bool, default=True
Whether to show progress.
overdispersion_shrinkage : bool, default=True
Whether to apply quasi-likelihood shrinkage.
do_cox_reid_adjustment : bool, default=True
Whether to apply Cox-Reid adjustment.
Returns
-------
pd.DataFrame
DE results (same as :func:`nb_test`).
Examples
--------
Simple condition comparison:
>>> results = dx.tl.nb_de(adata, condition_key="treatment", reference="control")
Formula-based with covariates:
>>> results = dx.tl.nb_de(adata, formula="~ treatment + batch",
... contrast="treatment[T.drugA]")
"""
fit = nb_fit(
adata,
condition_key=condition_key,
formula=formula,
design=design,
design_column_names=design_column_names,
reference=reference,
covariate_keys=covariate_keys,
size_factors=size_factors,
layer=layer,
overdispersion=overdispersion,
batch_size=batch_size,
maxiter=maxiter,
verbose=verbose,
overdispersion_shrinkage=overdispersion_shrinkage,
do_cox_reid_adjustment=do_cox_reid_adjustment,
)
return nb_test(
adata,
fit,
contrast=contrast,
multitest_method=multitest_method,
lfc_threshold=lfc_threshold,
batch_size=batch_size,
)