Hi, thank you for the great package.
I am working with the transition module of mp_srlds and came across the code:
|
Ez = np.sum(expected_joints, axis=2) # marginal over z from T=1 to T-1 |
|
for k1 in range(self.K): |
|
for k2 in range(self.K): |
|
vtilde = vtildes[:,k1,k2][:,None] # SWAP? |
|
#Sticky terms |
|
if k1==k2: |
|
Rv = vtilde@self.Ss[k2:k2+1,:] |
|
hess += Ez[k1,k2] * \ |
|
( np.einsum('tn, ni, nj ->tij', -vtilde, self.Ss[k2:k2+1,:], self.Ss[k2:k2+1,:]) \ |
|
+ np.einsum('ti, tj -> tij', Rv, Rv)) |
|
#Switching terms |
|
else: |
|
Rv = vtilde@self.Rs[k2:k2+1,:] |
|
hess += Ez[k1,k2] * \ |
|
( np.einsum('tn, ni, nj ->tij', -vtilde, self.Rs[k2:k2+1,:], self.Rs[k2:k2+1,:]) \ |
|
+ np.einsum('ti, tj -> tij', Rv, Rv)) |
where on line 89
Ez was indexed by
k1 and
k2. However on line 82:
Ez = np.sum(expected_joints, axis=2) # marginal over z from T=1 to T-1
and after checking the dimensions of
expected_joints:
|
# Compute E[z_t, z_{t+1}] for t = 1, ..., T-1 |
|
# Note that this is an array of size T*K*K, which can be quite large. |
|
# To be a bit more frugal with memory, first check if the given log_Ps |
|
# are TxKxK. If so, instantiate the full expected joints as well, since |
|
# we will need them for the M-step. However, if log_Ps is 1xKxK then we |
|
# know that the transition matrix is stationary, and all we need for the |
|
# M-step is the sum of the expected joints. |
|
stationary = (Ps.shape[0] == 1) |
|
if not stationary: |
|
expected_joints = alphas[:-1,:,None] + betas[1:,None,:] + ll[1:,None,:] + log_Ps |
|
expected_joints -= expected_joints.max((1,2))[:,None, None] |
|
expected_joints = np.exp(expected_joints) |
|
expected_joints /= expected_joints.sum((1,2))[:,None,None] |
I believe it should have dimensions
(T-1, K, K). As a result
Ez would have dimensions
(T-1, K), but as shown above the time dimension was actually indexed using
k1, which is a bit confusing to me.
Could you clarify if this behavior is intentional, or if there might be a mistake in how Ez is used? I may be missing something here, so I’d appreciate your insight. Thanks for your time and support!
Hi, thank you for the great package.
I am working with the transition module of mp_srlds and came across the code:
ssm/ssm/extensions/mp_srslds/transitions_ext.py
Lines 82 to 97 in 6c856ad
where on line 89
Ezwas indexed byk1andk2. However on line 82:Ez = np.sum(expected_joints, axis=2) # marginal over z from T=1 to T-1and after checking the dimensions of
expected_joints:ssm/ssm/messages.py
Lines 186 to 198 in 6c856ad
I believe it should have dimensions
(T-1, K, K). As a resultEzwould have dimensions(T-1, K), but as shown above the time dimension was actually indexed usingk1, which is a bit confusing to me.Could you clarify if this behavior is intentional, or if there might be a mistake in how Ez is used? I may be missing something here, so I’d appreciate your insight. Thanks for your time and support!