import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from adjustText import adjust_text
from delnx._utils import get_de_genes
class VolcanoPlot:
"""Static volcano plot using matplotlib."""
DEFAULT_COLOR_LEGEND_TITLE = "-log10(p-value)"
DEFAULT_SAVE_PREFIX = "volcanoplot"
def __init__(
self,
df: pd.DataFrame,
x: str = "coef",
y: str = "-log10(pval)",
thresh: dict[str, float] | None = None,
color_legend_title: str | None = None,
ax: plt.Axes | None = None,
figsize: tuple[float, float] = (8, 6),
save_path: str | None = None,
):
self.df = df.copy()
self.x = x
self.y = y
self.thresh = thresh or {}
self.color_legend_title = color_legend_title or self.DEFAULT_COLOR_LEGEND_TITLE
self.save_path = save_path
self.figsize = figsize
self.ax = ax
# Compute -log10(pval) if needed and y is set to that
if self.y == "-log10(pval)" and "-log10(pval)" not in self.df.columns:
self.df["-log10(pval)"] = -np.log10(self.df["pval"])
# Default color map
self.color_map = {
"NS": "#d1d5db", # gray-300
"Up": "#ef4444", # red-500
"Down": "#3b82f6", # blue-500
}
def style(self, color_map: dict[str, str] | None = None) -> "VolcanoPlot":
if color_map:
self.color_map = color_map
return self
def make_figure(self) -> "VolcanoPlot":
if self.ax is not None:
self.fig = self.ax.figure
else:
self.fig, self.ax = plt.subplots(figsize=self.figsize)
# Plot each significance group
for label, color in self.color_map.items():
subset = self.df[self.df["significant"] == label]
self.ax.scatter(
subset[self.x],
subset[self.y],
c=color,
label=label,
edgecolor="black",
linewidth=0.5,
s=20,
alpha=0.8,
)
# Threshold lines
x_thresh = self.thresh.get(self.x, None)
y_thresh = self.thresh.get(self.y, None)
if y_thresh is not None:
self.ax.axhline(y=y_thresh, color="black", linestyle="--", linewidth=1)
if x_thresh is not None:
self.ax.axvline(x=x_thresh, color="black", linestyle="--", linewidth=1)
self.ax.axvline(x=-x_thresh, color="black", linestyle="--", linewidth=1)
# Labels and grid
self.ax.set_xlabel(self.x if self.x != "coef" else "Estimated Coefficient")
self.ax.set_ylabel(self.y if self.y != "-log10(pval)" else "-log10(p-value)")
self.ax.legend(title=self.color_legend_title)
self.ax.grid(True, linestyle="--", linewidth=0.5, alpha=0.6)
return self
def add_labels(self, top_up: list[str], top_down: list[str]) -> None:
if self.ax is None:
raise RuntimeError("Plot must be initialized before adding labels.")
texts = []
for feature in top_up + top_down:
row = self.df[self.df["feature"] == feature].iloc[0]
texts.append(
self.ax.text(
row[self.x],
row[self.y],
feature,
fontsize=8,
ha="right" if row[self.x] > 0 else "left",
va="bottom",
bbox={
"boxstyle": "round,pad=0.3",
"facecolor": "white",
"edgecolor": "black",
"linewidth": 0.5,
},
)
)
adjust_text(
texts,
ax=self.ax,
arrowprops={
"arrowstyle": "-",
"color": "gray",
"lw": 0.5,
},
)
def show(self):
if self.fig is None:
self.make_figure()
plt.show()
def save(self):
if self.fig and self.save_path:
self.fig.savefig(self.save_path, bbox_inches="tight", dpi=300)
def get_figure(self):
if self.fig is None:
self.make_figure()
return self.fig, self.ax
[docs]
def volcanoplot(
df: pd.DataFrame,
x: str = "log2fc",
y: str = "-log10(pval)",
effect_key: str = "log2fc",
pval_key: str = "pval",
feature_key: str = "feature",
effect_thresh: float = 0.5,
pval_thresh: float = 0.01,
thresh: dict[str, float] | None = None,
label_top: int = 0,
color_legend_title: str | None = None,
ax: plt.Axes | None = None,
figsize: tuple[float, float] = (8, 6),
show: bool | None = True,
save: str | bool | None = None,
return_fig: bool = False,
):
"""
Create a volcano plot using matplotlib. Labels significant genes based on thresholds or uses "significant" column if present.
Parameters
----------
df : pd.DataFrame
DataFrame containing differential expression results.
x : str, default="log2fc"
Column name for x-axis (effect size).
y : str, default="-log10(pval)"
Column name for y-axis (significance).
effect_key : str, default="log2fc"
Column with effect size values for DE analysis.
pval_key : str, default="pval"
Column with p-values for DE analysis.
feature_key : str, default="feature"
Column containing gene names for labeling.
effect_thresh : float, default=0.5
Threshold for absolute effect size.
pval_thresh : float, default=0.01
Threshold for significance.
thresh : dict[str, float] or None
Dictionary mapping axis names to threshold values for plot lines,
e.g. {'log2fc': 1.0, '-log10(pval)': 1.3}.
If None, uses effect_thresh and -log10(pval_thresh).
label_top : int, default=0
If > 0, label top N up/down genes by effect size.
color_legend_title : str or None
Title for the legend. Default: "-log10(p-value)".
ax : plt.Axes or None
If provided, use this Axes for plotting instead of creating a new one.
If None, a new Axes will be created.
figsize : tuple[float, float], default=(8, 6)
Size of the figure in inches.
show : bool or None, default=True
Whether to display the figure interactively.
save : str or bool or None
If str, path to save the image. If True, uses default name.
return_fig : bool, default=False
Whether to return the matplotlib Figure and Axes.
Returns
-------
VolcanoPlot or tuple[Figure, Axes] or None
"""
# Check group uniqueness if group column exists
if "group" in df.columns:
unique_groups = df["group"].unique()
if len(unique_groups) > 1:
raise ValueError(f"Volcano plot expects a single group, but found multiple: {unique_groups}")
# Use analyze_de_genes to label the dataframe if needed
df_to_plot = df.copy()
# Check if significance labels already exist, if not, add them
if "significant" not in df_to_plot.columns or y not in df_to_plot.columns:
_, df_to_plot = get_de_genes(
df,
effect_key=effect_key,
pval_key=pval_key,
feature_key=feature_key,
effect_thresh=effect_thresh,
pval_thresh=pval_thresh,
return_labeled_df=True,
)
# Set default thresholds if not provided
if thresh is None:
thresh = {x: effect_thresh, y: -np.log10(pval_thresh)}
save_path = None
if isinstance(save, str):
save_path = save
elif save is True:
save_path = f"{VolcanoPlot.DEFAULT_SAVE_PREFIX}.pdf"
vp = VolcanoPlot(
df_to_plot,
x=x,
y=y,
thresh=thresh,
color_legend_title=color_legend_title,
figsize=figsize,
ax=ax,
save_path=save_path,
).make_figure()
# Add labels for top genes if requested
if label_top > 0 and feature_key in df_to_plot.columns:
# Extract top genes using our unified function
de_genes_dict = get_de_genes(
df_to_plot,
effect_key=effect_key,
pval_key=pval_key,
feature_key=feature_key,
effect_thresh=effect_thresh,
pval_thresh=pval_thresh,
top_n=label_top,
)
# Get the first (and should be only) group's genes
group_key = list(de_genes_dict.keys())[0]
top_up = de_genes_dict[group_key]["up"]
top_down = de_genes_dict[group_key]["down"]
vp.add_labels(top_up, top_down)
if save_path:
vp.save()
if show:
vp.show()
if return_fig:
return vp.get_figure()