From 5eb3f58f2bcfca73fceb650aaaa0d98f8336a083 Mon Sep 17 00:00:00 2001 From: Michael Condon Date: Thu, 4 Dec 2025 14:38:39 +0000 Subject: [PATCH 1/2] Added spectral learning for categorical HMMs. --- .../models/categorical_hmm.py | 188 +++++++++++++++++- dynamax/ssm.py | 29 +++ dynamax/utils/utils.py | 77 ++++++- 3 files changed, 291 insertions(+), 3 deletions(-) diff --git a/dynamax/hidden_markov_model/models/categorical_hmm.py b/dynamax/hidden_markov_model/models/categorical_hmm.py index 7d3b2ff2..2f689101 100644 --- a/dynamax/hidden_markov_model/models/categorical_hmm.py +++ b/dynamax/hidden_markov_model/models/categorical_hmm.py @@ -1,7 +1,8 @@ """Categorical Hidden Markov Model.""" -from typing import NamedTuple, Optional, Tuple, Union +from typing import NamedTuple, Optional, Tuple, Union, List import jax.numpy as jnp +from jax import lax import jax.random as jr import tensorflow_probability.substrates.jax.bijectors as tfb import tensorflow_probability.substrates.jax.distributions as tfd @@ -15,7 +16,8 @@ from dynamax.hidden_markov_model.models.transitions import StandardHMMTransitions from dynamax.parameters import ParameterProperties, ParameterSet, PropertySet from dynamax.types import IntScalar, Scalar -from dynamax.utils.utils import pytree_sum +from dynamax.types import PRNGKeyT +from dynamax.utils.utils import pytree_sum, ensure_array_has_batch_dim, low_rank_pinv, multilinear_product, cp_decomp class ParamsCategoricalHMMEmissions(NamedTuple): @@ -118,6 +120,43 @@ def m_step(self, params, props, batch_stats, m_step_state): probs = tfd.Dirichlet(self.emission_prior_concentration + emission_stats['sum_x']).mode() params = params._replace(probs=probs) return params, m_step_state + + def calc_sample_moment(self, + emissions: Float[Array, "num_batches num_timesteps emission_dim"], + order: Union[int, + List[int]]): + r"""Find the sample cross moments of order $n$. These are averaged over the + full timeseries because the HMM is time homogeneous, so for example the following + are assumed interchangeable: + + $$\mathbb{E}[x_1 \otimes x_2 \otimes x_3]$$ + + $$\mathbb{E}[x_{t+1} \otimes x_{t+2} \dots x_{t+3}]$$ + """ + x = one_hot(jnp.squeeze(emissions, -1), num_classes=self.num_classes) + B, T, _ = x.shape + if isinstance(order, int): + order = list(range(order)) + order_len = max(order)+1 + T_effective = T - order_len + 1 + if T_effective <= 0: + raise ValueError + + einsum_args = [] + output_indices = [] + for i, j in enumerate(order): + slice_j = x[:, j:T_effective+j, :] + einsum_args.append(slice_j) + einsum_args.append([0, 1, 2+j]) + output_indices.append(2+j) + + einsum_args.append(output_indices) + sum_outer_products = jnp.einsum(*einsum_args) + + return sum_outer_products / (B * T) + + def calc_pos_sample_mean(self, emissions, pos): + return jnp.mean(one_hot(jnp.squeeze(emissions, -1), num_classes=self.num_classes)[:,pos,:], axis=0) class CategoricalHMM(HMM): @@ -186,3 +225,148 @@ def initialize(self, params["transitions"], props["transitions"] = self.transition_component.initialize(key2, method=method, transition_matrix=transition_matrix) params["emissions"], props["emissions"] = self.emission_component.initialize(key3, method=method, emission_probs=emission_probs) return ParamsCategoricalHMM(**params), ParamsCategoricalHMM(**props) + + + def get_view(self, + batch_emissions: Float[Array, "num_batches num_timesteps emission_dim"], + target: int, + num_init: int = 100, + num_iter: int = 1000, + key: Array = jr.PRNGKey(0) + ) -> Array: + r"""Return the sample conditional means from the requested view. + + Specifically, return the conditional mean of the emissions at time t+target + given the hidden state at time t+1. + + $$\mathbb{E}[x_{t+target} \mid y_{t+1}=h]$$ + + Args: + batch_emissions: the emission data. + target: the requested timestep relative to the hidden state being conditioned on. + num_init: number of random starting points should be used in the robust tensor power method. + num_iter: number of iterations in the robust tensor power method. + key (PRNGKey, optional): random number generator for unspecified parameters. Must not be None if there are any unspecified parameters. Defaults to None. + + Returns: + Conditional mean vector $\mu$. + """ + k = self.num_states + sym_M2, sym_M3 = self.moment_view(batch_emissions, target, k) + # find whitening matrix W + eigvals, eigvecs = jnp.linalg.eigh(sym_M2) + + idx = jnp.argsort(jnp.abs(eigvals))[-k:] + trunc_eigvals = eigvals[idx] + trunc_eigvecs = eigvecs[:,idx] + + U = trunc_eigvecs + D = jnp.diag(1/jnp.sqrt(trunc_eigvals)) + + W = U @ D + B = jnp.linalg.pinv(W.T) + + # tensor decomposition + tilde_sym_M3 = multilinear_product(sym_M3, [W, W, W]) + rob_eigvecs, rob_eigvals = cp_decomp(tilde_sym_M3, L=num_init, N=num_iter, k=k, key=key) + + return jnp.diag(rob_eigvals) @ rob_eigvecs @ B.T + + + def fit_moments( + self, + params: ParameterSet, + props: PropertySet, + emissions: Union[Float[Array, "num_timesteps emission_dim"], + Float[Array, "num_batches num_timesteps emission_dim"]], + num_init: int=100, + num_iter: int=1000, + key: Array=jr.PRNGKey(0) + ) -> ParameterSet: + r"""Estimate the parameters using method of moments. + + Specifically, compute emission distribution and transition matrix from the second + and third moments. Since the model is time homogeneous, you can take it over all + consecutive 2 or 3 timesteps respectively. To recover the initial distribution, take + the mean over the first timestep of each sequence using the known emission + distribution to find the hidden state distribution. + + Then + + Args: + params: model parameters $\theta$ + props: properties specifying which parameters should be learned + emissions: observations from data. + num_init: number of random starting points should be used in the robust tensor power method. + num_iter: number of iterations in the robust tensor power method. + key: sufficient statistics from each sequence + + Returns: + new parameters + + """ + batch_emissions = ensure_array_has_batch_dim(emissions, self.emission_shape) + key_2, key_3 = jr.split(key, 2) + mu_1 = self.get_view(batch_emissions, 1, 100,1000, key_2) + mu_2 = self.get_view(batch_emissions, 2, 100,1000, key_3) + k = self.num_states + + transition_params = mu_2 @ low_rank_pinv(mu_1, k) + emission_params = mu_1 + + initial_params = low_rank_pinv(emission_params.T, k) @ self.emission_component.calc_pos_sample_mean(batch_emissions, 0) + params = params._replace(initial=initial_params, transitions=transition_params, emissions=emission_params) + + return params, + + + def moment_view(self, + batch_emissions: Float[Array, "num_batches num_timesteps emission_dim"], + target: int, + k: int + ) -> Tuple[Array, Array]: + r"""Perform the symmetrizing operation to get a particular view of the + second and third order moments. + + Specifically, compute the second and third moments. Since the model is time + homogeneous, you can take it over all consecutive 2 or 3 timesteps respectively. + + Then + + Args: + batch_emissions: the emission data. + target: the requested view. + k: the number of hidden states. + + Returns: + sym_M2: symmetrized second order moment corresponding to view of `target`. + sym_M3: symmetrized third order moment corresponding to view of `target`. + """ + + if target == 0: + source_1, source_2 = 1, 2 + elif target == 1: + source_1, source_2 = 0, 2 + else: + source_1, source_2 = 0, 1 + + A = self.emission_component.calc_sample_moment(batch_emissions, [target, source_2]) + B = self.emission_component.calc_sample_moment(batch_emissions, [source_1, source_2]) + C = self.emission_component.calc_sample_moment(batch_emissions, [target, source_1]) + D = self.emission_component.calc_sample_moment(batch_emissions, [source_2, source_1]) + + M2 = self.emission_component.calc_sample_moment(batch_emissions, [source_1, source_2]) + M3 = self.emission_component.calc_sample_moment(batch_emissions, 3) + + d = self.emission_component.num_classes + + sym_pre = jnp.transpose(A @ low_rank_pinv(B, k)) + sym_post = jnp.transpose(C @ low_rank_pinv(D, k)) + + M3_args = [jnp.eye(d)]*3 + M3_args[source_1] = sym_pre + M3_args[source_2] = sym_post + + sym_M2 = multilinear_product(M2, [sym_pre, sym_post]) + sym_M3 = multilinear_product(M3, M3_args) + return sym_M2, sym_M3 diff --git a/dynamax/ssm.py b/dynamax/ssm.py index fb844448..6bcfde61 100644 --- a/dynamax/ssm.py +++ b/dynamax/ssm.py @@ -478,3 +478,32 @@ def _loss_fn(unc_params, minibatch): params = from_unconstrained(unc_params, props) return params, losses + + def fit_moments( + self, + params: ParameterSet, + props: PropertySet, + emissions: Union[Float[Array, "num_timesteps emission_dim"], + Float[Array, "num_batches num_timesteps emission_dim"]], + key: Array=jr.PRNGKey(0) + ) -> ParameterSet: + r"""Estimate the parameters using method of moments. + + Specifically, compute the second and third moments. Since the model is time + homogeneous, you can take it over all consecutive 2 or 3 timesteps respectively. + + $$M_2 = \mathbb{E}[x_1 \otimes x_2]$$ + $$M_3 = \mathbb{E}[x_1 \otimes x_2 \otimes x_3]$$ + + Then + + Args: + params: model parameters $\theta$ + props: properties specifying which parameters should be learned + key: sufficient statistics from each sequence + + Returns: + new parameters + + """ + raise NotImplemented \ No newline at end of file diff --git a/dynamax/utils/utils.py b/dynamax/utils/utils.py index 3dd439c7..603b81cb 100644 --- a/dynamax/utils/utils.py +++ b/dynamax/utils/utils.py @@ -8,7 +8,7 @@ from functools import partial from jax import jit -from jax import vmap +from jax import vmap, lax from jax.tree_util import tree_map, tree_leaves, tree_flatten, tree_unflatten from jaxtyping import Array, Int from scipy.optimize import linear_sum_assignment @@ -220,3 +220,78 @@ def psd_solve(A, b, diagonal_boost=1e-9): def symmetrize(A): """Symmetrize one or more matrices.""" return 0.5 * (A + jnp.swapaxes(A, -1, -2)) + +def multilinear_product(core, factors): + """Multilinear map of a core tensor of order p with p matrices. + For an order 3 core tensor of shape (I, J, K), the factor matrices have + shape (P, I), (Q, J) and (R, K) respectively. The output has shape + (P, Q, R). + """ + order = core.ndim + assert order == len(factors) + einsum_args = [core, list(range(order))] + for i, factor in enumerate(factors): + einsum_args.append(factor) + einsum_args.append([i, i+order]) + einsum_args.append(list(range(order, 2*order))) + + return jnp.einsum(*einsum_args) + +def low_rank_pinv(X, k): + """Find the Moore-Penrose Pseudoinverse of a matrix X with rank k. + This is more robust than jnp.linalg.pinv since the sample cross moments + will likely have higher rank than the population cross moments. + + Here, we find the SVD which sorts in descending order of the singular + values and truncate the first k. + """ + u, s, vt = jnp.linalg.svd(X) + u_trunc = u[:,:k] + s_trunc = s[:k] + vt_trunc = vt[:k,:] + return vt_trunc.T @ jnp.diag(1.0/s_trunc) @ u_trunc.T + +def rtpm_eigvals(X, y): + """Find the eigenvalues of X corresponding to the eigenvectors $y$.""" + return multilinear_product(X, [y, y, y]) + +def rtpm(X, key=jr.PRNGKey(0), L=100, N=1000): + """Applies the robust tensor power method to a tensor X and returns the + deflated tensor, robust eigenvectors and eigenvalues. + """ + assert X.ndim == 3 + assert len(set(X.shape)) == 1 + keys = jr.split(key, L) + k = X.shape[0] + + def power_iter_update(theta, _): + mlm = multilinear_product(X, [jnp.eye(k), theta, theta]).squeeze(-1) + return jnp.divide(mlm, jnp.linalg.norm(mlm)), None + + def theta_sample(theta_key): + # random point on the unit sphere in R^k + Z = jr.normal(theta_key, shape=(k,1)) + norm_Z = jnp.linalg.norm(Z) + theta_init = jnp.divide(Z, norm_Z) + + theta_N, _ = lax.scan(power_iter_update, + theta_init, + length=N) + return theta_N + + theta_arr = vmap(theta_sample)(keys) + tau_star = jnp.argmax(vmap(partial(rtpm_eigvals, X))(theta_arr)) + theta_hat, _ = lax.scan(power_iter_update, + theta_arr[tau_star], + length=N) + lambda_hat = rtpm_eigvals(X, theta_hat).squeeze() + theta_hat = theta_hat.squeeze() + def_X = X - lambda_hat * jnp.einsum('a,b,c-> a b c', theta_hat, theta_hat, theta_hat) + return def_X, (theta_hat, lambda_hat) + +def cp_decomp(X, L, N, k, key): + """Apply the robust tensor power method iteratively, returning the robust + eigenvectors and eigenvalues. + """ + _, (eigvecs, eigvals) = lax.scan(partial(rtpm, L=L, N=N), X, jr.split(key,k)) + return eigvecs, eigvals \ No newline at end of file From 12c5faceec681db8b9d8edd76ce02f342875c0b3 Mon Sep 17 00:00:00 2001 From: Michael Condon Date: Thu, 4 Dec 2025 21:59:43 +0000 Subject: [PATCH 2/2] minor update Signed-off-by: Michael Condon --- dynamax/utils/utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/dynamax/utils/utils.py b/dynamax/utils/utils.py index 603b81cb..6fd23767 100644 --- a/dynamax/utils/utils.py +++ b/dynamax/utils/utils.py @@ -245,6 +245,7 @@ def low_rank_pinv(X, k): Here, we find the SVD which sorts in descending order of the singular values and truncate the first k. """ + u, s, vt = jnp.linalg.svd(X) u_trunc = u[:,:k] s_trunc = s[:k]