diff --git a/setup.cfg b/setup.cfg index 7f58c58d..67bc4ce6 100644 --- a/setup.cfg +++ b/setup.cfg @@ -18,6 +18,8 @@ install_requires = scvi-tools>=0.15.3 torch>=1.9.0 pymc3>=3.8<3.10 + jax>=0.3,<=0.3.6 + jaxlib>=0.1.65,<=0.3.5 arviz==0.10.0 numpy pandas diff --git a/tests/test_cell2location.py b/tests/test_cell2location.py index 8d33f0de..3bf8210e 100644 --- a/tests/test_cell2location.py +++ b/tests/test_cell2location.py @@ -24,6 +24,8 @@ def test_cell2location(): sc_model.train(max_epochs=1, batch_size=1000) # export the estimated cell abundance (summary of the posterior distribution) dataset = sc_model.export_posterior(dataset, sample_kwargs={"num_samples": 10}) + # test exclusion of observed variables from posterior sampling + assert "data_target" not in dataset.uns["mod"]["post_sample_means"].keys() # test plot_QC sc_model.plot_QC() # test save/load @@ -47,6 +49,8 @@ def test_cell2location(): # export the estimated cell abundance (summary of the posterior distribution) # full data dataset = st_model.export_posterior(dataset, sample_kwargs={"num_samples": 10, "batch_size": st_model.adata.n_obs}) + # test exclusion of observed variables from posterior sampling + assert "data_target" not in dataset.uns["mod"]["post_sample_means"].keys() ## minibatches of locations ## Cell2location.setup_anndata(dataset, batch_key="batch") st_model = Cell2location(dataset, cell_state_df=inf_aver, N_cells_per_location=30, detection_alpha=200)