@@ -199,20 +199,24 @@ def feature_wise_hessians(X, design_loc, design_scale, a, b, size_factors=None)
199199
200200 def hessian (data ): # data is tuple (X_t, a_t, b_t)
201201 X_t , a_t , b_t = data
202- X = tf .transpose (X_t )
203- a = tf .transpose (a_t )
204- b = tf .transpose (b_t )
202+ X = tf .transpose (X_t ) # observations x features
203+ a = tf .transpose (a_t ) # design_loc_params x features
204+ b = tf .transpose (b_t ) # design_scale_params x features
205205
206- model = BasicModelGraph (X , design_loc , design_scale , a , b , size_factors = size_factors )
206+ # cheat Tensorflow to get also dX^2/(da,db)
207+ param_vec = tf .concat ([a , b ], axis = 0 , name = "param_vec" )
208+ a_split , b_split = tf .split (param_vec , tf .TensorShape ([a .shape [0 ], b .shape [0 ]]))
207209
208- hess = tf .hessians (- model .log_likelihood , [a , b ])
210+ model = BasicModelGraph (X , design_loc , design_scale , a_split , b_split , size_factors = size_factors )
211+
212+ hess = tf .hessians (- model .log_likelihood , param_vec )
209213
210214 return hess
211215
212216 hessians = tf .map_fn (
213217 fn = hessian ,
214218 elems = (X_t , a_t , b_t ),
215- dtype = [tf .float32 , tf . float32 ], # hessians of [a, b]
219+ dtype = [tf .float32 ], # hessians of [a, b]
216220 parallel_iterations = pkg_constants .TF_LOOP_PARALLEL_ITERATIONS
217221 )
218222
@@ -271,6 +275,7 @@ def hessian_red(prev, cur):
271275 reduce_fn = hessian_red ,
272276 parallel_iterations = 1 ,
273277 )
278+ hessians = hessians [0 ]
274279
275280 self .X = model .X
276281 self .design_loc = model .design_loc
@@ -395,17 +400,17 @@ def __init__(
395400 )
396401 full_data_loss = full_data_model .loss
397402
398- with tf .name_scope ("hessian_diagonal" ):
399- hessian_diagonal = [
400- tf .map_fn (
401- # elems=tf.transpose(hess, perm=[2, 0, 1]),
402- elems = hess ,
403- fn = tf .diag_part ,
404- parallel_iterations = pkg_constants .TF_LOOP_PARALLEL_ITERATIONS
405- )
406- for hess in full_data_model .hessians
407- ]
408- fisher_a , fisher_b = hessian_diagonal
403+ # with tf.name_scope("hessian_diagonal"):
404+ # hessian_diagonal = [
405+ # tf.map_fn(
406+ # # elems=tf.transpose(hess, perm=[2, 0, 1]),
407+ # elems=hess,
408+ # fn=tf.diag_part,
409+ # parallel_iterations=pkg_constants.TF_LOOP_PARALLEL_ITERATIONS
410+ # )
411+ # for hess in full_data_model.hessians
412+ # ]
413+ # fisher_a, fisher_b = hessian_diagonal
409414
410415 mu = full_data_model .mu
411416 r = full_data_model .r
@@ -552,12 +557,8 @@ def __init__(
552557
553558 # we are minimizing the negative LL instead of maximizing the LL
554559 # => invert hessians
555- self .hessian_diagonal = - tf .concat ([
556- fisher_a ,
557- fisher_b ,
558- ], axis = - 1 )
559- self .fisher_loc = tf .transpose (fisher_a , name = "fisher_loc" )
560- self .fisher_scale = tf .transpose (fisher_b , name = "fisher_scale" )
560+ self .hessians = - full_data_model .hessians
561+ self .fisher_inv = tf .matrix_inverse (full_data_model .hessians )
561562
562563 with tf .name_scope ('summaries' ):
563564 tf .summary .histogram ('a' , model_vars .a )
@@ -715,7 +716,7 @@ def __init__(self,
715716 X = input_data .X .assign_coords (group = (("observations" ,), inverse_idx ))
716717 mean = X .groupby ("group" ).mean (dim = "observations" )
717718
718- [X [inverse_idx == i ].mean (dim = "observations" ).values for i in np .unique (inv_design )]
719+ [X [inverse_idx == i ].mean (dim = "observations" ).values for i in np .unique (inv_design )]
719720 a = np .log (mean )
720721 # a = a * np.eye(np.size(a))
721722 a_prime = np .matmul (inv_design , a )
@@ -978,16 +979,12 @@ def gradient(self):
978979 return self .to_xarray ("full_gradient" , coords = self .input_data .data .coords )
979980
980981 @property
981- def hessian_diagonal (self ):
982- return self .to_xarray ("hessian_diagonal" , coords = self .input_data .data .coords )
983-
984- @property
985- def fisher_loc (self ):
986- return self .to_xarray ("fisher_loc" , coords = self .input_data .data .coords )
982+ def hessians (self ):
983+ return self .to_xarray ("hessians" , coords = self .input_data .data .coords )
987984
988985 @property
989- def fisher_scale (self ):
990- return self .to_xarray ("fisher_scale " , coords = self .input_data .data .coords )
986+ def fisher_inv (self ):
987+ return self .to_xarray ("fisher_inv " , coords = self .input_data .data .coords )
991988
992989 def finalize (self ):
993990 store = XArrayEstimatorStore (self )
0 commit comments