Skip to content

Commit 5bfaa12

Browse files
committed
fix param_nonzero_b in nb_glm Estimator
1 parent 0cf8fcf commit 5bfaa12

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

batchglm/train/tf/nb_glm/estimator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -490,12 +490,12 @@ def __init__(
490490
name="a"
491491
)
492492

493-
param_nonzero_b = tf.broadcast_to(feature_isnonzero, [num_design_loc_params, num_features])
493+
param_nonzero_b = tf.broadcast_to(feature_isnonzero, [num_design_scale_params, num_features])
494494
alt_b = tf.concat([
495495
# intercept
496496
tf.broadcast_to(bounds_max["b"], [1, num_features]),
497497
# slope
498-
tf.zeros(shape=[num_design_scale_params - 1, num_features], dtype=model_vars.a.dtype),
498+
tf.zeros(shape=[num_design_scale_params - 1, num_features], dtype=model_vars.b.dtype),
499499
], axis=0, name="alt_b")
500500
b = tf.where(
501501
param_nonzero_b,

0 commit comments

Comments
 (0)