Skip to content

Commit 5cfbe46

Browse files
committed
added full model hessian and fisher_inv
1 parent 5bfaa12 commit 5cfbe46

File tree

5 files changed

+66
-46
lines changed

5 files changed

+66
-46
lines changed

batchglm/models/nb_glm/base.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,8 @@
3535
ESTIMATOR_PARAMS.update({
3636
"loss": (),
3737
"gradient": ("features",),
38-
"hessian_diagonal": ("features", "variables",),
39-
"fisher_loc": ("design_loc_params", "features"),
40-
"fisher_scale": ("design_scale_params", "features"),
38+
"hessians": ("features", "delta_var0", "delta_var1"),
39+
"fisher_inv": ("features", "delta_var0", "delta_var1"),
4140
})
4241

4342

@@ -393,7 +392,7 @@ def validate_data(self, **kwargs):
393392

394393
def __init__(self, estim: AbstractEstimator):
395394
input_data = estim.input_data
396-
params = estim.to_xarray(["a", "b", "loss", "gradient", "fisher_loc", "fisher_scale"], coords=input_data.data)
395+
params = estim.to_xarray(["a", "b", "loss", "gradient", "hessians", "fisher_inv"], coords=input_data.data)
397396

398397
XArrayModel.__init__(self, input_data, params)
399398

@@ -410,9 +409,9 @@ def gradient(self):
410409
return self.params["loss"]
411410

412411
@property
413-
def fisher_loc(self):
414-
return self.params["fisher_loc"]
412+
def hessians(self):
413+
return self.params["hessians"]
415414

416415
@property
417-
def fisher_scale(self):
418-
return self.params["fisher_scale"]
416+
def fisher_inv(self):
417+
return self.params["fisher_inv"]

batchglm/train/tf/base.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,25 @@
1313
from .train import StopAtLossHook, TimedRunHook
1414

1515

16+
# def model_param(f: callable, key: str, param_dict):
17+
# """
18+
# Special decorator for TFEstimator's model params.
19+
#
20+
# :param f: the function to decorate
21+
# :param key: the name of the data item to fetch
22+
# :param param_dict: the dict where to add the function
23+
# :return: decorated function without the "data" parameter
24+
# """
25+
#
26+
# def wrap_fn(self, *args, **kwargs):
27+
# data = self._get_unsafe(key)
28+
# return f(self, data, *args, **kwargs)
29+
#
30+
# param_dict[key] = wrap_fn
31+
#
32+
# return wrap_fn
33+
34+
1635
class TFEstimatorGraph(metaclass=abc.ABCMeta):
1736
graph: tf.Graph
1837
loss: tf.Tensor
@@ -88,9 +107,13 @@ class TrainingStrategy(Enum):
88107
session: tf.Session
89108
feed_dict: Dict[Union[Union[tf.Tensor, tf.Operation], Any], Any]
90109

110+
_param_decorators: Dict[str, callable]
111+
91112
def __init__(self, tf_estimator_graph):
92113
self.model = tf_estimator_graph
93114
self.session = None
115+
116+
self._param_decorators = dict()
94117

95118
def initialize(self):
96119
self.close_session()
@@ -135,11 +158,11 @@ def get(self, key: Union[str, Iterable]) -> Union[Any, Dict[str, Any]]:
135158

136159
@property
137160
def global_step(self):
138-
return self.get("global_step")
161+
return self._get_unsafe("global_step")
139162

140163
@property
141164
def loss(self):
142-
return self.get("loss")
165+
return self._get_unsafe("loss")
143166

144167
def _train_to_convergence(self,
145168
train_op,

batchglm/train/tf/nb_glm/estimator.py

Lines changed: 29 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -199,20 +199,24 @@ def feature_wise_hessians(X, design_loc, design_scale, a, b, size_factors=None)
199199

200200
def hessian(data): # data is tuple (X_t, a_t, b_t)
201201
X_t, a_t, b_t = data
202-
X = tf.transpose(X_t)
203-
a = tf.transpose(a_t)
204-
b = tf.transpose(b_t)
202+
X = tf.transpose(X_t) # observations x features
203+
a = tf.transpose(a_t) # design_loc_params x features
204+
b = tf.transpose(b_t) # design_scale_params x features
205205

206-
model = BasicModelGraph(X, design_loc, design_scale, a, b, size_factors=size_factors)
206+
# cheat Tensorflow to get also dX^2/(da,db)
207+
param_vec = tf.concat([a, b], axis=0, name="param_vec")
208+
a_split, b_split = tf.split(param_vec, tf.TensorShape([a.shape[0], b.shape[0]]))
207209

208-
hess = tf.hessians(-model.log_likelihood, [a, b])
210+
model = BasicModelGraph(X, design_loc, design_scale, a_split, b_split, size_factors=size_factors)
211+
212+
hess = tf.hessians(-model.log_likelihood, param_vec)
209213

210214
return hess
211215

212216
hessians = tf.map_fn(
213217
fn=hessian,
214218
elems=(X_t, a_t, b_t),
215-
dtype=[tf.float32, tf.float32], # hessians of [a, b]
219+
dtype=[tf.float32], # hessians of [a, b]
216220
parallel_iterations=pkg_constants.TF_LOOP_PARALLEL_ITERATIONS
217221
)
218222

@@ -271,6 +275,7 @@ def hessian_red(prev, cur):
271275
reduce_fn=hessian_red,
272276
parallel_iterations=1,
273277
)
278+
hessians = hessians[0]
274279

275280
self.X = model.X
276281
self.design_loc = model.design_loc
@@ -395,17 +400,17 @@ def __init__(
395400
)
396401
full_data_loss = full_data_model.loss
397402

398-
with tf.name_scope("hessian_diagonal"):
399-
hessian_diagonal = [
400-
tf.map_fn(
401-
# elems=tf.transpose(hess, perm=[2, 0, 1]),
402-
elems=hess,
403-
fn=tf.diag_part,
404-
parallel_iterations=pkg_constants.TF_LOOP_PARALLEL_ITERATIONS
405-
)
406-
for hess in full_data_model.hessians
407-
]
408-
fisher_a, fisher_b = hessian_diagonal
403+
# with tf.name_scope("hessian_diagonal"):
404+
# hessian_diagonal = [
405+
# tf.map_fn(
406+
# # elems=tf.transpose(hess, perm=[2, 0, 1]),
407+
# elems=hess,
408+
# fn=tf.diag_part,
409+
# parallel_iterations=pkg_constants.TF_LOOP_PARALLEL_ITERATIONS
410+
# )
411+
# for hess in full_data_model.hessians
412+
# ]
413+
# fisher_a, fisher_b = hessian_diagonal
409414

410415
mu = full_data_model.mu
411416
r = full_data_model.r
@@ -552,12 +557,8 @@ def __init__(
552557

553558
# we are minimizing the negative LL instead of maximizing the LL
554559
# => invert hessians
555-
self.hessian_diagonal = - tf.concat([
556-
fisher_a,
557-
fisher_b,
558-
], axis=-1)
559-
self.fisher_loc = tf.transpose(fisher_a, name="fisher_loc")
560-
self.fisher_scale = tf.transpose(fisher_b, name="fisher_scale")
560+
self.hessians = - full_data_model.hessians
561+
self.fisher_inv = tf.matrix_inverse(full_data_model.hessians)
561562

562563
with tf.name_scope('summaries'):
563564
tf.summary.histogram('a', model_vars.a)
@@ -715,7 +716,7 @@ def __init__(self,
715716
X = input_data.X.assign_coords(group=(("observations",), inverse_idx))
716717
mean = X.groupby("group").mean(dim="observations")
717718

718-
[X[inverse_idx==i].mean(dim="observations").values for i in np.unique(inv_design)]
719+
[X[inverse_idx == i].mean(dim="observations").values for i in np.unique(inv_design)]
719720
a = np.log(mean)
720721
# a = a * np.eye(np.size(a))
721722
a_prime = np.matmul(inv_design, a)
@@ -978,16 +979,12 @@ def gradient(self):
978979
return self.to_xarray("full_gradient", coords=self.input_data.data.coords)
979980

980981
@property
981-
def hessian_diagonal(self):
982-
return self.to_xarray("hessian_diagonal", coords=self.input_data.data.coords)
983-
984-
@property
985-
def fisher_loc(self):
986-
return self.to_xarray("fisher_loc", coords=self.input_data.data.coords)
982+
def hessians(self):
983+
return self.to_xarray("hessians", coords=self.input_data.data.coords)
987984

988985
@property
989-
def fisher_scale(self):
990-
return self.to_xarray("fisher_scale", coords=self.input_data.data.coords)
986+
def fisher_inv(self):
987+
return self.to_xarray("fisher_inv", coords=self.input_data.data.coords)
991988

992989
def finalize(self):
993990
store = XArrayEstimatorStore(self)

batchglm/unit_test/test_nb.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def test_default_fit(self):
5454
estimator = estimator.finalize()
5555
print(estimator.mu.values)
5656
print(estimator.gradient.values)
57-
print(estimator.hessian_diagonal.values)
57+
print(estimator.hessians.values)
5858
print(estimator.probs().values)
5959
print(estimator.log_probs().values)
6060

batchglm/unit_test/test_nb_glm.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def estimate(input_data: InputData, working_dir: str):
3030
)
3131
input_data.save(os.path.join(working_dir, "input_data.h5"))
3232

33-
estimator.train_sequence()
33+
estimator.train_sequence(training_strategy="QUICK")
3434

3535
return estimator
3636

@@ -69,7 +69,7 @@ def test_default_fit(self):
6969
estimator = estimator.finalize()
7070
print(estimator.mu.values)
7171
print(estimator.gradient.values)
72-
print(estimator.hessian_diagonal.values)
72+
print(estimator.hessians.values)
7373
print(estimator.probs().values)
7474
print(estimator.log_probs().values)
7575

@@ -116,7 +116,8 @@ def test_nonconfounded_fit(self):
116116
estimator = estimator.finalize()
117117
print(estimator.mu.values)
118118
print(estimator.gradient.values)
119-
print(estimator.hessian_diagonal.values)
119+
print(estimator.hessians.values)
120+
print(estimator.fisher_inv.values)
120121
print(estimator.probs().values)
121122
print(estimator.log_probs().values)
122123

0 commit comments

Comments
 (0)