delnx.models.NegativeBinomialRegression

class delnx.models.NegativeBinomialRegression(maxiter=100, tol=1e-06, optimizer='BFGS', skip_stats=False, dispersion=None, dispersion_range=(1e-08, 100.0))[source]

Negative Binomial regression in JAX.

This class implements Negative Binomial regression for modeling count data, particularly RNA-seq data, with support for offsets to incorporate size factors or other normalization terms. The model uses a log link function and allows for overdispersion in count data.

Parameters:
  • maxiter (int (default: 100)) – Maximum number of iterations for optimization algorithms.

  • tol (float (default: 1e-06)) – Convergence tolerance for optimization algorithms.

  • optimizer (str (default: 'BFGS')) – Optimization method to use. Options are “BFGS” or “IRLS”.

  • skip_stats (bool (default: False)) – Whether to skip calculating Wald test statistics.

  • dispersion (float | None (default: None)) – Fixed dispersion parameter. If None, dispersion is estimated from the data.

  • dispersion_range (tuple[float, float] (default: (1e-08, 100.0))) – Range for the dispersion parameter. Used to constrain the estimated dispersion to avoid numerical issues.

Examples

>>> import jax.numpy as jnp
>>> from delnx.models import NegativeBinomialRegression
>>> X = jnp.array([[1.0, 0.0], [1.0, 1.0]])  # Design matrix with intercept
>>> y = jnp.array([10.0, 20.0])  # Count data
>>> size_factors = jnp.array([0.8, 1.2])  # Size factors from normalization
>>> offset = jnp.log(size_factors)  # Log transform for offset
>>> model = NegativeBinomialRegression(optimizer="IRLS")
>>> result = model.fit(X, y, offset=offset)
>>> print(f"Coefficients: {result['coef']}")

Attributes table

Methods table

fit(X, y[, offset, test_idx])

Fit negative binomial regression model with optional offset.

get_llf(X, y, params[, offset, dispersion])

Get log-likelihood at fitted parameters with offset.

predict(X, params[, offset])

Predict count response variable using fitted model.

Attributes

NegativeBinomialRegression.dispersion: float | None = None
NegativeBinomialRegression.dispersion_range: tuple[float, float] = (1e-08, 100.0)
NegativeBinomialRegression.maxiter: int = 100
NegativeBinomialRegression.optimizer: str = 'BFGS'
NegativeBinomialRegression.skip_stats: bool = False
NegativeBinomialRegression.tol: float = 1e-06

Methods

NegativeBinomialRegression.fit(X, y, offset=None, test_idx=-1)[source]

Fit negative binomial regression model with optional offset.

This method fits a Negative Binomial regression model to count data, with support for including offset terms (typically log size factors) to account for normalization. The method also handles dispersion estimation if not provided during initialization.

Parameters:
  • X (Array) – Design matrix of shape (n_samples, n_features).

  • y (Array) – Count response vector of shape (n_samples,).

  • offset (Array | None (default: None)) – Offset term (log scale) to include in the model. Typically log(size_factors) for RNA-seq data. If provided, overrides the offset set during class initialization.

  • test_idx (int (default: -1)) – Index of the parameter to test. If -1, tests the last parameter.

Return type:

dict

Returns:

Dictionary containing:

  • coef: Parameter estimates

  • llf: Log-likelihood at fitted parameters

  • se: Standard errors (None if skip_stats=True)

  • stat: Test statistics (None if skip_stats=True)

  • pval: P-values (None if skip_stats=True)

  • dispersion: Estimated or provided dispersion parameter

NegativeBinomialRegression.get_llf(X, y, params, offset=None, dispersion=1.0)[source]

Get log-likelihood at fitted parameters with offset.

Return type:

float

NegativeBinomialRegression.predict(X, params, offset=None)[source]

Predict count response variable using fitted model.

Parameters:
  • X (Array) – Design matrix of shape (n_samples, n_features).

  • params (Array) – Fitted parameter estimates.

  • offset (Array | None (default: None)) – Offset term to include in the prediction. If provided, overrides the offset set during class initialization.

Return type:

Array

Returns:

jnp.ndarray Predicted count response variable.