Negative binomial DE with quasi-likelihood dispersion shrinkageΒΆ

delnx implements GPU-accelerated negative binomial GLMs following the glmGamPoi approach with quasi-likelihood dispersion shrinkage for differential expression analysis. The two-step workflow (nb_fit + nb_test) separates model fitting from hypothesis testing, allowing you to reuse a single fit across multiple contrasts.

import delnx as dx
import scanpy as sc
import numpy as np
import matplotlib.pyplot as plt

# Load example pseudobulk data
adata = sc.read_h5ad("data/GLI3_KO_45d_pseudobulk.h5ad")
adata.obs["GLI3_KO"] = adata.obs["GLI3_KO"].astype(str)  # Ensure string type for design matrix

# Use raw counts (X contains normalized values)
adata.X = np.round(np.asarray(adata.layers["counts"])).astype(np.float64)

print(adata)
print(f"Count range: {adata.X.min():.0f} - {adata.X.max():.0f}")
AnnData object with n_obs Γ— n_vars = 28 Γ— 16199
    obs: 'psbulk_replicate', 'cell_type', 'organoid', 'GLI3_KO', 'psbulk_cells', 'psbulk_counts', 'size_factor'
    var: 'dispersion', 'dispersion_deseq', 'dispersion_mle', 'dispersion_edger', 'mean', 'mean_norm'
    uns: 'log1p'
    layers: 'counts', 'psbulk_props'
Count range: 0 - 123900

Fit the modelΒΆ

nb_fit handles everything in one call: size factor estimation, dispersion MLE with Cox-Reid bias adjustment, quasi-likelihood shrinkage, and coefficient fitting via IRLS.

# Fit negative binomial GLMs with quasi-likelihood dispersion shrinkage
fit = dx.tl.nb_fit(adata, condition_key="GLI3_KO", reference="True")

print(f"Coefficients: {fit.design_column_names}")
print(f"Genes fitted: {len(fit.overdispersions)}")
print(f"Dispersion range: {fit.overdispersions.min():.4f} - {fit.overdispersions.max():.2f}")
INFO     Fitting 16199 genes with 2 coefficient(s)
INFO     Applying quasi-likelihood shrinkage
Coefficients: ['Intercept', 'GLI3_KO[T.False]']
Genes fitted: 16199
Dispersion range: nan - nan
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Left: dispersion distribution
valid_disp = fit.overdispersions[fit.overdispersions > 0]
axes[0].hist(np.log10(valid_disp), bins=50, color="steelblue", alpha=0.7, edgecolor="white")
axes[0].set_xlabel("log10(dispersion)")
axes[0].set_ylabel("Number of genes")
axes[0].set_title("Dispersion distribution")
axes[0].axvline(np.log10(np.median(valid_disp)), color="red", linestyle="--", label=f"median={np.median(valid_disp):.3f}")
axes[0].legend()

# Right: mean-dispersion relationship
gene_means = np.asarray(adata.X).mean(axis=0) if hasattr(adata.X, 'toarray') else adata.X.mean(axis=0)
mask = (gene_means > 0) & (fit.overdispersions > 0)
axes[1].scatter(np.log10(gene_means[mask]), np.log10(fit.overdispersions[mask]), s=5, alpha=0.3)
axes[1].set_xlabel("log10(mean expression)")
axes[1].set_ylabel("log10(dispersion)")
axes[1].set_title("Mean-dispersion relationship")

plt.tight_layout()
plt.show()
plt.close()
../_images/984379c50a724ad54c41085ff25fc2f971a37e7b88c00a1fc208ea6db7f926aa.png

The left panel shows the distribution of MLE dispersions across genes. Most genes cluster around the median, with a long tail of highly variable genes.

The right panel shows the characteristic mean-dispersion relationship: lowly expressed genes tend to have higher dispersion, while highly expressed genes converge toward lower dispersion values.

Test for differential expressionΒΆ

Now we test for DE using the fitted model. nb_test uses a quasi-likelihood F-test, which accounts for the uncertainty in dispersion estimation.

# Test for differential expression
de_results = dx.tl.nb_test(adata, fit, contrast="GLI3_KO[T.False]")

print(de_results)
      feature    log2fc      coef       stat          pval      padj
0        EMX1  6.761568  4.686762  83.710401  4.326881e-10  0.000007
1      ZNF429 -5.335782 -3.698482  75.925150  1.246159e-09  0.000010
2       SFTA3 -8.920224 -6.183028  60.244111  1.368526e-08  0.000074
3       CXXC4 -0.873021 -0.605132  54.949587  3.389955e-08  0.000106
4        ZIC5  1.904016  1.319763  54.774321  3.496747e-08  0.000106
...       ...       ...       ...        ...           ...       ...
16194    GJB3       NaN       NaN        NaN           NaN  1.000000
16195   CCL16       NaN       NaN        NaN           NaN  1.000000
16196    CCR1       NaN       NaN        NaN           NaN  1.000000
16197   PLET1       NaN       NaN        NaN           NaN  1.000000
16198    ORM2       NaN       NaN        NaN           NaN  1.000000

[16199 rows x 6 columns]
# Volcano plot
fig, ax = plt.subplots(figsize=(8, 6))

sig = de_results["padj"] < 0.05
colors = np.where(sig, "tab:red", "tab:grey")

ax.scatter(de_results["log2fc"], -np.log10(de_results["pval"]), c=colors, s=10, alpha=0.5)
ax.axhline(-np.log10(0.05), color="black", linestyle="--", linewidth=0.5, alpha=0.5)
ax.set_xlabel("log2 fold change")
ax.set_ylabel("-log10(p-value)")
ax.set_title(f"GLI3 KO vs WT β€” {sig.sum()} significant genes (padj < 0.05)")

plt.tight_layout()
plt.show()
plt.close()
../_images/9396cb2224f56bb381832b11fe408d393e095c9714337958c769a3a092f63654.png