Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
173 commits
Select commit Hold shift + click to select a range
8365331
add_velocity function
PalkaPuri Aug 7, 2024
e2ee500
velocity predictor WIP
PalkaPuri Aug 8, 2024
6a19ae4
wip
PalkaPuri Aug 8, 2024
b5dde8a
format/lint
PalkaPuri Aug 8, 2024
a923457
specify forager column as int
PalkaPuri Aug 12, 2024
8cc4111
add warning
PalkaPuri Aug 12, 2024
892da93
format/lint
PalkaPuri Aug 12, 2024
f675d23
rename f,t for clarity
PalkaPuri Aug 13, 2024
1ee7caa
rerun notebook
PalkaPuri Aug 13, 2024
fcd34b2
add docstrings
PalkaPuri Aug 13, 2024
dc950c2
format/lint
PalkaPuri Aug 13, 2024
205235f
use inbuilt function for gaussian pdf
PalkaPuri Aug 13, 2024
b476a17
merge changes from pp-collab2-compute-velocity
PalkaPuri Aug 13, 2024
045e7d0
updated visualization
PalkaPuri Aug 13, 2024
05b3d90
add handling of nan values in visualization
PalkaPuri Aug 15, 2024
5774858
add _generate_pairwise_copying predictor
PalkaPuri Aug 15, 2024
41a7787
format/lint
PalkaPuri Aug 15, 2024
ab6ee7e
type hints
PalkaPuri Aug 15, 2024
47679f3
add pairwise predictor and animation function
PalkaPuri Aug 21, 2024
fac3f10
change velocity to backward looking
PalkaPuri Aug 21, 2024
ddf4572
Merge branch 'pp-collab2-compute-velocity' of https://github.com/Basi…
PalkaPuri Aug 21, 2024
4b242d8
docstrings and type hints
PalkaPuri Aug 21, 2024
82223c0
generate function and typehints
PalkaPuri Aug 21, 2024
11f5dce
Merge branch 'staging-collab-2' of https://github.com/BasisResearch/c…
PalkaPuri Aug 21, 2024
0dec71b
fixed naming, nan handling
PalkaPuri Aug 21, 2024
ea399b2
refactoring
PalkaPuri Aug 22, 2024
9a2ec97
vicsek wip
PalkaPuri Aug 23, 2024
033e565
update normalization of velocity predictors
PalkaPuri Aug 26, 2024
9c6854b
distance_to_next_move WIP
PalkaPuri Aug 26, 2024
8bfa2d0
Merge branch 'pp-collab2-vicsek-predictor' of https://github.com/Basi…
PalkaPuri Aug 26, 2024
e61915f
added alternates for distance_to_next_pos
PalkaPuri Aug 26, 2024
e79917e
lint/format
PalkaPuri Aug 26, 2024
ab7df53
updated nb example
PalkaPuri Aug 26, 2024
21e0361
update scores wip
PalkaPuri Aug 27, 2024
fca3235
merge rafal's changes
PalkaPuri Aug 27, 2024
d9cadcf
absolute paths for imports in modules
PalkaPuri Aug 27, 2024
3999139
update module to allow chnaging n
PalkaPuri Aug 27, 2024
540b11b
renaming
PalkaPuri Aug 27, 2024
a7b0495
format
PalkaPuri Aug 27, 2024
d0018bd
trace code in place
rfl-urbaniak Aug 27, 2024
319aa97
rename n
PalkaPuri Aug 28, 2024
cf3d543
merge staging-collab-2
PalkaPuri Aug 28, 2024
e47d161
Merge branch 'staging-collab-2' of https://github.com/BasisResearch/c…
PalkaPuri Aug 29, 2024
92a1957
add derive_predictors
PalkaPuri Aug 29, 2024
fd2a3f3
test inference wip
PalkaPuri Aug 29, 2024
d573538
Merge branch 'staging-collab-2' of https://github.com/BasisResearch/c…
PalkaPuri Aug 29, 2024
e2d3615
derive WIP
PalkaPuri Aug 29, 2024
f897604
test inference pipeline
PalkaPuri Aug 29, 2024
5230e77
format/lint
PalkaPuri Aug 29, 2024
e8dba19
format again
PalkaPuri Aug 29, 2024
a0f024c
one more format
PalkaPuri Aug 29, 2024
e07b3ff
lint
PalkaPuri Aug 29, 2024
5faaa13
format
PalkaPuri Aug 29, 2024
b01c8cf
adding doc to local_windows WIP
PalkaPuri Aug 29, 2024
e7bfeaf
small typos in docstrings
rfl-urbaniak Aug 30, 2024
4476c6b
added predictor time to logging
rfl-urbaniak Aug 30, 2024
bcbd8d6
added time to score logging
rfl-urbaniak Aug 30, 2024
72d28f4
time to logging, a few typos
rfl-urbaniak Aug 30, 2024
8d6a2ce
format lint
rfl-urbaniak Aug 30, 2024
5c70d1f
refactored proximity
rfl-urbaniak Aug 30, 2024
eb406a7
removed old proximity code
rfl-urbaniak Aug 30, 2024
bfbf187
format lint
rfl-urbaniak Aug 30, 2024
065518c
refactored trace
rfl-urbaniak Aug 30, 2024
20b89c1
Merge branch 'staging-collab-2' of https://github.com/BasisResearch/c…
rfl-urbaniak Aug 30, 2024
bdc5bff
added generate food to init
rfl-urbaniak Aug 30, 2024
0de9af7
format, lint
rfl-urbaniak Aug 30, 2024
1f0e506
update parameter names in docstring
PalkaPuri Sep 3, 2024
a7c55cc
explanation of score_kwargs
PalkaPuri Sep 3, 2024
b14029d
update next_step_score notebook
PalkaPuri Sep 3, 2024
2b8126f
update warning and add docstring to dataObject
PalkaPuri Sep 3, 2024
f00ce02
removing reusing velocity warning
PalkaPuri Sep 3, 2024
4d11e85
removed individual drop warning, add to derive_predictors_and_scores
PalkaPuri Sep 3, 2024
66688b8
lint/format
PalkaPuri Sep 3, 2024
c3f875e
fix bug in test
PalkaPuri Sep 3, 2024
d7fe531
format
PalkaPuri Sep 3, 2024
21083e8
Merge branch 'pp-collab2-derivepredictors' into pp-collab2-housekeeping
PalkaPuri Sep 3, 2024
2b8ecb2
local_windows update constraint implementation, add docstring
PalkaPuri Sep 3, 2024
5df34c7
rename variables
PalkaPuri Sep 3, 2024
62f0cb2
filter_by_distance update constraint implementation
PalkaPuri Sep 3, 2024
728a182
remove fps argument from subsampling func
PalkaPuri Sep 3, 2024
c7bc9bf
velocity- rename predictorID, update constraint implementation, docst…
PalkaPuri Sep 3, 2024
b583275
ensure all test notebooks run
PalkaPuri Sep 3, 2024
a621300
format/lint WIP
PalkaPuri Sep 3, 2024
e33c5f8
Merge branch 'ru-new-trace' into pp-collab2-housekeeping
PalkaPuri Sep 3, 2024
1b28cc3
ensure rafals nbs run
PalkaPuri Sep 3, 2024
c4ddd34
Merge branch 'staging-collab-2' of https://github.com/BasisResearch/c…
PalkaPuri Sep 4, 2024
2d9829e
type hints for kwargs in constraint func
PalkaPuri Sep 4, 2024
2ea9f12
add collab2 notebooks to automatic testing
PalkaPuri Sep 4, 2024
953558e
small fixes to local_windows.py docstrings
rfl-urbaniak Sep 4, 2024
f75ec48
interpunction in filtering.py
rfl-urbaniak Sep 4, 2024
d9249c1
proximity notebook now works
rfl-urbaniak Sep 5, 2024
575db78
wip
rfl-urbaniak Sep 5, 2024
2bd0546
fixed trace and derivation notebooks
rfl-urbaniak Sep 5, 2024
5913050
resolved type hinting problems
rfl-urbaniak Sep 5, 2024
f3f7e4f
fix some bugs
PalkaPuri Sep 5, 2024
bba492b
Merge branch 'staging-collab-2' of https://github.com/BasisResearch/c…
PalkaPuri Sep 5, 2024
dd1816c
notebook issue fix
PalkaPuri Sep 5, 2024
b4edb96
fix type hint in proximity
PalkaPuri Sep 5, 2024
ca693a8
Merge branch 'pp-collab2-housekeeping' of https://github.com/BasisRes…
rfl-urbaniak Sep 5, 2024
b88937e
grid_constraint_params updated to dict
PalkaPuri Sep 9, 2024
35f9131
interaction_constraint_params updated to dict
PalkaPuri Sep 9, 2024
1bcf668
make sure all notebooks pass + format/lint
PalkaPuri Sep 9, 2024
5572734
lint
rfl-urbaniak Sep 9, 2024
59e6215
lint
rfl-urbaniak Sep 9, 2024
f9b92da
lint
rfl-urbaniak Sep 9, 2024
bdd9965
lint
rfl-urbaniak Sep 9, 2024
1e1a548
lint
rfl-urbaniak Sep 9, 2024
5467b05
lint
rfl-urbaniak Sep 9, 2024
cc63147
lint
rfl-urbaniak Sep 9, 2024
c2f4c82
suspended old velocity test
rfl-urbaniak Sep 9, 2024
954d9e5
rhf old update
rfl-urbaniak Sep 17, 2024
032b3f4
Merge branch 'ru-random-hungry-2' of https://github.com/BasisResearch…
rfl-urbaniak Sep 17, 2024
944b20a
re-ran hungry and random foragers
rfl-urbaniak Sep 17, 2024
5780fef
following foragers replicated
rfl-urbaniak Sep 17, 2024
4ee8c47
format, lint
rfl-urbaniak Sep 18, 2024
1d31c6b
change to sublinear
rfl-urbaniak Sep 18, 2024
f0293fb
fixing velocity_predicts WIP
rfl-urbaniak Sep 18, 2024
e3d75c6
updated trace_predictor
rfl-urbaniak Sep 18, 2024
c1c0b8b
suspend animations in velocity predictors
rfl-urbaniak Sep 18, 2024
5205da4
added access predictor
rfl-urbaniak Sep 19, 2024
ec21ab4
updated random with access
rfl-urbaniak Sep 19, 2024
233c693
updateg hungry foragers
rfl-urbaniak Sep 19, 2024
bab23a2
updated followers with access
rfl-urbaniak Sep 19, 2024
b349262
format, lint
rfl-urbaniak Sep 19, 2024
2e3321c
communication WIP
rfl-urbaniak Sep 19, 2024
9b6e058
Merge branch 'staging-collab-2' of https://github.com/BasisResearch/c…
rfl-urbaniak Sep 19, 2024
07c231f
wip
rfl-urbaniak Sep 19, 2024
52b339b
fixed lint
rfl-urbaniak Sep 19, 2024
920b509
WIP
rfl-urbaniak Sep 19, 2024
22104de
fixed imports in random foragers
rfl-urbaniak Sep 19, 2024
acc6d96
WIP
rfl-urbaniak Sep 19, 2024
f20a472
fixing init
rfl-urbaniak Sep 19, 2024
e313810
communication predictor notebook functional
rfl-urbaniak Sep 19, 2024
605ed17
format, lint
rfl-urbaniak Sep 19, 2024
f3a6aa1
restored init fix
rfl-urbaniak Sep 19, 2024
8f61e4e
bump pyro to 1.9.1
rfl-urbaniak Sep 20, 2024
e0dfbaf
revert (chirho), remove deterministic from rendering
rfl-urbaniak Sep 20, 2024
77752a3
Merge branch 'ru-random-hungry-2' of https://github.com/BasisResearch…
rfl-urbaniak Sep 20, 2024
2548754
pkl to gitignore
rfl-urbaniak Sep 20, 2024
bcc7617
locust wip
rfl-urbaniak Sep 24, 2024
1ba8214
locust WIP
rfl-urbaniak Sep 24, 2024
05b2ce3
locust wip
rfl-urbaniak Sep 27, 2024
52d0f14
Some polishing of the random/hungry/followers notebooks and fixing st…
dimkab Oct 2, 2024
9ee6bad
Merge branch 'ru-random-hungry-2' of https://github.com/BasisResearch…
rfl-urbaniak Oct 2, 2024
f8cb283
format
rfl-urbaniak Oct 2, 2024
002144e
pin cleanup package versions
rfl-urbaniak Oct 2, 2024
ce347ba
upgrade and pin nbqa
rfl-urbaniak Oct 2, 2024
a52eb2b
suspend outdated velocity test
rfl-urbaniak Oct 2, 2024
4de6bab
Merge branch 'staging-collab-2' of https://github.com/BasisResearch/c…
rfl-urbaniak Oct 3, 2024
b683cef
missing import for generate_communication_predictor in __init__.py
dimkab Oct 3, 2024
cd8f8c5
decoupled communicator simulations from collab 1
rfl-urbaniak Oct 5, 2024
ff4a04d
format lint
rfl-urbaniak Oct 5, 2024
686e384
ensure starts at 0 with communicators inference, re-run
rfl-urbaniak Oct 5, 2024
946c9f3
format, lint
rfl-urbaniak Oct 5, 2024
eddce38
Merge branch 'ru-communication' of https://github.com/BasisResearch/c…
rfl-urbaniak Oct 5, 2024
219fcf0
change framing to approximately the one in paper
rfl-urbaniak Oct 5, 2024
e664c6f
re-run locust, format, lint
rfl-urbaniak Oct 5, 2024
3654da6
cp birds predictors derived
rfl-urbaniak Oct 8, 2024
6eeca70
wip
rfl-urbaniak Oct 8, 2024
c757691
update notebook wip
rfl-urbaniak Oct 10, 2024
34aecc6
wip
rfl-urbaniak Oct 10, 2024
d73506f
cp wip
rfl-urbaniak Oct 10, 2024
fa56d1f
cp wip
rfl-urbaniak Oct 12, 2024
feb7393
bayesian optimization for ducks
rfl-urbaniak Oct 13, 2024
7e773dd
ducks optimization
rfl-urbaniak Oct 15, 2024
09f3571
cp optimization wip
rfl-urbaniak Oct 16, 2024
9f3c30b
before central parks cleanup
rfl-urbaniak Oct 17, 2024
68b9512
reinstate proximity
rfl-urbaniak Oct 17, 2024
0c1d7ab
cp bnotebook revised
rfl-urbaniak Oct 17, 2024
bc193e9
format lint
rfl-urbaniak Oct 17, 2024
68bdaec
Merge branch 'staging-collab-2' of https://github.com/BasisResearch/c…
rfl-urbaniak Oct 17, 2024
bc0d5ca
added uninformed prior, format, lint
rfl-urbaniak Oct 17, 2024
6b5c22b
add bayesian opt to setup to pass CI
rfl-urbaniak Oct 17, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,6 @@ tmp_files/
# folder to temporarily store files
data/user_data/*


**/xmemo/

data/communicators/communicators_strong/*
Expand All @@ -171,3 +170,7 @@ data/foraging/locust/ds/priors_sam30_15EQ20191202_s0_e10.pkl
docs/foraging/random-hungry-followers/sim_data/hungry_sim.dill
docs/foraging/random-hungry-followers/sim_data/hungry_foragers_samples.dill
docs/foraging/random-hungry-followers/sim_data/hungry_foragers_samples.dill
docs/foraging/central-park-birds/ducks_proximity_single_optimizer.pkl
docs/foraging/central-park-birds/ducks_proximity_optimizer.pkl
docs/foraging/central-park-birds/sparrows_proximity_single_optimizer.pkl
docs/foraging/central-park-birds/ducks_access_optimizer.pkl
10 changes: 9 additions & 1 deletion collab2/foraging/toolkit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,12 @@
from .filtering import constraint_filter_nearest, filter_by_distance # noqa: F401
from .food import generate_food_predictor # noqa: F401
from .inference import run_svi_inference # type: ignore # noqa: F401
from .inference import get_samples, prep_data_for_inference, summary # noqa: F401
from .inference import ( # noqa: F401
get_samples,
prep_data_for_inference,
prep_DF_data_for_inference,
summary,
)
from .local_windows import ( # noqa: F401
_generate_local_windows,
_get_grid,
Expand Down Expand Up @@ -69,4 +74,7 @@
)
from .visualization import animate_predictors, plot_predictor # noqa: F401

# from .waic import compute_waic # noqa: F401


# from .trace import rewards_to_trace, rewards_trace # noqa: F401
5 changes: 3 additions & 2 deletions collab2/foraging/toolkit/animate_foragers.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
warnings.simplefilter(action="ignore", category=FutureWarning)


def plot_trajectories(df, title):
def plot_trajectories(df, title, legend=True):
unique_foragers = df["forager"].unique()
plt.figure()

Expand Down Expand Up @@ -41,7 +41,8 @@ def plot_trajectories(df, title):
plt.axis("equal")
plt.gca().invert_yaxis()
plt.axis("off")
plt.legend()
if legend:
plt.legend()
plt.suptitle(f"Trajectories: {title}", fontsize=16)
return plt

Expand Down
58 changes: 53 additions & 5 deletions collab2/foraging/toolkit/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,44 @@ def prep_data_for_inference(
return predictor_tensors, outcome_tensors


def prep_DF_data_for_inference(
DF, predictors: List[str], outcome_vars: str, subsample_rate: float = 1.0
) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]:

if isinstance(outcome_vars, str):
outcome_list = [outcome_vars]
else:
outcome_list = outcome_vars

df = DF[predictors + outcome_list].copy()

# assert no NaNs in df
assert df.notna().all().all(), "Dataframe contains NaN values"

# Apply subsampling
if subsample_rate < 1.0:
df = df.sample(frac=subsample_rate).reset_index(drop=True)

# Apply subsampling
if subsample_rate < 1.0:
df = df.sample(frac=subsample_rate).reset_index(drop=True)

predictor_tensors = {
key: torch.tensor(df[key].values, dtype=torch.float32) for key in predictors
}
outcome_tensors = {
key: torch.tensor(df[key].values, dtype=torch.float32) for key in outcome_list
}

# print size
print("Sample size:", len(df))

# print size
print("Sample size:", len(df))

return predictor_tensors, outcome_tensors


def summary(samples, sites):
site_stats = {}
for site_name, values in samples.items():
Expand Down Expand Up @@ -79,7 +117,7 @@ def run_svi_inference(
loss.backward()
losses.append(loss.item())
adam.step()
if (step % 50 == 0) or (step == 1) & verbose:
if (step % 200 == 0) or (step == 1) & verbose:
print("[iteration %04d] loss: %.4f" % (step, loss))

if plot:
Expand All @@ -97,13 +135,20 @@ def get_samples(
outcome,
num_svi_iters,
num_samples,
plot=True,
verbose=True,
):

logging.info(f"Starting SVI inference with {num_svi_iters} iterations.")
start_time = time.time()
pyro.clear_param_store()
guide = run_svi_inference(
model, n_steps=num_svi_iters, predictors=predictors, outcome=outcome
model,
n_steps=num_svi_iters,
predictors=predictors,
outcome=outcome,
plot=plot,
verbose=verbose,
)
end_time = time.time()
elapsed_time = end_time - start_time
Expand All @@ -118,18 +163,21 @@ def get_samples(
for k, v in predictive(predictors, outcome).items()
if k != "obs"
}
print(samples.keys())

sites = [
key
for key in samples.keys()
if (key.startswith("weight") and not key.endswith("sigma"))
]
print(sites)

print("Coefficient marginals:")
for site, values in summary(samples, sites).items():
print("Site: {}".format(site))
print(values, "\n")

return {"samples": samples, "guide": guide, "predictive": predictive}
return {
"samples": samples,
"guide": guide,
"predictive": predictive,
"summaries": summary(samples, sites),
}
10 changes: 7 additions & 3 deletions collab2/foraging/toolkit/next_step_score.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,13 @@ def _generate_nextStep_score(
score[f][t]["distance_to_next_step"] = np.nan
score[f][t][score_name] = np.nan

# save nans for last frame
score[f][num_frames - 1]["distance_to_next_step"] = np.nan
score[f][num_frames - 1][score_name] = np.nan
# save nans for last fram
if (
score[f][num_frames - 1] is not None
): # conditioning needed as last frame is not always present
# as in cp birds data
score[f][num_frames - 1]["distance_to_next_step"] = np.nan
score[f][num_frames - 1][score_name] = np.nan

return score

Expand Down
5 changes: 4 additions & 1 deletion collab2/foraging/toolkit/proximity.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,10 @@ def _piecewise_proximity_function(

result = np.where(
cond1,
np.sin(np.pi / (2 * repulsion_radius) * (distance + 3 * repulsion_radius)),
np.sin(
np.pi / ((2 * repulsion_radius) * (distance + 3 * repulsion_radius))
+ 0.0001
), # division by zero errors
np.where(
cond2,
np.sin(
Expand Down
17 changes: 17 additions & 0 deletions collab2/foraging/toolkit/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,21 @@ def __init__(
self.derived_quantities: dict[str, List[List[pd.DataFrame]]] = {}
self.derivedDF: pd.DataFrame

def calculate_step_distances(self):
step_distances = []
for f in range(len(self.foragers)):
data = self.foragers[f].dropna()
unique_t = data["time"].unique()
for t in unique_t:
if (t - 1) in unique_t:
step_ago = data[data["time"] == t - 1]
xdiff = data[data["time"] == t]["x"].values - step_ago["x"].values
ydiff = data[data["time"] == t]["y"].values - step_ago["y"].values
step_distance = np.sqrt(xdiff**2 + ydiff**2)
step_distances.append(step_distance)

return np.concatenate(step_distances)

def calculate_step_size_max(self):
step_maxes = []

Expand Down Expand Up @@ -203,6 +218,8 @@ def distances_and_peaks(distances, bins=40, x_min=None, x_max=None):
color="red",
)

plt.show()


# remove rewards eaten by foragers in proximity
def update_rewards(sim, rewards, foragers, start=1, end=None):
Expand Down
2 changes: 2 additions & 0 deletions collab2/foraging/toolkit/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,9 @@ def plot_predictor(
for c in range(len(time) % ncols, ncols):
fig.delaxes(axes[c])

fig.suptitle(f"Predictor: {predictor_name}")
fig.tight_layout(pad=2)
fig.show()


def animate_predictors(
Expand Down
69 changes: 69 additions & 0 deletions collab2/foraging/toolkit/waic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
from typing import Any, Callable, Dict, Optional

import pyro
import torch
from pyro.infer.enum import get_importance_trace


def compute_waic(
model: Callable[..., Any],
guide: Callable[..., Any],
num_particles: int,
max_plate_nesting: int,
sites: Optional[list[str]] = None,
*args: Any,
**kwargs: Any
) -> Dict[str, Any]:

def vectorize(fn: Callable[..., Any]) -> Callable[..., Any]:
def _fn(*args: Any, **kwargs: Any) -> Any:
with pyro.plate(
"num_particles_vectorized", num_particles, dim=-max_plate_nesting
):
return fn(*args, **kwargs)

return _fn

model_trace, _ = get_importance_trace(
"flat", max_plate_nesting, vectorize(model), vectorize(guide), args, kwargs
)

def site_filter_is_observed(site_name: str) -> bool:
return model_trace.nodes[site_name]["is_observed"]

def site_filter_in_sites(site_name: str) -> bool:
return sites is not None and site_name in sites

if sites is None:
site_filter = site_filter_is_observed
else:
site_filter = site_filter_in_sites

observed_nodes = {
name: node for name, node in model_trace.nodes.items() if site_filter(name)
}

log_p_post = {
key: observed_nodes[key]["log_prob"].mean(dim=0) # sum(axis = 0)/num_particles
for key in observed_nodes.keys()
}

lppd = torch.stack([log_p_post[key] for key in log_p_post.keys()]).sum()

var_log_p_post = {
key: (observed_nodes[key]["log_prob"]).var(axis=0)
for key in observed_nodes.keys()
}

p_waic = torch.stack([var_log_p_post[key] for key in var_log_p_post.keys()]).sum()

waic = -2 * (lppd - p_waic)

return {
"waic": waic,
"nodes": observed_nodes,
"log_p_post": log_p_post,
"var_log_p_post": var_log_p_post,
"lppd": lppd,
"p_waic": p_waic,
}
Loading