Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions cell2location/models/_cell2location_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,15 @@ def __init__(
self.factor_names_ = cell_state_df.columns.values

if not detection_mean_per_sample:
# compute expected change in sensitivity (m_g in V1 or y_s in V2)
# compute expected change in sensitivity (m_g in V1 and y_s in V2)
sc_total = cell_state_df.sum(0).mean()
sp_total = self.adata_manager.get_from_registry(REGISTRY_KEYS.X_KEY).sum(1).mean()
sp_total = self.adata_manager.get_from_registry(_CONSTANTS.X_KEY).sum(1)
batch = self.adata_manager.get_from_registry(_CONSTANTS.BATCH_KEY).flatten()
sp_total = np.array([sp_total[batch == b].mean() for b in range(self.summary_stats["n_batch"])])
self.detection_mean_ = (sp_total / model_kwargs.get("N_cells_per_location", 1)) / sc_total
if (self.detection_mean_.max() > 1.0) and (model_kwargs.get("use_detection_probability", False) is True):
self.detection_mean_ = self.detection_mean_ / (self.detection_mean_.max() + 0.000001)
self.detection_mean_ = self.detection_mean_.mean()
self.detection_mean_ = self.detection_mean_ * detection_mean_correction
model_kwargs["detection_mean"] = self.detection_mean_
else:
Expand Down
60 changes: 41 additions & 19 deletions cell2location/models/_cell2location_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ def __init__(
n_groups: int = 50,
detection_mean=1 / 2,
detection_alpha=200.0,
use_detection_probability: bool = False,
m_g_gene_level_prior={"mean": 1, "mean_var_ratio": 1.0, "alpha_mean": 3.0},
N_cells_per_location=8.0,
A_factors_per_location=7.0,
Expand Down Expand Up @@ -116,6 +117,7 @@ def __init__(
detection_hyp_prior["mean"] = detection_mean
detection_hyp_prior["alpha"] = detection_alpha
self.detection_hyp_prior = detection_hyp_prior
self.use_detection_probability = use_detection_probability

if (init_vals is not None) & (type(init_vals) is dict):
self.np_init_vals = init_vals
Expand Down Expand Up @@ -325,27 +327,47 @@ def forward(self, x_data, idx, batch_index):
) # (self.n_obs, self.n_factors)

# =====================Location-specific detection efficiency ======================= #
# y_s with hierarchical mean prior
detection_mean_y_e = pyro.sample(
"detection_mean_y_e",
dist.Gamma(
self.ones * self.detection_mean_hyp_prior_alpha,
self.ones * self.detection_mean_hyp_prior_beta,
if not self.use_detection_probability:
# y_s with hierarchical mean prior
detection_mean_y_e = pyro.sample(
"detection_mean_y_e",
dist.Gamma(
self.ones * self.detection_mean_hyp_prior_alpha,
self.ones * self.detection_mean_hyp_prior_beta,
)
.expand([self.n_batch, 1])
.to_event(2),
)
detection_hyp_prior_alpha = pyro.deterministic(
"detection_hyp_prior_alpha",
self.ones_n_batch_1 * self.detection_hyp_prior_alpha,
)
.expand([self.n_batch, 1])
.to_event(2),
)
detection_hyp_prior_alpha = pyro.deterministic(
"detection_hyp_prior_alpha",
self.ones_n_batch_1 * self.detection_hyp_prior_alpha,
)

beta = (obs2sample @ detection_hyp_prior_alpha) / (obs2sample @ detection_mean_y_e)
with obs_plate:
detection_y_s = pyro.sample(
"detection_y_s",
dist.Gamma(obs2sample @ detection_hyp_prior_alpha, beta),
) # (self.n_obs, 1)
beta = (obs2sample @ detection_hyp_prior_alpha) / (obs2sample @ detection_mean_y_e)
with obs_plate:
detection_y_s = pyro.sample(
"detection_y_s",
dist.Gamma(obs2sample @ detection_hyp_prior_alpha, beta),
) # (self.n_obs, 1)
else:
# y_s with hierarchical mean prior
detection_mean_y_e = pyro.sample(
"detection_mean_y_e",
dist.Beta(
self.ones * self.detection_mean_hyp_prior_alpha,
self.ones * self.detection_mean_hyp_prior_beta,
)
.expand([self.n_batch, 1])
.to_event(2),
)

alpha = (obs2sample @ detection_mean_y_e) * self.ones * self.detection_hyp_prior_alpha
beta = (obs2sample @ (self.ones - detection_mean_y_e)) * self.ones * self.detection_hyp_prior_alpha
with obs_plate:
detection_y_s = pyro.sample(
"detection_y_s",
dist.Beta(alpha, beta),
) # (self.n_obs, 1)

# =====================Gene-specific additive component ======================= #
# per gene molecule contribution that cannot be explained by
Expand Down
16 changes: 16 additions & 0 deletions tests/test_cell2location.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,3 +125,19 @@ def test_cell2location():
# export the estimated cell abundance (summary of the posterior distribution)
# full data
st_model.export_posterior(dataset, sample_kwargs={"num_samples": 10, "batch_size": st_model.adata.n_obs})

### test new cell2location models ###
## detection probability rather than detection efficiency ##
st_model = Cell2location(
dataset,
cell_state_df=inf_aver,
N_cells_per_location=30,
detection_alpha=200,
use_detection_probability=True,
detection_hyp_prior={"mean_alpha": 100.0},
)
# test full data training
st_model.train(max_epochs=1)
# export the estimated cell abundance (summary of the posterior distribution)
# full data
dataset = st_model.export_posterior(dataset, sample_kwargs={"num_samples": 10, "batch_size": st_model.adata.n_obs})