Source code for delnx.models._models

"""Regression models in JAX."""

import warnings
from collections.abc import Callable
from dataclasses import dataclass
from functools import partial

import jax
import jax.numpy as jnp
import jax.scipy as jsp
from jax.scipy import optimize

# Enable x64 precision globally
try:
    jax.config.update("jax_enable_x64", True)
    if not jax.config.jax_enable_x64:
        warnings.warn(
            "JAX x64 precision could not be enabled. This might lead to numerical instabilities.", stacklevel=2
        )
except Exception as e:  # noqa: BLE001
    warnings.warn(f"JAX configuration failed: {e}", stacklevel=2)


@dataclass(frozen=True)
class Regression:
    """Base class for regression models.

    This is the abstract base class for all regression models in the package.
    It provides common functionality for fitting models, computing statistics,
    and handling offsets for normalization.

    Parameters
    ----------
    maxiter : int, default=100
        Maximum number of iterations for optimization algorithms.
    tol : float, default=1e-6
        Convergence tolerance for optimization algorithms.
    optimizer : str, default="BFGS"
        Optimization method to use. Options include "BFGS" and "IRLS"
        (Iteratively Reweighted Least Squares) for GLM-type models.
    skip_stats : bool, default=False
        Whether to skip calculating Wald test statistics (for faster computation).
    """

    maxiter: int = 100
    tol: float = 1e-6
    optimizer: str = "BFGS"
    skip_stats: bool = False

    def _fit_bfgs(self, neg_ll_fn: Callable, init_params: jnp.ndarray, **kwargs) -> jnp.ndarray:
        """Fit model using the BFGS optimizer.

        Parameters
        ----------
        neg_ll_fn : Callable
            Function that computes the negative log-likelihood.
        init_params : jnp.ndarray
            Initial parameter values.
        **kwargs
            Additional arguments passed to the optimizer.

        Returns
        -------
        jnp.ndarray
            Optimized parameters.
        """
        result = optimize.minimize(neg_ll_fn, init_params, method="BFGS", options={"maxiter": self.maxiter})
        return result.x

    def _fit_irls(
        self,
        X: jnp.ndarray,
        y: jnp.ndarray,
        weight_fn: Callable,
        working_resid_fn: Callable,
        init_params: jnp.ndarray,
        offset: jnp.ndarray | None = None,
        **kwargs,
    ) -> jnp.ndarray:
        """Fit model using Iteratively Reweighted Least Squares algorithm.

        This implements the IRLS algorithm for generalized linear models
        with support for offset terms. For count models (e.g., Negative
        Binomial), the offset is used to incorporate size factors.

        Parameters
        ----------
        X : jnp.ndarray
            Design matrix of shape (n_samples, n_features).
        y : jnp.ndarray
            Response vector of shape (n_samples,).
        weight_fn : Callable
            Function to compute weights at each iteration.
        working_resid_fn : Callable
            Function to compute working residuals at each iteration.
        init_params : jnp.ndarray
            Initial parameter values.
        offset : jnp.ndarray | None, default=None
            Offset term (log scale for GLMs) to include in the model.
        **kwargs
            Additional arguments passed to weight_fn and working_resid_fn.

        Returns
        -------
        jnp.ndarray
            Optimized parameters.
        """
        n, p = X.shape
        eps = 1e-6

        # Handle offset
        if offset is None:
            offset = jnp.zeros(n)

        def irls_step(state):
            i, converged, beta = state

            # Compute weights and working residuals
            W = weight_fn(X, beta, offset=offset, **kwargs)
            z = working_resid_fn(X, y, beta, offset=offset, **kwargs)

            # Weighted design matrix
            W_sqrt = jnp.sqrt(W)
            X_weighted = X * W_sqrt[:, None]
            z_weighted = z * W_sqrt

            # Solve weighted least squares: (X^T W X) β = X^T W z
            XtWX = X_weighted.T @ X_weighted
            XtWz = X_weighted.T @ z_weighted
            beta_new = jax.scipy.linalg.solve(XtWX + eps * jnp.eye(p), XtWz, assume_a="pos")

            # Check convergence
            delta = jnp.max(jnp.abs(beta_new - beta))
            converged = delta < self.tol

            return i + 1, converged, beta_new

        def irls_cond(state):
            i, converged, _ = state
            return jnp.logical_and(i < self.maxiter, ~converged)

        # Initialize state
        state = (0, False, init_params)
        final_state = jax.lax.while_loop(irls_cond, irls_step, state)
        _, _, beta_final = final_state
        return beta_final

    def _compute_stats(
        self,
        X: jnp.ndarray,
        neg_ll_fn: Callable,
        params: jnp.ndarray,
        test_idx: int = -1,
    ) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
        """Compute test statistics for fitted parameters.
        This method computes the Wald test statistics and p-values for the
        fitted parameters using the Hessian of the negative log-likelihood function.
        If the Hessian is ill-conditioned, it falls back to a likelihood ratio test.

        Parameters
        ----------
        X : jnp.ndarray
            Design matrix of shape (n_samples, n_features).
        neg_ll_fn : Callable
            Function that computes the negative log-likelihood.
        params : jnp.ndarray
            Fitted parameter estimates.
        test_idx : int, default=-1
            Index of the parameter to test. If -1, tests the last parameter.

        Returns
        -------
        tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]
            Standard errors, test statistics, and p-values.
        """  # noqa: D205
        hess_fn = jax.hessian(neg_ll_fn)
        hessian = hess_fn(params)
        hessian = 0.5 * (hessian + hessian.T)

        # Check condition number
        condition_number = jnp.linalg.cond(hessian)

        def wald_test():
            """Perform Wald test."""
            se = jnp.sqrt(jnp.clip(jnp.diag(jnp.linalg.inv(hessian)), 1e-8))
            stat = (params / se) ** 2
            pval = jsp.stats.chi2.sf(stat, df=1)
            return se, stat, pval

        def likelihood_ratio_test():
            """Perform likelihood ratio test as a fallback for ill-conditioned cases."""
            ll_full = -neg_ll_fn(params)
            params_reduced = params.at[test_idx].set(0.0)
            ll_reduced = -neg_ll_fn(params_reduced)
            # Compute likelihood ratio statistic
            lr_stat = 2 * (ll_full - ll_reduced)
            lr_stat = jnp.maximum(lr_stat, 0.0)
            # Compute correction for small sample sizes (where appropriate)
            n_samples = X.shape[0]
            n_params = X.shape[1]
            correction = 1 + n_params / jnp.maximum(1.0, n_samples - n_params)
            corrected_lr_stat = lr_stat / correction
            # Compute p-value for the likelihood ratio statistic
            lr_pval = jsp.stats.chi2.sf(corrected_lr_stat, df=1)
            # Return dummy values for SE and stat
            se = jnp.full_like(params, jnp.nan)
            stat = jnp.zeros_like(params)
            stat = stat.at[test_idx].set(lr_stat)
            pval = jnp.ones_like(params)
            pval = pval.at[test_idx].set(lr_pval)
            return se, stat, pval

        stats = jax.lax.cond(
            condition_number < 1e5,  # Relatively conservative threshold
            lambda _: wald_test(),
            lambda _: likelihood_ratio_test(),
            operand=None,
        )

        return stats

    def _exact_solution(self, X: jnp.ndarray, y: jnp.ndarray, offset: jnp.ndarray | None = None) -> jnp.ndarray:
        """Compute exact Ordinary Least Squares solution.

        For linear regression, the offset is incorporated by adjusting the
        response variable (y - offset) rather than the linear predictor.

        Parameters
        ----------
        X : jnp.ndarray
            Design matrix of shape (n_samples, n_features).
        y : jnp.ndarray
            Response vector of shape (n_samples,).
        offset : jnp.ndarray | None, default=None
            Offset term to include in the model.

        Returns
        -------
        jnp.ndarray
            Coefficient estimates.
        """
        if offset is not None:
            # Adjust y by subtracting offset for linear regression
            y_adj = y - offset
        else:
            y_adj = y

        XtX = X.T @ X
        Xty = X.T @ y_adj
        params = jax.scipy.linalg.solve(XtX, Xty, assume_a="pos")
        return params

    def get_llf(self, X: jnp.ndarray, y: jnp.ndarray, params: jnp.ndarray, offset: jnp.ndarray | None = None) -> float:
        """Get log-likelihood at fitted parameters.

        This method converts the negative log-likelihood to a log-likelihood
        value, which is useful for model comparison and likelihood ratio tests.

        Parameters
        ----------
        X : jnp.ndarray
            Design matrix of shape (n_samples, n_features).
        y : jnp.ndarray
            Response vector of shape (n_samples,).
        params : jnp.ndarray
            Parameter estimates.
        offset : jnp.ndarray | None, default=None
            Offset term to include in the model.

        Returns
        -------
        float
            Log-likelihood value.
        """
        nll = self._negative_log_likelihood(params, X, y, offset)
        return -nll  # Convert negative log-likelihood to log-likelihood


[docs] @dataclass(frozen=True) class LinearRegression(Regression): """Linear regression with Ordinary Least Squares estimation. This class implements a basic linear regression model using OLS, with support for including offset terms. For linear models, offsets are applied by subtracting from the response variable rather than adding to the linear predictor. Parameters ---------- maxiter : int, default=100 Maximum number of iterations for optimization (inherited from Regression). tol : float, default=1e-6 Convergence tolerance (inherited from Regression). optimizer : str, default="BFGS" Optimization method (inherited from Regression). skip_stats : bool, default=False Whether to skip calculating Wald test statistics (inherited from Regression). Examples -------- >>> import jax.numpy as jnp >>> from delnx.models import LinearRegression >>> X = jnp.array([[1.0, 0.5], [1.0, 1.5], [1.0, 2.5]]) # Design matrix with intercept >>> y = jnp.array([1.0, 2.0, 3.0]) # Response variable >>> model = LinearRegression() >>> result = model.fit(X, y) >>> print(f"Coefficients: {result['coef']}") """ def _negative_log_likelihood( self, params: jnp.ndarray, X: jnp.ndarray, y: jnp.ndarray, offset: jnp.ndarray | None = None ) -> float: """Compute negative log likelihood (assuming Gaussian noise) with offset.""" pred = jnp.dot(X, params) if offset is not None: pred = pred + offset residuals = y - pred return 0.5 * jnp.sum(residuals**2) def _compute_cov_matrix( self, X: jnp.ndarray, params: jnp.ndarray, y: jnp.ndarray, offset: jnp.ndarray | None = None ) -> jnp.ndarray: """Compute covariance matrix for parameters with offset.""" n = X.shape[0] pred = X @ params if offset is not None: pred = pred + offset residuals = y - pred sigma2 = jnp.sum(residuals**2) / (n - len(params)) return sigma2 * jnp.linalg.pinv(X.T @ X)
[docs] def fit(self, X: jnp.ndarray, y: jnp.ndarray, offset: jnp.ndarray | None = None) -> dict: """Fit linear regression model. Parameters ---------- X : jnp.ndarray Design matrix of shape (n_samples, n_features). y : jnp.ndarray Response vector of shape (n_samples,). offset : jnp.ndarray | None, default=None Offset term to include in the model. If provided, overrides the offset set during class initialization. Returns ------- Dictionary containing: - coef: Parameter estimates - llf: Log-likelihood at fitted parameters - se: Standard errors (:obj:`None` if `skip_stats=True`) - stat: Test statistics (:obj:`None` if `skip_stats=True`) - pval: P-values (:obj:`None` if `skip_stats=True`) """ # Fit model params = self._exact_solution(X, y, offset) # Compute standard errors llf = self.get_llf(X, y, params, offset) # Compute test statistics if requested se = stat = pval = None if not self.skip_stats: cov = self._compute_cov_matrix(X, params, y, offset) se = jnp.sqrt(jnp.diag(cov)) stat = (params[-1] / se[-1]) ** 2 pval = jsp.stats.chi2.sf(stat, df=1) return {"coef": params, "llf": llf, "se": se, "stat": stat, "pval": pval}
[docs] def predict(self, X: jnp.ndarray, params: jnp.ndarray, offset: jnp.ndarray | None = None) -> jnp.ndarray: """Predict response variable using fitted model. Parameters ---------- X : jnp.ndarray Design matrix of shape (n_samples, n_features). offset : jnp.ndarray | None, default=None Offset term to include in the prediction. If provided, overrides the offset set during class initialization. Returns ------- jnp.ndarray Predicted response variable. """ pred = X @ params if offset is not None: pred += offset return pred
[docs] @dataclass(frozen=True) class LogisticRegression(Regression): """Logistic regression in JAX. This class implements logistic regression for binary classification tasks with support for offset terms. Offsets are added to the linear predictor before applying the logistic function. Parameters ---------- maxiter : int, default=100 Maximum number of iterations for optimization algorithms. tol : float, default=1e-6 Convergence tolerance for optimization algorithms. optimizer : str, default="BFGS" Optimization method to use. Options are "BFGS" or "IRLS" (recommended). skip_stats : bool, default=False Whether to skip calculating test statistics. Examples -------- >>> import jax.numpy as jnp >>> from delnx.models import LogisticRegression >>> X = jnp.array([[1.0, 0.5], [1.0, 1.5], [1.0, 2.5]]) # Design matrix with intercept >>> y = jnp.array([0.0, 0.0, 1.0]) # Binary outcome >>> model = LogisticRegression(optimizer="IRLS") >>> result = model.fit(X, y) >>> print(f"Coefficients: {result['coef']}") """ def _negative_log_likelihood( self, params: jnp.ndarray, X: jnp.ndarray, y: jnp.ndarray, offset: jnp.ndarray | None = None ) -> float: """Compute negative log likelihood with offset.""" logits = jnp.dot(X, params) if offset is not None: logits = logits + offset nll = -jnp.sum(y * logits - jnp.logaddexp(0.0, logits)) return nll def _weight_fn(self, X: jnp.ndarray, beta: jnp.ndarray, offset: jnp.ndarray | None = None) -> jnp.ndarray: """Compute weights for IRLS with offset.""" eta = X @ beta if offset is not None: eta = eta + offset eta = jnp.clip(eta, -50, 50) p = jax.nn.sigmoid(eta) return p * (1 - p) def _working_resid_fn( self, X: jnp.ndarray, y: jnp.ndarray, beta: jnp.ndarray, offset: jnp.ndarray | None = None ) -> jnp.ndarray: """Compute working residuals for IRLS with offset.""" eta = X @ beta if offset is not None: eta = eta + offset eta = jnp.clip(eta, -50, 50) p = jax.nn.sigmoid(eta) return eta + (y - p) / jnp.clip(p * (1 - p), 1e-6)
[docs] def fit( self, X: jnp.ndarray, y: jnp.ndarray, offset: jnp.ndarray | None = None, test_idx: int = -1, ) -> dict: """Fit logistic regression model. Parameters ---------- X : jnp.ndarray Design matrix of shape (n_samples, n_features). y : jnp.ndarray Binary response vector of shape (n_samples,). offset : jnp.ndarray | None, default=None Offset term to include in the model. If provided, overrides the offset set during class initialization. Returns ------- Dictionary containing: - coef: Parameter estimates - llf: Log-likelihood at fitted parameters - se: Standard errors (:obj:`None` if `skip_stats=True`) - stat: Test statistics (:obj:`None` if `skip_stats=True`) - pval: P-values (:obj:`None` if `skip_stats=True`) """ # Fit model init_params = jnp.zeros(X.shape[1]) if self.optimizer == "BFGS": nll = partial(self._negative_log_likelihood, X=X, y=y, offset=offset) params = self._fit_bfgs(nll, init_params) elif self.optimizer == "IRLS": params = self._fit_irls(X, y, self._weight_fn, self._working_resid_fn, init_params, offset=offset) else: raise ValueError(f"Unsupported optimizer: {self.optimizer}") # Get log-likelihood llf = self.get_llf(X, y, params, offset) # Compute test statistics if requested se = stat = pval = None if not self.skip_stats: nll = partial(self._negative_log_likelihood, X=X, y=y, offset=offset) se, stat, pval = self._compute_stats(X, nll, params, test_idx=test_idx) return { "coef": params, "llf": llf, "se": se, "stat": stat, "pval": pval, }
[docs] def predict(self, X: jnp.ndarray, params: jnp.ndarray, offset: jnp.ndarray | None = None) -> jnp.ndarray: """Predict probabilities using fitted model. Parameters ---------- X : jnp.ndarray Design matrix of shape (n_samples, n_features). params : jnp.ndarray Fitted parameter estimates. offset : jnp.ndarray | None, default=None Offset term to include in the prediction. If provided, overrides the offset set during class initialization. Returns ------- jnp.ndarray Predicted probabilities of the positive class. """ logits = X @ params if offset is not None: logits += offset return jax.nn.sigmoid(logits)
[docs] @dataclass(frozen=True) class NegativeBinomialRegression(Regression): """Negative Binomial regression in JAX. This class implements Negative Binomial regression for modeling count data, particularly RNA-seq data, with support for offsets to incorporate size factors or other normalization terms. The model uses a log link function and allows for overdispersion in count data. Parameters ---------- maxiter : int, default=100 Maximum number of iterations for optimization algorithms. tol : float, default=1e-6 Convergence tolerance for optimization algorithms. optimizer : str, default="BFGS" Optimization method to use. Options are "BFGS" or "IRLS". skip_stats : bool, default=False Whether to skip calculating Wald test statistics. dispersion : float | None, default=None Fixed dispersion parameter. If :obj:`None`, dispersion is estimated from the data. dispersion_range : tuple[float, float], default=(1e-6, 10.0) Range for the dispersion parameter. Used to constrain the estimated dispersion to avoid numerical issues. Examples -------- >>> import jax.numpy as jnp >>> from delnx.models import NegativeBinomialRegression >>> X = jnp.array([[1.0, 0.0], [1.0, 1.0]]) # Design matrix with intercept >>> y = jnp.array([10.0, 20.0]) # Count data >>> size_factors = jnp.array([0.8, 1.2]) # Size factors from normalization >>> offset = jnp.log(size_factors) # Log transform for offset >>> model = NegativeBinomialRegression(optimizer="IRLS") >>> result = model.fit(X, y, offset=offset) >>> print(f"Coefficients: {result['coef']}") """ dispersion: float | None = None dispersion_range: tuple[float, float] = (1e-8, 100.0) def _negative_log_likelihood( self, params: jnp.ndarray, X: jnp.ndarray, y: jnp.ndarray, offset: jnp.ndarray | None = None, dispersion: float = 1.0, ) -> float: """Compute negative log likelihood with offset.""" eta = X @ params if offset is not None: eta = eta + offset eta = jnp.clip(eta, -50, 50) mu = jnp.exp(eta) r = 1 / jnp.clip(dispersion, self.dispersion_range[0], self.dispersion_range[1]) ll = ( jsp.special.gammaln(r + y) - jsp.special.gammaln(r) - jsp.special.gammaln(y + 1) + r * jnp.log(r / (r + mu)) + y * jnp.log(mu / (r + mu)) ) return -jnp.sum(ll) def _weight_fn( self, X: jnp.ndarray, beta: jnp.ndarray, offset: jnp.ndarray | None = None, dispersion: float = 1.0 ) -> jnp.ndarray: """Compute weights for IRLS with offset.""" eta = X @ beta if offset is not None: eta = eta + offset eta = jnp.clip(eta, -50, 50) mu = jnp.exp(eta) # Negative binomial variance = μ + φμ² var = mu + dispersion * mu**2 # IRLS weights: (dμ/dη)² / var # For log link: dμ/dη = μ return mu**2 / jnp.clip(var, 1e-6) def _working_resid_fn( self, X: jnp.ndarray, y: jnp.ndarray, beta: jnp.ndarray, offset: jnp.ndarray | None = None, dispersion: float = 1.0, ) -> jnp.ndarray: """Compute working residuals for IRLS with offset.""" eta = X @ beta if offset is not None: eta = eta + offset eta = jnp.clip(eta, -50, 50) mu = jnp.exp(eta) # Working response: z = η + (y - μ) * (dη/dμ) # For log link: dη/dμ = 1/μ return eta + (y - mu) / mu
[docs] def get_llf( self, X: jnp.ndarray, y: jnp.ndarray, params: jnp.ndarray, offset: jnp.ndarray | None = None, dispersion: float = 1.0, ) -> float: """Get log-likelihood at fitted parameters with offset.""" nll = self._negative_log_likelihood(params, X, y, offset, dispersion) return -nll
[docs] def fit( self, X: jnp.ndarray, y: jnp.ndarray, offset: jnp.ndarray | None = None, test_idx: int = -1, ) -> dict: """Fit negative binomial regression model with optional offset. This method fits a Negative Binomial regression model to count data, with support for including offset terms (typically log size factors) to account for normalization. The method also handles dispersion estimation if not provided during initialization. Parameters ---------- X : jnp.ndarray Design matrix of shape (n_samples, n_features). y : jnp.ndarray Count response vector of shape (n_samples,). offset : jnp.ndarray | None, default=None Offset term (log scale) to include in the model. Typically log(size_factors) for RNA-seq data. If provided, overrides the offset set during class initialization. test_idx : int, default=-1 Index of the parameter to test. If -1, tests the last parameter. Returns ------- Dictionary containing: - coef: Parameter estimates - llf: Log-likelihood at fitted parameters - se: Standard errors (:obj:`None` if `skip_stats=True`) - stat: Test statistics (:obj:`None` if `skip_stats=True`) - pval: P-values (:obj:`None` if `skip_stats=True`) - dispersion: Estimated or provided dispersion parameter """ # Estimate dispersion parameter if self.dispersion is not None: dispersion = jnp.clip(self.dispersion, self.dispersion_range[0], self.dispersion_range[1]) else: raise ValueError("A dispersion value must be provided. Use nb_fit() for automatic dispersion estimation.") # Initialize parameters init_params = jnp.zeros(X.shape[1]) # Better initialization for intercept mean_y = jnp.maximum(jnp.mean(y), 1e-8) if offset is not None: init_params = init_params.at[0].set(jnp.log(mean_y) - jnp.mean(offset)) else: init_params = init_params.at[0].set(jnp.log(mean_y)) # Fit model if self.optimizer == "BFGS": nll = partial(self._negative_log_likelihood, X=X, y=y, offset=offset, dispersion=dispersion) params = self._fit_bfgs(nll, init_params) elif self.optimizer == "IRLS": params = self._fit_irls( X, y, self._weight_fn, self._working_resid_fn, init_params, offset=offset, dispersion=dispersion ) else: raise ValueError(f"Unsupported optimizer: {self.optimizer}") # Get log-likelihood llf = self.get_llf(X, y, params, offset, dispersion) # Compute test statistics if requested se = stat = pval = None if not self.skip_stats: nll = partial(self._negative_log_likelihood, X=X, y=y, offset=offset, dispersion=dispersion) se, stat, pval = self._compute_stats(X, nll, params, test_idx=test_idx) return { "coef": params, "llf": llf, "se": se, "stat": stat, "pval": pval, "dispersion": dispersion, }
[docs] def predict( self, X: jnp.ndarray, params: jnp.ndarray, offset: jnp.ndarray | None = None, ) -> jnp.ndarray: """Predict count response variable using fitted model. Parameters ---------- X : jnp.ndarray Design matrix of shape (n_samples, n_features). params : jnp.ndarray Fitted parameter estimates. offset : jnp.ndarray | None, default=None Offset term to include in the prediction. If provided, overrides the offset set during class initialization. Returns ------- jnp.ndarray Predicted count response variable. """ eta = X @ params if offset is not None: eta += offset eta = jnp.clip(eta, -50, 50) mu = jnp.exp(eta) return mu