Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 79 additions & 13 deletions pf2/figures/figureA3b_c.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,23 @@
Figure A3b_c
"""

import numpy as np
import pandas as pd
import anndata
import seaborn as sns
from statsmodels.stats.anova import anova_lm
from statsmodels.formula.api import ols
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
from ..data_import import condition_factors_meta, add_obs, combine_cell_types
from ..predict import predict_mortality_all, plsr_acc_proba
from ..predict import predict_mortality_all, pca_acc_proba
from .common import subplotLabel, getSetup
from sklearn.metrics import RocCurveDisplay


def makeFigure():
"""Get a list of the axis objects and create a figure."""
ax, f = getSetup((6, 6), (2, 2))
ax, f = getSetup((9, 9), (3, 2))
subplotLabel(ax)

X = anndata.read_h5ad("/opt/northwest_bal/full_fitted.h5ad")
Expand All @@ -25,36 +30,46 @@ def makeFigure():

roc_auc = [False, True]
for i in range(2):
plsr_acc_df = pd.DataFrame([])
for j in range(3):
df = plsr_acc_proba(
pca_acc_df = pd.DataFrame([])
for j in range(5):
df = pca_acc_proba(
X, cond_fact_meta_df, n_components=j + 1, roc_auc=roc_auc[i]
)
df["Component"] = j + 1
plsr_acc_df = pd.concat([plsr_acc_df, df], axis=0)
pca_acc_df = pd.concat([pca_acc_df, df], axis=0)

plsr_acc_df = plsr_acc_df.melt(
pca_acc_df = pca_acc_df.melt(
id_vars="Component", var_name="Category", value_name="Accuracy"
)
sns.barplot(
data=plsr_acc_df, x="Component", y="Accuracy", hue="Category", ax=ax[i],
data=pca_acc_df, x="Component", y="Accuracy", hue="Category", ax=ax[i],
hue_order=["C19", "nC19", "Overall"]
)
if roc_auc[i] is True:
ax[i].set(ylim=[0, 1], ylabel="AUC ROC")
else:
ax[i].set(ylim=[0, 1], ylabel="Prediction Accuracy")

# Find the top performing PCA models based on best total accuracy within a component
top_acc_idx = pca_acc_df.groupby("Component")["Accuracy"].idxmax().tolist()
pca_acc_df_top = pca_acc_df.loc[top_acc_idx].sort_values("Accuracy", ascending=False)
top_performing = pca_acc_df_top["Component"].tolist()

# Get the top 2 performing PCA models and plot the ROC curve for each
for i in range(2):
plot_plsr_auc_roc(X, cond_fact_meta_df, n_components=i + 1, ax=ax[i + 2])
ax[i + 2].set(title=f"PLSR {i + 1} Components")
plot_pca_auc_roc(X, cond_fact_meta_df, n_components=top_performing[i], ax=ax[i + 2])
ax[i + 2].set(title=f"PCA {top_performing[i]} Components")

anova_df = pc_anova_analysis(cond_fact_meta_df)
plot_pc_anova_heatmap(anova_df, ax=ax[4])
ax[5].set_visible(False)

return f



def plot_plsr_auc_roc(X, patient_factor_matrix, n_components, ax):
"""Runs PLSR and plots ROC AUC based on actual and prediction labels"""
def plot_pca_auc_roc(X, patient_factor_matrix, n_components, ax):
"""Runs PCA and plots ROC AUC based on actual and prediction labels"""
probabilities_all, labels_all = predict_mortality_all(X,
patient_factor_matrix, n_components=n_components, proba=True)

Expand All @@ -72,4 +87,55 @@ def plot_plsr_auc_roc(X, patient_factor_matrix, n_components, ax):
)
RocCurveDisplay.from_predictions(
labels_all.to_numpy().astype(int), probabilities_all, plot_chance_level=True, ax=ax, name="Overall"
)
)


def pc_anova_analysis(cond_fact_meta_df, n_components=5):
"""ANOVA for each PC against C19 outcome, nC19 outcome, and C19 vs nC19."""
feature_cols = [c for c in cond_fact_meta_df.columns if c.startswith("Cmp.")]
df = cond_fact_meta_df[
cond_fact_meta_df["patient_category"] != "Non-Pneumonia Control"
].copy()

pca_scores = PCA(n_components=n_components).fit_transform(
StandardScaler().fit_transform(df[feature_cols])
)
pc_df = pd.DataFrame(
pca_scores,
columns=[f"PC{i+1}" for i in range(n_components)],
index=df.index,
)
pc_df["binary_outcome"] = df["binary_outcome"]
pc_df["patient_category"] = df["patient_category"]
pc_df["is_c19"] = (pc_df["patient_category"] == "COVID-19").astype(int)

def _p(formula, data, term):
return anova_lm(ols(formula, data=data).fit(), typ=2).loc[term, "PR(>F)"]

results = {}
for i in range(n_components):
pc = f"PC{i+1}"
c19 = pc_df[pc_df["patient_category"] == "COVID-19"].dropna(
subset=["binary_outcome"]
)
nc19 = pc_df[pc_df["patient_category"] != "COVID-19"].dropna(
subset=["binary_outcome"]
)
results[pc] = {
"C19 Outcome": _p(f"{pc} ~ C(binary_outcome)", c19, "C(binary_outcome)"),
"nC19 Outcome": _p(f"{pc} ~ C(binary_outcome)", nc19, "C(binary_outcome)"),
"C19 vs nC19": _p(f"{pc} ~ C(is_c19)", pc_df, "C(is_c19)"),
}

return pd.DataFrame(results).T


def plot_pc_anova_heatmap(anova_df, ax):
"""Heatmap of -log10(p-values) from per-PC ANOVA associations."""
neg_log_p = -np.log10(anova_df.astype(float))
sns.heatmap(
neg_log_p, ax=ax, cmap="YlOrRd", annot=True, fmt=".2f",
linewidths=0.5, cbar_kws={"label": "-log10(p-value)"},
)
ax.axhline(y=0, color="black", linewidth=0.5)
ax.set(title="PC Associations (ANOVA)", xlabel="", ylabel="")
179 changes: 125 additions & 54 deletions pf2/figures/figureA3d_g.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,104 +2,175 @@
Figure A3d_g
"""

import anndata
import numpy as np
import pandas as pd
import anndata
import seaborn as sns
from sklearn.model_selection import cross_val_score
from sklearn.svm import SVC
from ..data_import import condition_factors_meta
from ..predict import plsr_acc
from ..predict import pca_acc
from .common import subplotLabel, getSetup
import seaborn as sns

# Make modifying the target PCs easier while exploring
C19_TARGET = 3
nC19_TARGET = 4
C19_TARGET_STR = f"PC {C19_TARGET}"
nC19_TARGET_STR = f"PC {nC19_TARGET}"
target_pcs = {C19_TARGET_STR: C19_TARGET, nC19_TARGET_STR: nC19_TARGET}

# Helper function to get keys by index since dicts don't support direct indexing
def get_nth_key(target_dict, n):
for i, key in enumerate(target_dict):
if i == n:
return key
raise IndexError("Index out of range")

def makeFigure():
"""Get a list of the axis objects and create a figure."""
ax, f = getSetup((3, 7), (4, 1))
ax, f = getSetup((6, 8), (3, 2))
subplotLabel(ax)

X = anndata.read_h5ad("/opt/northwest_bal/full_fitted.h5ad")

cond_fact_meta_df = condition_factors_meta(X)

labels, plsr_results_both = plsr_acc(X, cond_fact_meta_df, n_components=1)

plot_plsr_loadings(plsr_results_both, ax[0], ax[1])
labels, pca_results_both = pca_acc(X, cond_fact_meta_df, n_components=5)

plot_pca_loadings(pca_results_both, ax[0], ax[1])
ax[0].set(xlim=[-0.35, 0.35])
ax[1].set(xlim=[-0.35, 0.35])

plot_plsr_scores(plsr_results_both, cond_fact_meta_df, labels, ax[2], ax[3])
ax[2].set(xlim=[-7, 7])
ax[3].set(xlim=[-9.5, 9.5])
plot_pca_scores(pca_results_both, cond_fact_meta_df, labels, ax[2], ax[3])

plot_pca_scores_2d(pca_results_both, cond_fact_meta_df, labels, ax[4], ax[5])

return f


def plot_plsr_loadings(plsr_results, ax1, ax2):
"""Runs PLSR and plots ROC AUC based on actual and prediction labels"""
def plot_pca_loadings(pca_results, ax1, ax2):
"""Plots PCA component loadings for C19 and nC19 models."""
ax = [ax1, ax2]
type_of_data = ["C19", "nC19"]

for i in range(2):
x_load = plsr_results[i].x_loadings_[:, 0]
pca = pca_results[i].named_steps["pca"]

x_load = pca.components_[target_pcs[get_nth_key(target_pcs, i)] - 1, :]
if i == 1:
x_load =-1*x_load
df_xload = pd.DataFrame(data=x_load, columns=["PLSR 1"])
x_load = -1 * x_load
df_xload = pd.DataFrame(data=x_load, columns=[get_nth_key(target_pcs, i)])
df_xload["Component"] = np.arange(df_xload.shape[0]) + 1
print(df_xload.sort_values(by="PLSR 1"))
y_load = plsr_results[i].y_loadings_[0, 0]
if i == 1:
y_load =-1*y_load
df_yload = pd.DataFrame(data=[[y_load]], columns=["PLSR 1"])
sns.swarmplot(
data=df_xload,
x="PLSR 1",
ax=ax[i],
color="k",
)
sns.swarmplot(
data=df_yload,
x="PLSR 1",
ax=ax[i],
color="r",

print(df_xload.sort_values(by=get_nth_key(target_pcs, i)))

sns.swarmplot(data=df_xload, x=get_nth_key(target_pcs, i), ax=ax[i], color="k")
ax[i].set(
xlabel=get_nth_key(target_pcs, i), ylabel="Pf2 Components", title=f"{type_of_data[i]}-loadings"
)
ax[i].set(xlabel="PLSR 1", ylabel="Pf2 Components", title=f"{type_of_data[i]}-loadings")


def plot_plsr_scores(plsr_results, cond_fact_meta_df, labels, ax1, ax2):
"""Runs PLSR and plots ROC AUC based on actual and prediction labels"""
def plot_pca_scores(pca_results, cond_fact_meta_df, labels, ax1, ax2):
"""Plots PCA scores for C19 and nC19 patients colored by mortality outcome."""
ax = [ax1, ax2]
type_of_data = ["C19", "nC19"]

cond_fact_meta_df = cond_fact_meta_df.loc[
cond_fact_meta_df.loc[:, "patient_category"] != "Non-Pneumonia Control", :
]

cmp_cols = [c for c in cond_fact_meta_df.columns if c.startswith("Cmp.")]
c19_mask = cond_fact_meta_df.loc[:, "patient_category"] == "COVID-19"

for i in range(2):
if i == 0:
score_labels = labels.loc[
cond_fact_meta_df.loc[:, "patient_category"] == "COVID-19"
]
else:
score_labels = labels.loc[
cond_fact_meta_df.loc[:, "patient_category"] != "COVID-19"
]
mask = c19_mask if i == 0 else ~c19_mask
score_labels = labels.loc[mask]
subset_data = cond_fact_meta_df.loc[mask, cmp_cols]

pal = sns.color_palette()
if i == 0:
numb1=0; numb2=2
else:
numb1=1; numb2=3

x_scores = plsr_results[i].x_scores_[:, 0]
x_scores = pca_results[i][:-1].transform(subset_data)[:, target_pcs[get_nth_key(target_pcs, i)] - 1]
if i == 1:
x_scores =-1*x_scores
df_xscores = pd.DataFrame(data=x_scores, columns=["PLSR 1"])
x_scores = -1 * x_scores

pal = sns.color_palette()
numb1, numb2 = (0, 2) if i == 0 else (1, 3)

df_xscores = pd.DataFrame(data=x_scores, columns=[get_nth_key(target_pcs, i)])
sns.swarmplot(
data=df_xscores,
x="PLSR 1",
x=get_nth_key(target_pcs, i),
ax=ax[i],
hue=score_labels.to_numpy(),
palette=[pal[numb1], pal[numb2]],
hue_order=[1, 0],
)
ax[i].set(xlabel="PLSR 1", ylabel="Samples", title=f"{type_of_data[i]}-scores")
ax[i].set(xlabel=get_nth_key(target_pcs, i), ylabel="Samples", title=f"{type_of_data[i]}-scores")

def plot_pca_scores_2d(pca_results, cond_fact_meta_df, labels, ax1, ax2):
"""Plots PCA scores 2D with RBF-SVM decision boundaries and cross-validated AUC."""
ax = [ax1, ax2]
pc_pairs = [(1, 2)] # PC2 vs. PC3

cond_fact_meta_df = cond_fact_meta_df.loc[
cond_fact_meta_df.loc[:, "patient_category"] != "Non-Pneumonia Control", :
]

cmp_cols = [c for c in cond_fact_meta_df.columns if c.startswith("Cmp.")]
c19_mask = cond_fact_meta_df.loc[:, "patient_category"] == "COVID-19"

score_labels = labels.loc[c19_mask]
subset_data = cond_fact_meta_df.loc[c19_mask, cmp_cols]
scores = pca_results[0][:-1].transform(subset_data)[:, :5]
y_svm = score_labels.to_numpy().astype(int)

# Perform SVM on 5-component model for total accuracy reporting
cv_auc_5d = cross_val_score(
SVC(kernel="rbf", probability=True),
scores, y_svm, cv=5, scoring="roc_auc",
).mean()

pal = sns.color_palette()

# Plot just pairs of PCs with SVM decision boundaries for visualization
for i, (pc_x, pc_y) in enumerate(pc_pairs):
X_2d = np.column_stack([scores[:, pc_x], scores[:, pc_y]])

svm_2d = SVC(kernel="rbf", probability=True)
svm_2d.fit(X_2d, y_svm)

# Decision boundary drawn before scatter so points sit on top
x_pad = (X_2d[:, 0].max() - X_2d[:, 0].min()) * 0.1
y_pad = (X_2d[:, 1].max() - X_2d[:, 1].min()) * 0.1
xx, yy = np.meshgrid(
np.linspace(X_2d[:, 0].min() - x_pad, X_2d[:, 0].max() + x_pad, 300),
np.linspace(X_2d[:, 1].min() - y_pad, X_2d[:, 1].max() + y_pad, 300),
)
Z = svm_2d.decision_function(np.c_[xx.ravel(), yy.ravel()]).reshape(xx.shape)
# Z<0 → class 0 (Lived, pal[2]); Z>0 → class 1 (Died, pal[0])
ax[i].contourf(
xx, yy, Z, levels=[-1000, 0, 1000], alpha=0.15, colors=[pal[2], pal[0]]
)
ax[i].contour(
xx, yy, Z, levels=[0], colors="k", linewidths=1.5, linestyles="--"
)

df_scores = pd.DataFrame(
{
f"PC {pc_x + 1}": scores[:, pc_x],
f"PC {pc_y + 1}": scores[:, pc_y],
"Outcome": score_labels.to_numpy(),
}
)
sns.scatterplot(
data=df_scores,
x=f"PC {pc_x + 1}",
y=f"PC {pc_y + 1}",
hue="Outcome",
palette=[pal[0], pal[2]],
hue_order=[1, 0],
ax=ax[i],
)
ax[i].set(
xlabel=f"PC {pc_x + 1}",
ylabel=f"PC {pc_y + 1}",
title=f"C19-scores 2D (5-PC SVM CV AUC={cv_auc_5d:.2f})",
)

Loading
Loading