Conversation
Critical fixes: - Fix negative sampling gradient in SparseCategoricalSoftmaxOutputLayer Previously added +1.0 to negative samples, now correctly computes softmax gradients (true: prob-1.0, negative: prob, other: 0) - Add unigram^0.75 distribution for negative sampling in Vocab - Add word2vec-style embedding initialization (uniform [-0.5/dim, 0.5/dim]) - Add exists_ok parameter to layer registry for deserialization support Evaluation improvements: - Add comprehensive word analogy evaluation (semantic and syntactic) - Implement nearest neighbor search with cosine similarity - Add CLI eval command for model evaluation - Support word arithmetic calculations
RNN layer features: - Vanilla RNN with BPTT (Backpropagation Through Time) - Support for both many-to-one and many-to-many architectures - Gradient clipping to prevent exploding gradients - Multiple parameter initialization methods (Xavier, He, random) - Proper hidden state management across time steps Additional components: - RNN module wrapper following Dense module pattern - Comprehensive test suite with 27 tests covering: * Forward/backward propagation * Gradient computation and clipping * Serialization/deserialization * Parameter initialization methods * Various sequence configurations - RNN language model demo with CLI interface
Issues added for each failing test skipped.
📝 WalkthroughWalkthroughAdds Huffman-tree and hierarchical softmax support, a full Recurrent RNN layer and RNN module, configurable word2vec softmax strategies, extensive sample CLIs and tests, layer exists_ok propagation, optimiser handling for output-layer gradients, Nix/Just workflow updates, and ETag-based resource caching. Changes
Estimated code review effort🎯 5 (Critical) | ⏱️ ~120 minutes Possibly related issues
Possibly related PRs
Poem
🚥 Pre-merge checks | ✅ 1 | ❌ 2❌ Failed checks (1 warning, 1 inconclusive)
✅ Passed checks (1 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
🧪 Generate unit tests (beta)
Tip Issue Planner is now in beta. Read the docs and try it out! Share your feedback on Discord. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 16
Note
Due to the large number of review comments, Critical, Major severity comments were prioritized as inline comments.
🤖 Fix all issues with AI agents
In `@mo_net/data_structures/huffman_tree.py`:
- Around line 254-282: The deserialize method in HuffmanTree restores vocab_size
but forces num_internal_nodes = vocab_size - 1 which mismatches the single-word
build() behavior (where build() sets num_internal_nodes = 1) and can break
hierarchical-softmax vector allocation; update the implementation so
num_internal_nodes is preserved or consistently computed: either (A) modify
serialize() to include "num_internal_nodes" and have deserialize() read and
assign it back to tree.num_internal_nodes, or (B) make both build(),
serialize(), and deserialize() compute num_internal_nodes using a single formula
such as max(vocab_size - 1, 1) so the value is consistent across
creation/serialization/deserialization (change code in HuffmanTree.serialize,
HuffmanTree.deserialize, and HuffmanTree.build accordingly).
- Around line 99-112: The single-word branch in build bypasses __init__ (using
cls.__new__) and sets num_internal_nodes=1, causing
serialization/deserialization mismatch; instead construct the tree via the
normal initializer and centralize single-word handling in __init__ so all
attributes (_paths, _codes, root, vocab_size, num_internal_nodes) are set
consistently, and update deserialize to compute num_internal_nodes as max(0,
vocab_size - 1) (or otherwise use the same formula as __init__) so the
single-word case yields num_internal_nodes=0 after deserialization; reference
HuffmanTree.__init__, HuffmanTree.build, HuffmanTree.deserialize, HuffmanNode,
and methods _build_paths/_paths/_codes when making the change.
In `@mo_net/model/layer/output.py`:
- Around line 487-542: The _backward_prop method currently uses Python for-loops
over batch and path nodes (in _backward_prop and calls to self.tree.get_path),
which prevents JAX tracing/JIT; also embedding_dim is unused — fix by making the
path data JAX-compatible and vectorizing the gradient computation: change
tree.get_path to return batched JAX arrays (node_indices: int32[batch,
max_path], directions: int8[batch, max_path], and path_lengths) or add a helper
that converts per-sample Python paths into padded arrays, then compute per-node
scores and sigmoids with jnp.take(self._parameters.node_vectors, node_indices,
axis=0) and jnp.einsum or jnp.sum to get scores for all batch×path entries,
compute grad = sigmoid - directions (cast) as a jnp array, then accumulate
grad_input via jnp.einsum('bpv,bp->bv', node_vectors_selected, grad) and
grad_node_vectors via jax.ops.segment_sum or jnp.zeros_like + jax.ops.index_add
(or use jax.vmap over samples with lax.scan over path dimension) to produce
grad_node_vectors; store dP with
self._cache["dP"]=d(self.Parameters(node_vectors=grad_node_vectors)), and remove
the unused embedding_dim variable.
In `@mo_net/model/layer/recurrent.py`:
- Around line 67-71: The __rsub__ implementation is backwards: when Python
computes other - self it should return other - self, not self - other; update
__rsub__ in the class containing __sub__/__add__ so it negates self then adds
other (e.g. return (-self).__add__(other) or equivalently return (-self) +
other) ensuring the operation uses the same __add__/negation semantics as
__sub__/__add__ for Parameters.
- Around line 241-255: Update Recurrent.__init__ to accept an exists_ok: bool =
False parameter and forward it to super().__init__(...) so the LayerRegistry
uniqueness check can be bypassed when desired; then modify
Recurrent.Serialized.deserialize to pass exists_ok=True when constructing
Recurrent (i.e., Recurrent(..., exists_ok=True, parameters_init_fn=...,
freeze_parameters=...)) so deserialization can safely re-create layers that
already exist. Ensure parameter name matches exactly (exists_ok) and is threaded
through to super().__init__.
In `@mo_net/optimiser/adam.py`:
- Around line 87-90: The snapshot/restore logic currently only iterates
self._model.grad_layers and omits the output layer state when output_layer is a
GradLayer; update snapshot() and restore() to also handle
self._model.output_module.output_layer when isinstance(output_layer, GradLayer):
initialize its optimizer caches (first_moment, second_moment) the same way
gradient_operation() does and include those keys in the snapshot dict and in
restore() so the output layer's optimizer state is persisted and recovered just
like other GradLayer instances.
In `@mo_net/optimiser/rmsprop.py`:
- Around line 73-76: The output layer (self._model.output_module.output_layer)
is being processed by gradient_operation but its RMSProp state
"squared_grad_avg" is never initialized, causing KeyError; update __init__ in
rmsprop.py to also create cache["squared_grad_avg"] for the output layer when
isinstance(output_layer, GradLayer), and mirror that change in snapshot() and
restore() so the output layer's squared_grad_avg is included in the saved state
and restored correctly; reference the symbols __init__, snapshot, restore,
gradient_operation, squared_grad_avg, GradLayer, and
self._model.output_module.output_layer when making these additions.
In `@mo_net/resources.py`:
- Around line 55-56: The current HEAD probe using requests.head(download_url)
followed by head_response.raise_for_status() can abort downloads when servers
deny HEAD; modify the logic around requests.head and
head_response.raise_for_status() (symbols: requests.head, head_response,
raise_for_status, download_url) to catch exceptions
(requests.RequestException/HTTPError) from the HEAD call and treat them as
non-fatal — log or ignore the HEAD failure and fall through to the existing
GET/download path instead of returning/raising; only abort if the subsequent GET
fails.
- Around line 54-56: The requests.head and requests.get calls (head_response =
requests.head(download_url) and the subsequent requests.get(download_url, ...))
lack a timeout and can hang; update both to include a timeout argument (e.g.,
timeout=REQUEST_TIMEOUT) and introduce a single configurable constant or
function parameter (REQUEST_TIMEOUT or timeout) used for both calls so the value
is consistent and easily adjustable.
In `@mo_net/samples/rnn_language/__main__.py`:
- Around line 250-254: The Y arrays were flattened which collapses the sequence
dimension and breaks batching for BasicTrainer; instead keep Y_train and Y_val
with the same first-dimension as X (do not call .flatten()) or reshape them to
(train_size, sequence_length) so their row count matches X_train[:train_size];
update the two instantiations (the BasicTrainer call around the listed lines and
the train_pretrained call referenced) to pass Y_train[:train_size] and
Y_train[train_size:] (or appropriately reshaped/squeezed forms) so batch
indexing aligns with X_train.
- Around line 88-118: The current output flattens RNN outputs and applies
Linear+SparseCategoricalSoftmax over seq_len * vocab_size, treating all
time-steps jointly; fix by making the classifier per-time-step instead: remove
the Flatten layer and the Linear that expects input_dimensions=(sequence_length
* hidden_dim,) and instead add a Linear that maps from hidden_dim -> vocab_size
applied to each time-step (i.e., keep RNN with return_sequences=True and use a
per-step Linear with input_dimensions=(hidden_dim,) and
output_dimensions=(vocab_size,)), and change SparseCategoricalSoftmaxOutputLayer
to expect input_dimensions=(vocab_size,) so each time-step gets its own softmax;
also update any downstream generate/reshape logic (the code that reshapes model
outputs in generate) to no longer reapply a global softmax but to use the
per-step softmax probabilities.
- Around line 566-579: The sampling code double-applies softmax and applies
temperature after softmax; fix by using pre-softmax logits and applying
temperature before softmax (or, if the model currently returns probabilities
from model.forward_prop because of SparseCategoricalSoftmaxOutputLayer, stop
re-softmaxing and sample directly from those probabilities). Specifically,
change model.forward_prop to return raw logits (or add a flag/alternative method
that returns logits) so you can compute last_logits =
logits.reshape(sequence_length, vocab_size)[-1], then if temperature != 1.0 do
last_logits = last_logits / temperature, compute probs =
jax.nn.softmax(last_logits), and finally sample with jax.random.choice(subkey,
vocab_size, p=probs); alternatively, if you cannot change the model, remove the
jax.nn.softmax call and sample from the probabilities returned by
model.forward_prop without temperature scaling.
In `@mo_net/samples/word2vec/__main__.py`:
- Around line 1098-1106: Replace the dynamic __import__ usage inside the list
comprehension with the already-imported function evaluate_analogy: call
evaluate_analogy(example, model, vocab, top_k) directly when building results
for analogies; ensure the list comprehension uses the imported evaluate_analogy
symbol (not __import__) and preserves the conditional filtering for truthy
results, and remove or stop the subsequent duplicate call to evaluate_analogies
that re-evaluates examples so the examples are evaluated only once and results
is reused.
In `@mo_net/tests/data_structures/test_huffman_tree.py`:
- Around line 261-271: Update the test_serialize_single_word_tree to assert that
num_internal_nodes is preserved across serialization round-trip: after creating
tree = HuffmanTree.build(frequencies), serializing via HuffmanTree.serialize and
deserializing via HuffmanTree.deserialize into tree2, add an assertion that
tree2.num_internal_nodes == tree.num_internal_nodes in addition to the existing
checks (vocab_size and get_path). This will catch the deserialize bug where
num_internal_nodes is not reconstructed correctly.
In `@mo_net/tests/trainer/test_parallel.py`:
- Around line 99-100: The top-level pytest.skip is too broad and skips all
parameterized cases; narrow the skip to only the failing
GradientTransferTestCase for "batch_norm". Replace the bare
GradientTransferTestCase(name="batch_norm", ...) entry in the parameter list
with a pytest.param(...) wrapper that includes
marks=pytest.mark.skip(reason="TODO: batch_norm gradient transfer broken") and
remove the top-level `@pytest.mark.skip` decorator and its TODO comment from the
test function; locate the entries using GradientTransferTestCase, Model and
BatchNorm to modify the specific parameter and use pytest.param and
pytest.mark.skip to scope the skip.
🟡 Minor comments (11)
mo_net/resources.py-24-31 (1)
24-31:⚠️ Potential issue | 🟡 MinorPrefix-matching on content hash could produce false cache hits.
_find_cached_by_hashusespath.name.startswith(content_hash), which would match if one ETag/hash is a prefix of another file's hash. This is unlikely with full-length hashes, but thesha256fallback truncates to 32 chars (line 71), and an ETag could be shorter. Consider matching up to the-separator instead.Proposed fix
for path in RESOURCE_CACHE.iterdir(): - if path.name.startswith(content_hash): + if path.name.startswith(content_hash + "-"): return pathmo_net/tests/data_structures/test_huffman_tree.py-147-155 (1)
147-155:⚠️ Potential issue | 🟡 Minor
assert len(nodes) >= 0is always true — strengthen the assertion.A list's length is always ≥ 0, so this assertion never fails. If the intent is to verify that a path exists, use
> 0instead.Proposed fix
- # Path should exist - assert len(nodes) >= 0 + # Path should be non-empty (every word needs at least one internal node) + assert len(nodes) > 0mo_net/model/layer/recurrent.py-446-480 (1)
446-480:⚠️ Potential issue | 🟡 Minor
h_ton line 455 is assigned but never used (dead code).Static analysis correctly flags this (F841). The variable
h_t = hidden_states[:, t + 1, :]is computed but never referenced — the backward pass recomputes the pre-activation value directly on lines 460–464. Remove it to avoid confusion.Proposed fix
# Get current gradients and activations dh_t = ( dh_all[:, t, :] + dh_next ) # Gradient from output + gradient from next timestep - h_t = hidden_states[:, t + 1, :] # Current hidden state h_prev = hidden_states[:, t, :] # Previous hidden state x_t = input_activations[:, t, :] # Current inputmo_net/samples/sequence/simple_rnn.py-16-16 (1)
16-16:⚠️ Potential issue | 🟡 MinorUnused variable
key2.
key2is never used. Prefix with_to signal intent.✏️ Suggested fix
- key1, key2 = jax.random.split(key) + key1, _key2 = jax.random.split(key)mo_net/tests/samples/word2vec/test_negative_sampling.py-116-123 (1)
116-123:⚠️ Potential issue | 🟡 Minor
SoftmaxConfig.negative_sampling(k=5)is hardcoded whilenegative_samplesvaries.The
SoftmaxConfigalways specifiesk=5, but the model is created withnegative_samples=n_negwheren_negiterates over[1, 5, 10]. Ifnegative_sampleson the model is supposed to match the config'snegative_samples, this is inconsistent for iterations wheren_neg != 5. Consider passingk=n_negto keep them in sync:✏️ Suggested fix
for n_neg in [1, 5, 10]: model = SkipGramModel.create( vocab=vocab, embedding_dim=8, - softmax_config=SoftmaxConfig.negative_sampling(k=5), + softmax_config=SoftmaxConfig.negative_sampling(k=n_neg), key=jax.random.PRNGKey(42), negative_samples=n_neg, )mo_net/tests/samples/word2vec/test_negative_sampling.py-6-11 (1)
6-11:⚠️ Potential issue | 🟡 MinorUnused imports:
dataclassandSparseCategoricalSoftmaxOutputLayer.Neither
dataclass(line 6) norSparseCategoricalSoftmaxOutputLayer(line 11) are referenced anywhere in this file.✏️ Suggested fix
import jax import jax.numpy as jnp import pytest -from dataclasses import dataclass from mo_net.samples.word2vec.softmax_strategy import SoftmaxConfig from mo_net.samples.word2vec.vocab import Vocab from mo_net.samples.word2vec import SkipGramModel -from mo_net.model.layer.output import SparseCategoricalSoftmaxOutputLayermo_net/tests/samples/word2vec/test_models.py-173-175 (1)
173-175:⚠️ Potential issue | 🟡 MinorUnused
batch_sizevariable (Ruff F841).
batch_sizeis assigned on Line 173 but never referenced—XandYare hard-coded. Either use it to construct the arrays or remove it.Proposed fix
- batch_size = 4 X = jnp.array([[0, 1], [1, 2], [2, 3], [3, 4]]) Y = jnp.array([2, 3, 4, 5])mo_net/tests/model/layer/test_output.py-57-59 (1)
57-59:⚠️ Potential issue | 🟡 MinorRemove unused
num_negativesvariable.
num_negativesis assigned on Line 58 but never referenced. Same issue at Line 88.Proposed fix
def test_backward_prop_with_negative_sampling_2d(self): """Test backward_prop_with_negative with 2D negative samples array.""" layer = SparseCategoricalSoftmaxOutputLayer(input_dimensions=(10,)) batch_size = 2 - num_negatives = 3 X = jax.random.normal(jax.random.PRNGKey(0), (batch_size, 10))And similarly at Line 88:
def test_backward_prop_with_negative_sampling_1d(self): """Test backward_prop_with_negative with 1D negative samples array.""" layer = SparseCategoricalSoftmaxOutputLayer(input_dimensions=(10,)) batch_size = 2 - num_negatives = 3 X = jax.random.normal(jax.random.PRNGKey(0), (batch_size, 10))mo_net/tests/model/layer/test_output.py-320-331 (1)
320-331:⚠️ Potential issue | 🟡 MinorUnused
outputvariable; forward_prop call is still needed.
outputon Line 327 is assigned but unused (flagged by Ruff F841). The call is needed to populate the cache, so just discard the return:Proposed fix
- output = layer.forward_prop(input_activations=Activations(X)) + layer.forward_prop(input_activations=Activations(X))mo_net/model/layer/output.py-347-364 (1)
347-364:⚠️ Potential issue | 🟡 MinorDeserialized layer skips
__init__cache initialization — optimizer state (first_moment,second_moment) not restored.
Serialized.deserialize()constructs the layer via__init__with onlyhuffman_tree, then overwrites_parameters. However, it doesn't restorefirst_moment/second_momentfrom the serialized state (they aren't serialized at all). If a deserialized layer is used for continued training, the Adam optimizer state will restart from zero, potentially causing a learning rate spike.This may be acceptable if deserialized models are only used for inference, but the
trainingparameter indeserializeis accepted (though unused — as the linter notes), suggesting training use is intended.mo_net/samples/rnn_language/__main__.py-309-315 (1)
309-315:⚠️ Potential issue | 🟡 MinorFragile module resolution in
Word2VecUnpickler.The condition
"rnn_language" in module(line 313) will redirect any pickle class whose module path contains"rnn_language"tomo_net.samples.word2vec.__main__. If the RNN language model itself is ever pickled, this would incorrectly remap its classes to the word2vec module.Suggested tighter condition
def find_class(self, module: str, name: str): - if module == "__main__" or "rnn_language" in module: + if module == "__main__": module = "mo_net.samples.word2vec.__main__" return super().find_class(module, name)
🧹 Nitpick comments (22)
mo_net/tests/smoke/test_mnist_cnn.py (1)
32-33: Consider linking a tracking issue so this skip doesn't become permanent.Skipped tests without a tracked issue tend to linger. Consider filing a GitHub issue and referencing it in the skip reason (e.g.,
reason="CNN smoke test broken, see#XX"), so this doesn't fall through the cracks.justfile (1)
32-51: Hardcoded remote host and path — consider documenting prerequisites.The
test-remotetarget assumes a specific SSH host alias (herakles) and remote path (~/mo-net/). This is fine as a personal dev convenience target, but a brief comment noting the expected SSH config (or a note in a README) would help other contributors understand this isn't a general-purpose target.Also, unknown flags (anything other than
--cuda) are silently ignored. If more flags are added later, consider adding a catch-all warning.flake.nix (1)
133-150: Defensive override fornvidia-nvshmem-cu12— consider adding a comment.Per CLAUDE.md (line 39), the intent is to avoid pulling in
nvidia-nvshmem-cu12by not using the["with-cuda"]extra. If that strategy holds, this override at lines 134–138 would never activate. It's good as a safety net, but a brief comment noting it's defensive would help future readers understand why it exists alongside the avoidance strategy.The
nvidia-cusolver-cu12andnvidia-cusparse-cu12overrides are properly guarded withlib.optionals (pkgsCuda != null).mo_net/data_structures/huffman_tree.py (1)
19-41: HuffmanNode:__lt__only — consider adding a tiebreaker for stable heap ordering.When two nodes have equal frequencies,
heapqrelies solely on__lt__, which returnsFalsefor both directions. While CPython handles this gracefully, the resulting tree shape is non-deterministic across equal-frequency entries. If reproducibility matters (e.g., for deterministic tests or serialized models), consider adding a tiebreaker (e.g., a creation counter).mo_net/model/layer/recurrent.py (1)
271-278: No-op conditional: both branches of the ternary are identical.
output_dimensions = hidden_dimensions if return_sequences else hidden_dimensionsalways evaluates tohidden_dimensions. If this is intentional (the feature dimension is alwayshidden_dimensionsregardless ofreturn_sequences), remove the conditional and add a brief comment for clarity.Proposed simplification
- # Output dimensions depend on return_sequences - output_dimensions = hidden_dimensions if return_sequences else hidden_dimensions + # Output feature dimension is always hidden_dimensions + # (return_sequences only affects the time axis, not tracked here) + output_dimensions = hidden_dimensionsmo_net/cli.py (1)
280-281: Stale comment:batch_sizeis always provided here.The comment "Batch normalisation requires batch_size" is misleading —
batch_sizeis typed asint(not optional) in this function's signature, so it's always present regardless of normalisation type. Consider removing or rewording it.✏️ Suggested change
if normalisation_type == NormalisationType.BATCH: - # Batch normalisation requires batch_sizemo_net/model/model.py (1)
534-555: Gradient cache logic correctly extended forGradLayeroutput.The
get_gradient_caches/populate_cachespair now consistently handles the optionalGradLayeroutput (e.g.,HierarchicalSoftmaxOutputLayer). The length validation at Line 548 provides a good safety net.One nit from static analysis:
zip()on Line 551 lacks an explicitstrict=parameter. Since you already validate lengths on Line 548, addingstrict=Truewould be a belt-and-suspenders improvement.Proposed diff
- for layer, update in zip(self.grad_layers, updates[: len(self.grad_layers)]): + for layer, update in zip(self.grad_layers, updates[: len(self.grad_layers)], strict=True):mo_net/tests/model/layer/test_output.py (1)
249-264: Document expected behavior when true label overlaps with negatives.The comment on Line 263 says "later assignment overwrites," which means the negative-sample gradient (positive value) overwrites the true-label gradient (prob − 1). This is a lossy semantic—effectively the model no longer gets a learning signal for that true label in that sample. While the test documents current behavior, consider adding an assertion that validates the specific resulting value to make the test more precise and catch future regressions.
mo_net/tests/model/layer/test_hierarchical_output.py (1)
64-64: Tests call_forward_prop/_backward_prop(private methods) instead of public API.Throughout this file, tests invoke
layer._forward_prop(...)andlayer._backward_prop(...)directly (e.g., Lines 64, 88, 109, 113, 132, 152, 184, 230, 237, 256, 274, 280). If the base classOutputLayer.forward_propever adds pre/post-processing around_forward_prop, these tests won't exercise the full path. Consider usinglayer.forward_prop(...)/layer.backward_prop(...)for more resilient tests.mo_net/tests/samples/word2vec/test_softmax_strategy.py (1)
111-119: Use a more specific exception type instead of bareException.
pytest.raises(Exception)(Lines 115, 118) is overly broad and could mask unrelated failures (Ruff B017). Frozen dataclasses raiseFrozenInstanceError(a subclass ofAttributeError). UsingAttributeErroris cross-version safe:Proposed fix
- with pytest.raises(Exception): # FrozenInstanceError + with pytest.raises(AttributeError): # FrozenInstanceError config.strategy = SoftmaxStrategy.NEGATIVE_SAMPLING # type: ignore - with pytest.raises(Exception): # FrozenInstanceError + with pytest.raises(AttributeError): # FrozenInstanceError config.negative_samples = 5 # type: ignoremo_net/tests/samples/word2vec/test_models.py (1)
326-328: Tests reach into private_negative_sampling_distfor setup.Lines 328 and 351 set
model._negative_sampling_distdirectly. If there's a public setter or ifcreate(...)can be configured to set this automatically, that would make the tests less coupled to internal state. Otherwise, this is acceptable for test setup.mo_net/tests/model/layer/test_recurrent.py (1)
563-586: Manual cache reset for 2D vs 3D input equivalence test.Line 582 directly resets
layer.cache["input_activations"] = Nonebetween forward passes. This works but couples the test to internal cache structure. If a second layer instance were used instead, the test would be more robust to internal changes.mo_net/tests/samples/word2vec/test_integration.py (1)
113-200: Direct mutation of private attribute_negative_sampling_dist.Line 132 directly sets
model._negative_sampling_dist. This mirrors the same pattern in__main__.py(line 708). Consider exposing a setter or accepting the distribution as acreate()parameter to avoid reaching into internals from test code.mo_net/tests/samples/word2vec/test_hierarchical_integration.py (2)
192-209: Redundant re-imports shadow module-level imports (F811).
HuffmanTree,SoftmaxConfig,SoftmaxStrategy, andHierarchicalSoftmaxOutputLayerare already imported at module level (lines 16–18). The re-imports at lines 198–203 are unnecessary and trigger the F811 lint warning. If the intent is purely to test importability, the module-level imports already prove that.Proposed fix
def test_all_components_importable(self): """Test that all components can be imported.""" - # This test verifies the implementation is complete - from mo_net.samples.word2vec.softmax_strategy import ( - SoftmaxConfig, - SoftmaxStrategy, - ) - from mo_net.data_structures.huffman_tree import HuffmanTree, HuffmanNode - from mo_net.model.layer.output import HierarchicalSoftmaxOutputLayer - assert SoftmaxConfig is not None assert SoftmaxStrategy is not None assert HuffmanTree is not None + from mo_net.data_structures.huffman_tree import HuffmanNode assert HuffmanNode is not None assert HierarchicalSoftmaxOutputLayer is not None
211-222: Placeholder meta-test with hardcoded counts will become stale.
test_all_tests_passingonly assertsTrueand has manually tracked test counts in the docstring. These counts will drift as tests are added or removed. Consider removing this test or generating the count dynamically.mo_net/samples/word2vec/eval.py (2)
76-86: Consider vectorized similarity computation for larger vocabularies.
get_nearest_neighborsiterates over each word in Python, computing cosine similarity one at a time. This is fine for small vocabularies but will be slow for larger ones. A vectorized approach using matrix operations would be significantly faster.Vectorized alternative
- similarities = [] - for other_word in vocab.vocab: - if exclude_self and other_word == word: - continue - other_idx = vocab[other_word] - other_vec = model.embeddings[other_idx] - sim = cosine_similarity(word_vec, other_vec) - similarities.append((other_word, sim)) - - similarities.sort(key=lambda x: x[1], reverse=True) - return similarities[:k] + # Vectorized cosine similarity against all embeddings + all_embeddings = model.embeddings + norms = jnp.linalg.norm(all_embeddings, axis=1, keepdims=True) + 1e-10 + normed = all_embeddings / norms + word_normed = word_vec / (jnp.linalg.norm(word_vec) + 1e-10) + sims = jnp.dot(normed, word_normed) + + # Get top-k+1 to account for possible self-exclusion + top_indices = jnp.argsort(-sims)[: k + (1 if exclude_self else 0)] + + results = [] + for idx in top_indices: + idx_int = int(idx) + if idx_int < len(vocab.vocab): + other_word = list(vocab.vocab)[idx_int] + if exclude_self and other_word == word: + continue + results.append((other_word, float(sims[idx_int]))) + if len(results) >= k: + break + return results
221-221: Unused loop variablei.The loop variable
ifromenumerateis never used. Rename to_per convention.Fix
- for i, result in enumerate(results[:max_display]): + for _i, result in enumerate(results[:max_display]):Or simply drop
enumerate:- for i, result in enumerate(results[:max_display]): + for result in results[:max_display]:mo_net/model/layer/output.py (2)
410-412: Fallback toPRNGKey(0)whenkeyis None may cause reproducibility issues.When
keyis not provided, the layer silently falls back toPRNGKey(0). This could lead to all layers sharing the same initialization seed if multiple are created without explicit keys. Consider raising an error instead, or at least logging a warning.
415-420: Xavier initialization formula may not be appropriate for node vectors.The scale
sqrt(2 / (embedding_dim + num_internal_nodes))uses the number of internal nodes as the fan-out, but each node vector participates in a single binary decision (sigmoid), not a full dense projection. A simpler initialization likesqrt(1 / embedding_dim)or small uniform may be more appropriate for sigmoid-based binary classifiers.mo_net/samples/rnn_language/__main__.py (1)
267-284:result.model_checkpoint_path.rename()may fail across filesystems.
Path.rename()does not work across filesystem boundaries (e.g., tmpfs → persistent storage). Consider usingshutil.move()instead, which handles cross-device moves.mo_net/samples/word2vec/__main__.py (2)
706-708: Repeated pattern of directly mutating_negative_sampling_distprivate attribute.This pattern appears here and in the test file. Consider accepting
negative_sampling_distas a parameter inSkipGramModel.create()to avoid external mutation of a private field.Proposed change in create()
return cls( input_dimensions=(1,), hidden=(...), output=output, key=key, negative_samples=negative_samples, + negative_sampling_dist=neg_sampling_dist, )This would require the caller to pass it at creation time rather than mutating after construction.
812-852: Good addition: saving best checkpoint onTrainingFailed.Preserving the best model checkpoint even when training is stopped early (e.g., rising validation loss) is good practice. The zip packaging logic is consistent with the
TrainingSuccessfulpath.One concern: the zip-packaging block (lines 829–846) is nearly identical to lines 787–803. Consider extracting a helper to avoid this duplication.
| # Special case: single word vocabulary | ||
| if vocab_size == 1: | ||
| # For single word, just return the leaf as root | ||
| # This means num_internal_nodes will be 1 (the root itself) | ||
| leaf = heap[0] | ||
| root = HuffmanNode(frequency=leaf.frequency, left=leaf) | ||
| tree = cls.__new__(cls) | ||
| tree.root = root | ||
| tree.vocab_size = vocab_size | ||
| tree.num_internal_nodes = 1 # One internal node (root) for single word | ||
| tree._paths = {} | ||
| tree._codes = {} | ||
| tree._build_paths() | ||
| return tree |
There was a problem hiding this comment.
Single-word case bypasses __init__ — fragile and causes serialization mismatch.
The cls.__new__(cls) pattern manually mirrors __init__ logic, which is easy to break if __init__ gains new attributes. More importantly, num_internal_nodes is set to 1 here (line 108), but deserialize (line 277) always computes vocab_size - 1, which yields 0 for a single-word tree. This inconsistency can cause downstream bugs (e.g., HierarchicalSoftmaxOutputLayer may allocate the wrong number of node vectors after deserialization).
Consider routing through __init__ and handling the edge case for num_internal_nodes consistently:
Proposed fix: unify construction and fix num_internal_nodes in deserialize
In build, route through normal __init__:
if vocab_size == 1:
leaf = heap[0]
root = HuffmanNode(frequency=leaf.frequency, left=leaf)
- tree = cls.__new__(cls)
- tree.root = root
- tree.vocab_size = vocab_size
- tree.num_internal_nodes = 1
- tree._paths = {}
- tree._codes = {}
- tree._build_paths()
- return tree
+ return cls(root=root, vocab_size=vocab_size)Then fix __init__ to handle the single-word case:
self.vocab_size = vocab_size
- self.num_internal_nodes = vocab_size - 1
+ self.num_internal_nodes = max(vocab_size - 1, 1)And ensure deserialize uses the same formula (it will, once __init__ is used or the formula is unified):
- tree.num_internal_nodes = vocab_size - 1
+ tree.num_internal_nodes = max(vocab_size - 1, 1)🤖 Prompt for AI Agents
In `@mo_net/data_structures/huffman_tree.py` around lines 99 - 112, The
single-word branch in build bypasses __init__ (using cls.__new__) and sets
num_internal_nodes=1, causing serialization/deserialization mismatch; instead
construct the tree via the normal initializer and centralize single-word
handling in __init__ so all attributes (_paths, _codes, root, vocab_size,
num_internal_nodes) are set consistently, and update deserialize to compute
num_internal_nodes as max(0, vocab_size - 1) (or otherwise use the same formula
as __init__) so the single-word case yields num_internal_nodes=0 after
deserialization; reference HuffmanTree.__init__, HuffmanTree.build,
HuffmanTree.deserialize, HuffmanNode, and methods _build_paths/_paths/_codes
when making the change.
| @classmethod | ||
| def deserialize(cls, data: bytes) -> HuffmanTree: | ||
| """Deserialize tree from bytes. | ||
|
|
||
| Args: | ||
| data: Msgpack-encoded tree structure | ||
|
|
||
| Returns: | ||
| Reconstructed HuffmanTree | ||
|
|
||
| Note: | ||
| This recreates the tree structure from saved paths/codes. | ||
| The actual tree structure isn't saved, only the path information needed for inference. | ||
| """ | ||
| decoded = msgpack.unpackb(data, raw=False) | ||
| vocab_size = decoded["vocab_size"] | ||
|
|
||
| # Reconstruct a minimal tree (just needs to support get_path) | ||
| # We don't need the full tree structure for inference | ||
| root = HuffmanNode(frequency=0) # Dummy root | ||
| tree = cls.__new__(cls) | ||
| tree.root = root | ||
| tree.vocab_size = vocab_size | ||
| tree.num_internal_nodes = vocab_size - 1 | ||
|
|
||
| # Restore paths and codes | ||
| tree._paths = {int(k): v for k, v in decoded["paths"].items()} | ||
| tree._codes = {int(k): v for k, v in decoded["codes"].items()} | ||
|
|
There was a problem hiding this comment.
deserialize doesn't serialize/restore num_internal_nodes independently.
serialize() does not include num_internal_nodes in the serialized data, and deserialize() recomputes it as vocab_size - 1. As noted above, this diverges from the single-word tree case where build() sets it to 1. If num_internal_nodes is semantically important (and it is — it's used for vector allocation in hierarchical softmax), consider either:
- Serializing
num_internal_nodesexplicitly, or - Using a single consistent formula everywhere (e.g.,
max(vocab_size - 1, 1)).
🤖 Prompt for AI Agents
In `@mo_net/data_structures/huffman_tree.py` around lines 254 - 282, The
deserialize method in HuffmanTree restores vocab_size but forces
num_internal_nodes = vocab_size - 1 which mismatches the single-word build()
behavior (where build() sets num_internal_nodes = 1) and can break
hierarchical-softmax vector allocation; update the implementation so
num_internal_nodes is preserved or consistently computed: either (A) modify
serialize() to include "num_internal_nodes" and have deserialize() read and
assign it back to tree.num_internal_nodes, or (B) make both build(),
serialize(), and deserialize() compute num_internal_nodes using a single formula
such as max(vocab_size - 1, 1) so the value is consistent across
creation/serialization/deserialization (change code in HuffmanTree.serialize,
HuffmanTree.deserialize, and HuffmanTree.build accordingly).
| def _forward_prop( | ||
| self, | ||
| *, | ||
| input_activations: Activations, | ||
| ) -> Activations: | ||
| """Compute forward pass - returns log probabilities. | ||
|
|
||
| For each sample in batch, computes probability for ALL words by traversing | ||
| all paths in the tree. This is expensive but needed for full softmax during inference. | ||
|
|
||
| Args: | ||
| input_activations: Context embeddings, shape (batch_size, embedding_dim) | ||
|
|
||
| Returns: | ||
| Log probabilities for each word, shape (batch_size, vocab_size) | ||
| """ | ||
| # Cache input for backward pass | ||
| self._cache["input_activations"] = input_activations | ||
|
|
||
| batch_size = input_activations.shape[0] | ||
| log_probs = jnp.zeros((batch_size, self.vocab_size)) | ||
|
|
||
| # For each word, compute log probability by traversing its path | ||
| for word_id in range(self.vocab_size): | ||
| node_indices, directions = self.tree.get_path(word_id) | ||
|
|
||
| if len(node_indices) == 0: | ||
| # Edge case: single word vocabulary | ||
| log_probs = log_probs.at[:, word_id].set(0.0) | ||
| continue | ||
|
|
||
| # Compute log P(word|context) = sum of log P(decision_i) along path | ||
| word_log_prob = jnp.zeros(batch_size) | ||
|
|
||
| for node_idx, go_left in zip(node_indices, directions): | ||
| # Compute score: θ_node^T · h | ||
| scores = jnp.dot( | ||
| input_activations, self._parameters.node_vectors[node_idx] | ||
| ) | ||
|
|
||
| # Log probability of this decision | ||
| if go_left: | ||
| word_log_prob += jax.nn.log_sigmoid(scores) | ||
| else: | ||
| word_log_prob += jax.nn.log_sigmoid(-scores) | ||
|
|
||
| log_probs = log_probs.at[:, word_id].set(word_log_prob) | ||
|
|
||
| # Return probabilities (not log probs) for compatibility | ||
| return Activations(jax.nn.softmax(log_probs)) |
There was a problem hiding this comment.
Forward pass is O(V·log V) per batch — defeats hierarchical softmax's purpose during training.
The forward pass iterates over all V words (line 459) and traverses each word's path in Python loops. This makes the forward pass O(V·log V), which is worse than standard softmax O(V). Hierarchical softmax should only need to compute probabilities for the target words during training (O(log V) per sample).
Additionally, line 485 applies jax.nn.softmax to the already-computed log-probabilities. The Huffman tree paths define a proper probability distribution by construction (each internal node splits probability mass), so exp(log_probs) should yield valid probabilities summing to ~1. Applying softmax renormalizes them, which masks any tree construction issues and changes the gradient signal.
Consider:
- A training-time forward pass that only computes
log P(target_word | context)— O(log V) per sample. - Using
jnp.exp(log_probs)instead ofjax.nn.softmax(log_probs)for the inference-time full-vocabulary pass, or at least documenting why the renormalization is intentional.
🧰 Tools
🪛 Ruff (0.14.14)
[warning] 470-470: zip() without an explicit strict= parameter
Add explicit value for parameter strict=
(B905)
| def _backward_prop( | ||
| self, | ||
| *, | ||
| Y_true: jnp.ndarray, | ||
| ) -> D[Activations]: | ||
| """Compute backward pass - sparse gradients only for paths to target words. | ||
|
|
||
| Args: | ||
| Y_true: Target word indices, shape (batch_size,) | ||
|
|
||
| Returns: | ||
| Gradient w.r.t. input activations, shape (batch_size, embedding_dim) | ||
| """ | ||
| if (input_activations := self._cache["input_activations"]) is None: | ||
| raise ValueError("Input activations not set during forward pass.") | ||
|
|
||
| batch_size, embedding_dim = input_activations.shape | ||
|
|
||
| # Initialize gradients | ||
| grad_input = jnp.zeros_like(input_activations) | ||
| grad_node_vectors = jnp.zeros_like(self._parameters.node_vectors) | ||
|
|
||
| # Process each sample in batch | ||
| for i in range(batch_size): | ||
| target_word = int(Y_true[i]) | ||
| h = input_activations[i] # Context embedding for this sample | ||
|
|
||
| # Get path to target word | ||
| node_indices, directions = self.tree.get_path(target_word) | ||
|
|
||
| # For each node on the path, compute gradient | ||
| for node_idx, go_left in zip(node_indices, directions): | ||
| # Forward: score = θ_node^T · h | ||
| theta = self._parameters.node_vectors[node_idx] | ||
| score = jnp.dot(theta, h) | ||
|
|
||
| # Sigmoid and gradient | ||
| sig = jax.nn.sigmoid(score) | ||
|
|
||
| # Target: 1 if go_left, 0 if go_right | ||
| target = 1.0 if go_left else 0.0 | ||
|
|
||
| # Gradient of negative log likelihood (loss): (sigmoid - target) | ||
| # This is for gradient descent: we want to minimize -log P(path) | ||
| grad = sig - target | ||
|
|
||
| # Gradient w.r.t. input: ∂loss/∂h = (σ - target) · θ | ||
| grad_input = grad_input.at[i].add(grad * theta) | ||
|
|
||
| # Gradient w.r.t. node vector: ∂loss/∂θ = (σ - target) · h | ||
| grad_node_vectors = grad_node_vectors.at[node_idx].add(grad * h) | ||
|
|
||
| # Store gradient in cache for optimizer | ||
| self._cache["dP"] = d(self.Parameters(node_vectors=grad_node_vectors)) | ||
|
|
||
| return cast(D[Activations], grad_input) |
There was a problem hiding this comment.
Backward pass uses Python loops over batch and path nodes — not JIT-compatible and slow.
The nested Python for loops (lines 510 and 518) over batch_size and path nodes prevent JAX from tracing/JIT-compiling this code efficiently. For training at any meaningful scale, this will be a severe bottleneck.
Also, line 503 unpacks embedding_dim but never uses it.
Fix for unused variable
- batch_size, embedding_dim = input_activations.shape
+ batch_size, _embedding_dim = input_activations.shape🧰 Tools
🪛 Ruff (0.14.14)
[warning] 501-501: Avoid specifying long messages outside the exception class
(TRY003)
[warning] 503-503: Unpacked variable embedding_dim is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
[warning] 518-518: zip() without an explicit strict= parameter
Add explicit value for parameter strict=
(B905)
[warning] 533-533: Comment contains ambiguous σ (GREEK SMALL LETTER SIGMA). Did you mean o (LATIN SMALL LETTER O)?
(RUF003)
[warning] 536-536: Comment contains ambiguous σ (GREEK SMALL LETTER SIGMA). Did you mean o (LATIN SMALL LETTER O)?
(RUF003)
🤖 Prompt for AI Agents
In `@mo_net/model/layer/output.py` around lines 487 - 542, The _backward_prop
method currently uses Python for-loops over batch and path nodes (in
_backward_prop and calls to self.tree.get_path), which prevents JAX tracing/JIT;
also embedding_dim is unused — fix by making the path data JAX-compatible and
vectorizing the gradient computation: change tree.get_path to return batched JAX
arrays (node_indices: int32[batch, max_path], directions: int8[batch, max_path],
and path_lengths) or add a helper that converts per-sample Python paths into
padded arrays, then compute per-node scores and sigmoids with
jnp.take(self._parameters.node_vectors, node_indices, axis=0) and jnp.einsum or
jnp.sum to get scores for all batch×path entries, compute grad = sigmoid -
directions (cast) as a jnp array, then accumulate grad_input via
jnp.einsum('bpv,bp->bv', node_vectors_selected, grad) and grad_node_vectors via
jax.ops.segment_sum or jnp.zeros_like + jax.ops.index_add (or use jax.vmap over
samples with lax.scan over path dimension) to produce grad_node_vectors; store
dP with self._cache["dP"]=d(self.Parameters(node_vectors=grad_node_vectors)),
and remove the unused embedding_dim variable.
| def __sub__(self, other: Self | float | int) -> Self: | ||
| return self.__add__(-other) | ||
|
|
||
| def __rsub__(self, other: Self | float | int) -> Self: | ||
| return self.__sub__(other) |
There was a problem hiding this comment.
__rsub__ is incorrect — other - self returns self - other.
__rsub__ is called when Python evaluates other - self and other doesn't know how to subtract a Parameters. The current implementation delegates to self.__sub__(other), which computes self - other — the opposite sign. For example, 0 - params would return params instead of -params.
Proposed fix
def __rsub__(self, other: Self | float | int) -> Self:
- return self.__sub__(other)
+ return (-self).__add__(other)🤖 Prompt for AI Agents
In `@mo_net/model/layer/recurrent.py` around lines 67 - 71, The __rsub__
implementation is backwards: when Python computes other - self it should return
other - self, not self - other; update __rsub__ in the class containing
__sub__/__add__ so it negates self then adds other (e.g. return
(-self).__add__(other) or equivalently return (-self) + other) ensuring the
operation uses the same __add__/negation semantics as __sub__/__add__ for
Parameters.
| trainer = BasicTrainer( | ||
| X_train=X_train[:train_size], | ||
| X_val=X_train[train_size:], | ||
| Y_train=Y_train[:train_size].flatten(), | ||
| Y_val=Y_train[train_size:].flatten(), |
There was a problem hiding this comment.
Shape mismatch between X_train and Y_train.flatten() will break batching.
X_train[:train_size] has shape (train_size, sequence_length), while Y_train[:train_size].flatten() has shape (train_size * sequence_length,). When BasicTrainer batches both arrays by the same row indices, the batch dimensions are misaligned — X-batch will have batch_size rows but Y-batch indexing will be off by a factor of sequence_length.
The same issue exists in train_pretrained at lines 468–469.
🤖 Prompt for AI Agents
In `@mo_net/samples/rnn_language/__main__.py` around lines 250 - 254, The Y arrays
were flattened which collapses the sequence dimension and breaks batching for
BasicTrainer; instead keep Y_train and Y_val with the same first-dimension as X
(do not call .flatten()) or reshape them to (train_size, sequence_length) so
their row count matches X_train[:train_size]; update the two instantiations (the
BasicTrainer call around the listed lines and the train_pretrained call
referenced) to pass Y_train[:train_size] and Y_train[train_size:] (or
appropriately reshaped/squeezed forms) so batch indexing aligns with X_train.
| key = jax.random.PRNGKey(int(time.time())) | ||
|
|
||
| for _ in range(num_words): | ||
| X = jnp.array([current_tokens], dtype=jnp.int32) | ||
| logits = model.forward_prop(X) | ||
| pred_reshaped = logits.reshape(sequence_length, vocab_size) | ||
| last_logits = pred_reshaped[-1] | ||
|
|
||
| if temperature != 1.0: | ||
| last_logits = last_logits / temperature | ||
|
|
||
| key, subkey = jax.random.split(key) | ||
| probs = jax.nn.softmax(last_logits) | ||
| next_token = int(jax.random.choice(subkey, vocab_size, p=probs)) |
There was a problem hiding this comment.
Double softmax produces incorrect sampling probabilities.
model.forward_prop(X) already applies softmax (via SparseCategoricalSoftmaxOutputLayer). Line 578 applies jax.nn.softmax again on the extracted slice, resulting in a doubly-squashed distribution. Temperature scaling (line 575) is also applied to post-softmax values, which doesn't produce the intended effect — temperature should be applied to pre-softmax logits.
Minimal fix (assuming architecture is corrected separately)
last_logits = pred_reshaped[-1]
- if temperature != 1.0:
- last_logits = last_logits / temperature
-
key, subkey = jax.random.split(key)
- probs = jax.nn.softmax(last_logits)
+ probs = last_logits / jnp.sum(last_logits) # already softmaxed, just renormalize slice
next_token = int(jax.random.choice(subkey, vocab_size, p=probs))Note: Temperature scaling requires access to pre-softmax logits, which needs an architecture change.
🤖 Prompt for AI Agents
In `@mo_net/samples/rnn_language/__main__.py` around lines 566 - 579, The sampling
code double-applies softmax and applies temperature after softmax; fix by using
pre-softmax logits and applying temperature before softmax (or, if the model
currently returns probabilities from model.forward_prop because of
SparseCategoricalSoftmaxOutputLayer, stop re-softmaxing and sample directly from
those probabilities). Specifically, change model.forward_prop to return raw
logits (or add a flag/alternative method that returns logits) so you can compute
last_logits = logits.reshape(sequence_length, vocab_size)[-1], then if
temperature != 1.0 do last_logits = last_logits / temperature, compute probs =
jax.nn.softmax(last_logits), and finally sample with jax.random.choice(subkey,
vocab_size, p=probs); alternatively, if you cannot change the model, remove the
jax.nn.softmax call and sample from the probabilities returned by
model.forward_prop without temperature scaling.
| results = [ | ||
| result | ||
| for example in analogies | ||
| if ( | ||
| result := __import__( | ||
| "mo_net.samples.word2vec.eval", fromlist=["evaluate_analogy"] | ||
| ).evaluate_analogy(example, model, vocab, top_k) | ||
| ) | ||
| ] |
There was a problem hiding this comment.
Bug: __import__ used instead of the already-imported evaluate_analogy.
evaluate_analogy is imported at lines 1084–1089 but unused — instead, lines 1102–1104 use a dynamic __import__ call. This is unnecessarily complex, harder to read, and bypasses the explicit import. Additionally, evaluate_analogies is called again at line 1110, re-evaluating all examples a second time.
Proposed fix
+ from mo_net.samples.word2vec.eval import (
+ evaluate_analogy,
+ evaluate_analogies,
+ evaluate_model,
+ get_default_analogies,
+ print_analogy_results,
+ )
+
model, vocab = load_model_and_vocab(model_path)
click.echo("Evaluating word2vec model...")
click.echo(f"Vocabulary size: {len(vocab.vocab)}")
click.echo()
analogies = get_default_analogies()
- results = [
- result
- for example in analogies
- if (
- result := __import__(
- "mo_net.samples.word2vec.eval", fromlist=["evaluate_analogy"]
- ).evaluate_analogy(example, model, vocab, top_k)
- )
- ]
+ results = [evaluate_analogy(ex, model, vocab, top_k) for ex in analogies]
print_analogy_results(results)
- metrics = evaluate_analogies(analogies, model, vocab, top_k)
+ # Compute metrics from already-evaluated results instead of re-evaluating
+ total = len(results)
+ correct = sum(1 for r in results if r.correct)
+ metrics = {"overall_accuracy": correct / total if total else 0.0, "total_examples": total, "correct_predictions": correct}
click.echo("\n=== Evaluation Metrics ===")
for key, value in metrics.items():
if isinstance(value, float):
click.echo(f"{key}: {value:.2%}")
else:
click.echo(f"{key}: {value}")
-
- # Also run comprehensive evaluation
- click.echo()
- evaluate_model(model, vocab, analogies)🤖 Prompt for AI Agents
In `@mo_net/samples/word2vec/__main__.py` around lines 1098 - 1106, Replace the
dynamic __import__ usage inside the list comprehension with the already-imported
function evaluate_analogy: call evaluate_analogy(example, model, vocab, top_k)
directly when building results for analogies; ensure the list comprehension uses
the imported evaluate_analogy symbol (not __import__) and preserves the
conditional filtering for truthy results, and remove or stop the subsequent
duplicate call to evaluate_analogies that re-evaluates examples so the examples
are evaluated only once and results is reused.
| def test_serialize_single_word_tree(self): | ||
| """Test serializing tree with single word.""" | ||
| frequencies = {0: 100} | ||
| tree = HuffmanTree.build(frequencies) | ||
|
|
||
| data = tree.serialize() | ||
| tree2 = HuffmanTree.deserialize(data) | ||
|
|
||
| assert tree2.vocab_size == 1 | ||
| path, code = tree2.get_path(0) | ||
| assert len(path) == len(code) |
There was a problem hiding this comment.
🛠️ Refactor suggestion | 🟠 Major
Consider asserting num_internal_nodes consistency in single-word serialization test.
This test doesn't verify that tree2.num_internal_nodes matches tree.num_internal_nodes after the round-trip — which is where the deserialize bug manifests. Adding this assertion would catch the mismatch.
Proposed addition
assert tree2.vocab_size == 1
+ assert tree2.num_internal_nodes == tree.num_internal_nodes
path, code = tree2.get_path(0)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| def test_serialize_single_word_tree(self): | |
| """Test serializing tree with single word.""" | |
| frequencies = {0: 100} | |
| tree = HuffmanTree.build(frequencies) | |
| data = tree.serialize() | |
| tree2 = HuffmanTree.deserialize(data) | |
| assert tree2.vocab_size == 1 | |
| path, code = tree2.get_path(0) | |
| assert len(path) == len(code) | |
| def test_serialize_single_word_tree(self): | |
| """Test serializing tree with single word.""" | |
| frequencies = {0: 100} | |
| tree = HuffmanTree.build(frequencies) | |
| data = tree.serialize() | |
| tree2 = HuffmanTree.deserialize(data) | |
| assert tree2.vocab_size == 1 | |
| assert tree2.num_internal_nodes == tree.num_internal_nodes | |
| path, code = tree2.get_path(0) | |
| assert len(path) == len(code) |
🤖 Prompt for AI Agents
In `@mo_net/tests/data_structures/test_huffman_tree.py` around lines 261 - 271,
Update the test_serialize_single_word_tree to assert that num_internal_nodes is
preserved across serialization round-trip: after creating tree =
HuffmanTree.build(frequencies), serializing via HuffmanTree.serialize and
deserializing via HuffmanTree.deserialize into tree2, add an assertion that
tree2.num_internal_nodes == tree.num_internal_nodes in addition to the existing
checks (vocab_size and get_path). This will catch the deserialize bug where
num_internal_nodes is not reconstructed correctly.
| # TODO: Fix gradient transfer test - batch_norm case fails | ||
| @pytest.mark.skip(reason="TODO: batch_norm gradient transfer broken") |
There was a problem hiding this comment.
🛠️ Refactor suggestion | 🟠 Major
Skip is too broad — it disables the linear and convolution cases that presumably still pass.
The reason cites only batch_norm, but @pytest.mark.skip blankets all three parameterized cases. Scope the skip to the failing case only so the other two continue to provide regression coverage.
Proposed fix: move the skip marker onto the batch_norm param only
GradientTransferTestCase(
name="batch_norm",
- model=Model(
+ ...
+ ) # wrap in pytest.param with a skip mark:
+ pytest.param(
+ GradientTransferTestCase(
+ name="batch_norm",
+ ...
+ ),
+ marks=pytest.mark.skip(reason="TODO: batch_norm gradient transfer broken"),
),Concretely, replace the bare GradientTransferTestCase(name="batch_norm", ...) entry with:
pytest.param(
GradientTransferTestCase(
name="batch_norm",
model=Model(
input_dimensions=(4,),
hidden=[BatchNorm(input_dimensions=(4,), training=True)],
),
forward_input=jax.random.uniform(jax.random.split(key)[0], (3, 4)),
backward_input=jax.random.uniform(jax.random.split(key)[1], (3, 4)),
expected_w_shape=(4,),
expected_b_shape=(4,),
),
marks=pytest.mark.skip(reason="TODO: batch_norm gradient transfer broken"),
),And remove the top-level @pytest.mark.skip(...) decorator + TODO comment from the function.
🤖 Prompt for AI Agents
In `@mo_net/tests/trainer/test_parallel.py` around lines 99 - 100, The top-level
pytest.skip is too broad and skips all parameterized cases; narrow the skip to
only the failing GradientTransferTestCase for "batch_norm". Replace the bare
GradientTransferTestCase(name="batch_norm", ...) entry in the parameter list
with a pytest.param(...) wrapper that includes
marks=pytest.mark.skip(reason="TODO: batch_norm gradient transfer broken") and
remove the top-level `@pytest.mark.skip` decorator and its TODO comment from the
test function; locate the entries using GradientTransferTestCase, Model and
BatchNorm to modify the specific parameter and use pytest.param and
pytest.mark.skip to scope the skip.
There was a problem hiding this comment.
Actionable comments posted: 4
🧹 Nitpick comments (11)
mo_net/cli.py (1)
280-330: Use an exhaustivematchto avoid a silent type-lie in theelsebranch.
castis a no-op at runtime. Theelsebranch currently assertsnormalisation_typeisLiteral[NormalisationType.NONE, NormalisationType.LAYER], but this will silently mistype any futureNormalisationTypemember that isn'tBATCH. An exhaustivematchwithassert_nevermakes this compiler-checked.♻️ Proposed refactor: exhaustive match + deduplicated kwargs
- if normalisation_type == NormalisationType.BATCH: - # Batch normalisation requires batch_size - return Model.mlp_of( - key=key, - module_dimensions=( - tuple( - map( - lambda d: (d,), - [ - X_train.shape[1], - *dims, - Y_train.shape[ - 1 - ], # Output dimension should match number of classes - ], - ) - ) - ), - activation_fn=activation_fn, - batch_size=batch_size, - normalisation_type=cast( - Literal[NormalisationType.BATCH], normalisation_type - ), - tracing_enabled=tracing_enabled, - dropout_keep_probs=dropout_keep_probs, - ) - else: - return Model.mlp_of( - key=key, - module_dimensions=( - tuple( - map( - lambda d: (d,), - [ - X_train.shape[1], - *dims, - Y_train.shape[ - 1 - ], # Output dimension should match number of classes - ], - ) - ) - ), - activation_fn=activation_fn, - normalisation_type=cast( - Literal[NormalisationType.NONE, NormalisationType.LAYER], - normalisation_type, - ), - tracing_enabled=tracing_enabled, - dropout_keep_probs=dropout_keep_probs, - ) + module_dimensions = tuple( + map( + lambda d: (d,), + [ + X_train.shape[1], + *dims, + Y_train.shape[1], # Output dimension should match number of classes + ], + ) + ) + common_kwargs = dict( + key=key, + module_dimensions=module_dimensions, + activation_fn=activation_fn, + tracing_enabled=tracing_enabled, + dropout_keep_probs=dropout_keep_probs, + ) + match normalisation_type: + case NormalisationType.BATCH: + return Model.mlp_of( + **common_kwargs, + batch_size=batch_size, + normalisation_type=cast( + Literal[NormalisationType.BATCH], normalisation_type + ), + ) + case NormalisationType.NONE | NormalisationType.LAYER: + return Model.mlp_of( + **common_kwargs, + normalisation_type=cast( + Literal[NormalisationType.NONE, NormalisationType.LAYER], + normalisation_type, + ), + ) + case _: + assert_never(normalisation_type)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@mo_net/cli.py` around lines 280 - 330, Replace the current if/else on normalisation_type in the Model.mlp_of call with an exhaustive match on NormalisationType to ensure new enum members don't silently get cast; handle the NormalisationType.BATCH arm by passing batch_size and setting normalisation_type to NormalisationType.BATCH, handle the NormalisationType.NONE and NormalisationType.LAYER arms by passing the appropriate normalisation_type values, and add an assert_never(default) fallback for compile-time exhaustiveness checking; also deduplicate common kwargs (key, module_dimensions, activation_fn, tracing_enabled, dropout_keep_probs) into a shared kwargs dict used by Model.mlp_of so only the per-branch differences (batch_size and normalisation_type) are set in each match arm.mo_net/train/backends/log.py (4)
310-330: Significant duplication withSqliteBackend.
start_run,end_run, and the run-field-update logic inlog_iterationare near-duplicates ofSqliteBackend. If these two backends evolve together (which the matchingend_runchange in this PR already suggests), keeping them in sync will become error-prone. A shared base class or mixin that owns the common session/run lifecycle could reduce this surface.Not urgent, but worth considering as a follow-up.
Also applies to: 332-341, 351-382
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@mo_net/train/backends/log.py` around lines 310 - 330, The start_run, end_run and run-update logic in log_iteration in this backend duplicates code in SqliteBackend (e.g., methods start_run, end_run, log_iteration and use of DbRun, _session, _current_run), so extract the shared lifecycle into a new base class or mixin (e.g., RunLifecycleMixin) that implements create/commit of DbRun and updating run fields, then have both backends inherit that mixin and remove the duplicated implementations; ensure the mixin exposes the same method signatures (start_run, end_run, update_run_from_iteration) and uses the existing _session and _current_run names so callers in LogBackend and SqliteBackend need minimal changes.
348-349:log_training_parameterssilently discards data.This is a no-op, which means training parameters are lost when using the in-memory backend. If this backend is intended for testing, tests that verify training-parameter logging will silently pass without actually storing anything. Consider storing parameters on the
DbRunmodel or in a local attribute so they can be asserted against in tests.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@mo_net/train/backends/log.py` around lines 348 - 349, The in-memory backend's log_training_parameters method is a no-op and discards the training_parameters; implement it to persist the string so tests can assert on it (e.g., attach the value to the corresponding DbRun instance or a local dict keyed by run id). Update the log_training_parameters function to accept the run context (or look up the current run) and store training_parameters on the DbRun model (or an in-memory attribute like self._runs[run_id].training_parameters or self._training_parameters_by_run[run_id]) so subsequent reads (e.g., get_run or tests) return the stored parameters.
343-346: Consider flushing uncommitted iteration data on teardown.
log_iterationcommits after every call, so in the normal flow there's nothing pending. However, if a future change introduces batching (mirroringSqliteBackend), or if an exception interrupts the commit insidelog_iteration, the session could hold unflushed data that would be silently lost here. A defensiveself._session.commit()beforeclose()(or at minimum a comment explaining why it's unnecessary) would make the contract clearer.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@mo_net/train/backends/log.py` around lines 343 - 346, In teardown() ensure any pending transaction is persisted before closing the SQLAlchemy session: if self._session is set, call self._session.commit() (catch/handle exceptions if needed) and then close and null out self._session; reference teardown(), self._session, and log_iteration (and the SqliteBackend batching behavior) so the change protects against future batching or interrupted commits and makes intent explicit.
293-383:InMemorySqliteBackendis not reachable fromparse_connection_string.The factory at line 428 doesn't recognise an in-memory connection string, so this backend can only be used via direct construction. If that's intentional (test-only usage), a brief docstring note would help future maintainers.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@mo_net/train/backends/log.py` around lines 293 - 383, The InMemorySqliteBackend isn't returned by parse_connection_string because its connection_string ("sqlite://:memory:") doesn't match the in-memory pattern parse_connection_string checks (and the engine in __init__ uses "sqlite:///:memory:"), so either update parse_connection_string to recognise the in-memory SQLite URIs (accept "sqlite:///:memory:" and "sqlite://:memory:") and return InMemorySqliteBackend, or make the backend's connection_string match the pattern parse_connection_string expects; update the InMemorySqliteBackend docstring to clarify if this backend is intentionally only for direct construction/testing.mo_net/tests/samples/word2vec/test_hierarchical_integration.py (3)
74-76: Move theActivationsimport to module level.The
from mo_net.protos import Activationsimport on line 74 is inside the test method. Moving it to the top of the file with other imports improves readability and avoids repeated import overhead.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@mo_net/tests/samples/word2vec/test_hierarchical_integration.py` around lines 74 - 76, Move the "from mo_net.protos import Activations" import out of the test body and place it with the module-level imports at the top of the file so it's executed once on import rather than inside the test; update the test that calls output_layer._forward_prop(input_activations=Activations(h)) to rely on the module-level Activations symbol (no other code changes needed).
197-211: Redundant re-imports shadow module-level names.Lines 200-205 re-import
SoftmaxConfig,SoftmaxStrategy,HuffmanTree, andHuffmanNodethat are already imported at the module level (lines 18-19). This triggers Ruff F811 forHuffmanTree. If the intent is to verify importability, the module-level imports already prove that; this test adds no value.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@mo_net/tests/samples/word2vec/test_hierarchical_integration.py` around lines 197 - 211, The test re-imports symbols already imported at module scope (SoftmaxConfig, SoftmaxStrategy, HuffmanTree, HuffmanNode, HierarchicalSoftmaxOutputLayer) which shadows module-level names and triggers a Ruff F811; remove the redundant local from-imports in test_all_components_importable and instead assert the existing module-level symbols (e.g., SoftmaxConfig, SoftmaxStrategy, HuffmanTree, HuffmanNode, HierarchicalSoftmaxOutputLayer) are not None, or if you want local imports for clarity, rename them to avoid shadowing (e.g., import as SoftmaxConfigLocal) and use those names consistently.
213-224:test_all_tests_passingis a no-op placeholder — consider removing.A test that simply
assert Trueprovides no verification value. The comment documenting test counts will become stale quickly.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@mo_net/tests/samples/word2vec/test_hierarchical_integration.py` around lines 213 - 224, Remove the no-op placeholder test function test_all_tests_passing from the file: it only contains assert True and stale documentation, so delete the function (def test_all_tests_passing(...)) from mo_net/tests/samples/word2vec/test_hierarchical_integration.py; if you want an aggregate check, replace it with a concrete test that programmatically inspects expected test outcomes (e.g., using pytest APIs) instead of a static assert/documentation comment.mo_net/model/layer/recurrent.py (1)
270-271: Useless conditional — both branches assign the same value.output_dimensions = hidden_dimensions if return_sequences else hidden_dimensionsBoth sides of the ternary are identical. Simplify to a direct assignment.
Proposed fix
- # Output dimensions depend on return_sequences - output_dimensions = hidden_dimensions if return_sequences else hidden_dimensions + output_dimensions = hidden_dimensions🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@mo_net/model/layer/recurrent.py` around lines 270 - 271, The ternary assigning output_dimensions uses identical branches; replace the conditional with a direct assignment: set output_dimensions equal to hidden_dimensions (remove the "if return_sequences else" expression). Update the occurrence in the recurrent layer implementation where variables output_dimensions, hidden_dimensions, and return_sequences are used so the code becomes a simple assignment and behavior is unchanged.mo_net/model/layer/output.py (1)
335-341:CacheTypedDict couples optimizer state (first_moment,second_moment) into the layer.Storing optimizer-specific fields (
first_moment,second_moment) in the layer's cache tightly couples the layer to the Adam/RMSProp optimizer internals. If a different optimizer is used, these fields are wasted. Other layers (e.g.,Linear,Recurrent) don't store optimizer moments in their cache — the optimizers manage that state externally.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@mo_net/model/layer/output.py` around lines 335 - 341, The Cache TypedDict currently couples optimizer state into the layer by including first_moment and second_moment; remove those optimizer-specific keys from the Cache definition (leave output_activations, input_activations, and dP related to HierarchicalSoftmaxOutputLayer.Parameters) and refactor any usages of Cache.first_moment/Cache.second_moment to use the optimizer's external state management (i.e., move moment storage/access into the optimizer implementation that indexes parameters by HierarchicalSoftmaxOutputLayer.Parameters or the D[...] key), ensuring references to Cache, first_moment, second_moment, and HierarchicalSoftmaxOutputLayer.Parameters are updated accordingly.mo_net/model/model.py (1)
534-555: Gradient cache handling for GradLayer output layers looks correct.The
get_gradient_cachesandpopulate_cachesmethods properly account for the optional gradient-capable output layer. The length validation inpopulate_cachesensures updates and layers stay in sync.One minor improvement: add
strict=Trueto thezipon line 551 for extra safety, per Ruff B905.Proposed fix
- for layer, update in zip(self.grad_layers, updates[: len(self.grad_layers)]): + for layer, update in zip(self.grad_layers, updates[: len(self.grad_layers)], strict=True):🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@mo_net/model/model.py` around lines 534 - 555, In populate_caches, make the pairing of grad_layers and updates strict to catch length mismatches by changing the loop in populate_caches that currently uses zip(self.grad_layers, updates[: len(self.grad_layers)]) to zip(self.grad_layers, updates[: len(self.grad_layers)], strict=True); keep the existing length check and the final assignment to output_layer.cache["dP"] unchanged (referencing methods get_gradient_caches and populate_caches and symbol self.grad_layers and output_layer.cache["dP"]).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@mo_net/model/layer/output.py`:
- Around line 410-413: The code silently falls back to jax.random.PRNGKey(0)
when key is None during node-vector initialization, harming reproducibility;
instead, remove the silent fallback and raise a clear error (ValueError) when
key is None, telling the caller to supply an explicit PRNGKey (or an explicit
seed/flag when deserializing); update the block that currently checks "if key is
None: key = jax.random.PRNGKey(0)" (referencing the variable key and the
jax.random.PRNGKey(0) usage in the "Initialize node vectors (one per internal
node)" section) to validate key and raise with a descriptive message so callers
must opt in to any default.
In `@mo_net/model/layer/recurrent.py`:
- Around line 445-479: The bptt_step JIT-traced function contains an unused
variable h_t = hidden_states[:, t + 1, :], causing an unnecessary dynamic slice;
remove the h_t assignment (and any mentions) from bptt_step so the code no
longer performs that unused slice inside jax.lax.scan, leaving calculations that
reference hidden_states and input_activations (e.g., dh_t, h_prev, x_t, dh_raw,
dW_ih_t, dW_hh_t, dB_t, dx_t, dh_prev) unchanged.
- Around line 369-377: Remove the `@jit` decorator from the nested rnn_step in
_forward_prop and stop closing over mutable self state: capture the needed
tensors (weights_ih, weights_hh, biases, activation_fn) into local variables
before defining rnn_step (or pass them as explicit args to the scan body) so the
function identity is stable across forward passes; do the same refactor for
bptt_step and analogous nested functions in linear.py to avoid repeated JIT
retracing and rely on jax.lax.scan's compilation instead.
In `@mo_net/train/backends/log.py`:
- Around line 302-304: The connection_string property currently returns
"sqlite://:memory:" which is inconsistent with the engine URL created using
"sqlite:///:memory:" (three slashes) and causes incorrect URL parsing; update
the connection_string property (method connection_string) to return the exact
same URL used to create the engine (three slashes, "sqlite:///:memory:") so
round-tripping via create_engine()/urlparse() yields the correct in-memory
SQLite URL.
---
Duplicate comments:
In `@mo_net/data_structures/huffman_tree.py`:
- Around line 254-283: The deserialize/serialize inconsistency: adjust
HuffmanTree so num_internal_nodes is preserved and round-trips correctly —
either include num_internal_nodes in serialize() and read it back in
deserialize(), or change both build() and deserialize() to compute
num_internal_nodes using a single consistent formula (e.g., set
num_internal_nodes = max(vocab_size - 1, 1)); update serialize(), deserialize(),
and build() (referencing methods serialize, deserialize, build and the attribute
num_internal_nodes and vocab_size) to follow the chosen approach so single-word
trees remain num_internal_nodes == 1 and saved trees restore exactly.
- Around line 99-112: The single-word branch in build() should not bypass
__init__; instead construct the HuffmanTree via its normal initializer and set
attributes consistently: create the leaf node (HuffmanNode) and pass it to the
class initializer or call __init__ so tree.root, tree.vocab_size and internal
counters are initialized the same way as other cases; change the single-word
num_internal_nodes logic to use the same formula used in deserialize
(num_internal_nodes = max(0, vocab_size - 1)) so serialize/deserialize
round-trips produce the same count; ensure _build_paths() is still invoked and
that deserialize() retains the identical formula for num_internal_nodes to avoid
allocation mismatch in HierarchicalSoftmaxOutputLayer.
In `@mo_net/model/layer/output.py`:
- Around line 488-543: The _backward_prop method uses Python for-loops over
batch and path which prevents JAX tracing; replace the nested loops in
_backward_prop by vectorizing: obtain fixed-shape padded node_indices and
directions for the batch (use tree.get_path mapped to produce padded arrays and
a mask), use jnp.take(self._parameters.node_vectors, node_indices, axis=0) and
jnp.einsum or jnp.sum with the mask to compute per-path scores, sigmoids and
(sig - target) weights, then compute grad_input and grad_node_vectors via
batched ops (e.g., einsum or jnp.matmul) and accumulate across path positions;
also rename unused embedding_dim to _embedding_dim. Keep storing
self._cache["dP"] = d(self.Parameters(node_vectors=grad_node_vectors)) and
return grad_input as before.
In `@mo_net/model/layer/recurrent.py`:
- Around line 67-71: The __rsub__ implementation is wrong: it calls
self.__sub__(other) which computes self - other instead of other - self; change
__rsub__ to compute other + (-self) by delegating to the negation of self (use
__neg__ or existing negation operator) and then adding other (i.e., return
(-self).__add__(other) or equivalent) so expressions like 0 - params produce the
correct negated result; update the __rsub__ method in the class that defines
__sub__ and __neg__ accordingly.
- Around line 239-253: The deserializer must allow recreating an
already-registered layer: update Recurrent.deserialize to pass exists_ok=True
when constructing Recurrent, and update the Recurrent class constructor
(Recurrent.__init__) to accept an exists_ok: bool = False parameter and forward
it to the base initializer (call super().__init__(..., exists_ok=exists_ok) or
pass to _Base.__init__) so deserialization of pre-registered layers succeeds;
keep parameter name exists_ok to match Linear's pattern.
---
Nitpick comments:
In `@mo_net/cli.py`:
- Around line 280-330: Replace the current if/else on normalisation_type in the
Model.mlp_of call with an exhaustive match on NormalisationType to ensure new
enum members don't silently get cast; handle the NormalisationType.BATCH arm by
passing batch_size and setting normalisation_type to NormalisationType.BATCH,
handle the NormalisationType.NONE and NormalisationType.LAYER arms by passing
the appropriate normalisation_type values, and add an assert_never(default)
fallback for compile-time exhaustiveness checking; also deduplicate common
kwargs (key, module_dimensions, activation_fn, tracing_enabled,
dropout_keep_probs) into a shared kwargs dict used by Model.mlp_of so only the
per-branch differences (batch_size and normalisation_type) are set in each match
arm.
In `@mo_net/model/layer/output.py`:
- Around line 335-341: The Cache TypedDict currently couples optimizer state
into the layer by including first_moment and second_moment; remove those
optimizer-specific keys from the Cache definition (leave output_activations,
input_activations, and dP related to HierarchicalSoftmaxOutputLayer.Parameters)
and refactor any usages of Cache.first_moment/Cache.second_moment to use the
optimizer's external state management (i.e., move moment storage/access into the
optimizer implementation that indexes parameters by
HierarchicalSoftmaxOutputLayer.Parameters or the D[...] key), ensuring
references to Cache, first_moment, second_moment, and
HierarchicalSoftmaxOutputLayer.Parameters are updated accordingly.
In `@mo_net/model/layer/recurrent.py`:
- Around line 270-271: The ternary assigning output_dimensions uses identical
branches; replace the conditional with a direct assignment: set
output_dimensions equal to hidden_dimensions (remove the "if return_sequences
else" expression). Update the occurrence in the recurrent layer implementation
where variables output_dimensions, hidden_dimensions, and return_sequences are
used so the code becomes a simple assignment and behavior is unchanged.
In `@mo_net/model/model.py`:
- Around line 534-555: In populate_caches, make the pairing of grad_layers and
updates strict to catch length mismatches by changing the loop in
populate_caches that currently uses zip(self.grad_layers, updates[:
len(self.grad_layers)]) to zip(self.grad_layers, updates[:
len(self.grad_layers)], strict=True); keep the existing length check and the
final assignment to output_layer.cache["dP"] unchanged (referencing methods
get_gradient_caches and populate_caches and symbol self.grad_layers and
output_layer.cache["dP"]).
In `@mo_net/tests/samples/word2vec/test_hierarchical_integration.py`:
- Around line 74-76: Move the "from mo_net.protos import Activations" import out
of the test body and place it with the module-level imports at the top of the
file so it's executed once on import rather than inside the test; update the
test that calls output_layer._forward_prop(input_activations=Activations(h)) to
rely on the module-level Activations symbol (no other code changes needed).
- Around line 197-211: The test re-imports symbols already imported at module
scope (SoftmaxConfig, SoftmaxStrategy, HuffmanTree, HuffmanNode,
HierarchicalSoftmaxOutputLayer) which shadows module-level names and triggers a
Ruff F811; remove the redundant local from-imports in
test_all_components_importable and instead assert the existing module-level
symbols (e.g., SoftmaxConfig, SoftmaxStrategy, HuffmanTree, HuffmanNode,
HierarchicalSoftmaxOutputLayer) are not None, or if you want local imports for
clarity, rename them to avoid shadowing (e.g., import as SoftmaxConfigLocal) and
use those names consistently.
- Around line 213-224: Remove the no-op placeholder test function
test_all_tests_passing from the file: it only contains assert True and stale
documentation, so delete the function (def test_all_tests_passing(...)) from
mo_net/tests/samples/word2vec/test_hierarchical_integration.py; if you want an
aggregate check, replace it with a concrete test that programmatically inspects
expected test outcomes (e.g., using pytest APIs) instead of a static
assert/documentation comment.
In `@mo_net/train/backends/log.py`:
- Around line 310-330: The start_run, end_run and run-update logic in
log_iteration in this backend duplicates code in SqliteBackend (e.g., methods
start_run, end_run, log_iteration and use of DbRun, _session, _current_run), so
extract the shared lifecycle into a new base class or mixin (e.g.,
RunLifecycleMixin) that implements create/commit of DbRun and updating run
fields, then have both backends inherit that mixin and remove the duplicated
implementations; ensure the mixin exposes the same method signatures (start_run,
end_run, update_run_from_iteration) and uses the existing _session and
_current_run names so callers in LogBackend and SqliteBackend need minimal
changes.
- Around line 348-349: The in-memory backend's log_training_parameters method is
a no-op and discards the training_parameters; implement it to persist the string
so tests can assert on it (e.g., attach the value to the corresponding DbRun
instance or a local dict keyed by run id). Update the log_training_parameters
function to accept the run context (or look up the current run) and store
training_parameters on the DbRun model (or an in-memory attribute like
self._runs[run_id].training_parameters or
self._training_parameters_by_run[run_id]) so subsequent reads (e.g., get_run or
tests) return the stored parameters.
- Around line 343-346: In teardown() ensure any pending transaction is persisted
before closing the SQLAlchemy session: if self._session is set, call
self._session.commit() (catch/handle exceptions if needed) and then close and
null out self._session; reference teardown(), self._session, and log_iteration
(and the SqliteBackend batching behavior) so the change protects against future
batching or interrupted commits and makes intent explicit.
- Around line 293-383: The InMemorySqliteBackend isn't returned by
parse_connection_string because its connection_string ("sqlite://:memory:")
doesn't match the in-memory pattern parse_connection_string checks (and the
engine in __init__ uses "sqlite:///:memory:"), so either update
parse_connection_string to recognise the in-memory SQLite URIs (accept
"sqlite:///:memory:" and "sqlite://:memory:") and return InMemorySqliteBackend,
or make the backend's connection_string match the pattern
parse_connection_string expects; update the InMemorySqliteBackend docstring to
clarify if this backend is intentionally only for direct construction/testing.
|
|
||
| # Initialize node vectors (one per internal node) | ||
| if key is None: | ||
| key = jax.random.PRNGKey(0) |
There was a problem hiding this comment.
Silent fallback to PRNGKey(0) when key is None reduces reproducibility.
When key is not provided, the layer silently uses PRNGKey(0) for initialization. While this is a fallback for deserialization (where parameters are overwritten), it means a caller who forgets to pass key during normal construction gets non-random "random" initialization without any warning.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@mo_net/model/layer/output.py` around lines 410 - 413, The code silently falls
back to jax.random.PRNGKey(0) when key is None during node-vector
initialization, harming reproducibility; instead, remove the silent fallback and
raise a clear error (ValueError) when key is None, telling the caller to supply
an explicit PRNGKey (or an explicit seed/flag when deserializing); update the
block that currently checks "if key is None: key = jax.random.PRNGKey(0)"
(referencing the variable key and the jax.random.PRNGKey(0) usage in the
"Initialize node vectors (one per internal node)" section) to validate key and
raise with a descriptive message so callers must opt in to any default.
| @jit | ||
| def rnn_step(h_prev, x_t): | ||
| """Single RNN step: h_t = activation(x_t @ W_ih + h_{t-1} @ W_hh + b)""" | ||
| h_t = self._activation_fn( | ||
| x_t @ self._parameters.weights_ih | ||
| + h_prev @ self._parameters.weights_hh | ||
| + self._parameters.biases | ||
| ) | ||
| return h_t, h_t |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# First, locate and examine the actual file
find . -name "recurrent.py" -type f | head -5Repository: modiase/mo-net
Length of output: 91
🏁 Script executed:
#!/bin/bash
# Read the full context around the mentioned lines
if [ -f "mo_net/model/layer/recurrent.py" ]; then
wc -l mo_net/model/layer/recurrent.py
fiRepository: modiase/mo-net
Length of output: 93
🏁 Script executed:
#!/bin/bash
# Get full context of the _forward_prop method and surrounding code
if [ -f "mo_net/model/layer/recurrent.py" ]; then
# Read lines around 369-377 with more context
head -400 mo_net/model/layer/recurrent.py | tail -100 | cat -n
fiRepository: modiase/mo-net
Length of output: 4940
🏁 Script executed:
#!/bin/bash
# Search for other `@jit` usages inside methods to see if this is a pattern
rg -n '@jit' --type=py mo_net/ | head -20Repository: modiase/mo-net
Length of output: 463
🏁 Script executed:
#!/bin/bash
# Check the exact context around the _forward_prop and how parameters are used
sed -n '48,100p' mo_net/model/layer/recurrent.py | cat -nRepository: modiase/mo-net
Length of output: 2375
🏁 Script executed:
#!/bin/bash
# Check if there are tests or documentation about JAX scan and JIT
find . -type f \( -name "*.py" -o -name "*.md" \) -exec grep -l "jax.lax.scan" {} \;Repository: modiase/mo-net
Length of output: 91
🏁 Script executed:
#!/bin/bash
# Check line 219 context (first `@jit` in recurrent.py)
sed -n '215,225p' mo_net/model/layer/recurrent.py | cat -nRepository: modiase/mo-net
Length of output: 507
🏁 Script executed:
#!/bin/bash
# Check line 445 context (second `@jit` inside method in recurrent.py)
sed -n '440,455p' mo_net/model/layer/recurrent.py | cat -nRepository: modiase/mo-net
Length of output: 873
🏁 Script executed:
#!/bin/bash
# Check the linear.py `@jit` patterns at lines 317, 338
sed -n '310,325p' mo_net/model/layer/linear.py | cat -nRepository: modiase/mo-net
Length of output: 644
🏁 Script executed:
#!/bin/bash
# Check line 338 context in linear.py
sed -n '335,350p' mo_net/model/layer/linear.py | cat -nRepository: modiase/mo-net
Length of output: 783
🏁 Script executed:
#!/bin/bash
# See if there's any documentation or best practices in the repo
find . -type f \( -name "*.md" -o -name "*.txt" \) -exec grep -l "jit\|JAX" {} \;Repository: modiase/mo-net
Length of output: 69
🌐 Web query:
JAX jax.lax.scan JIT compilation performance optimization
💡 Result:
Key ways to improve JIT compilation performance (compile time + avoiding recompiles) when using jax.lax.scan:
- Use
lax.scan(orlax.fori_loop) instead of Pythonforinsidejitfor non-trivial/long loops: Python loops inside ajitget unrolled, which can make the compiled program huge and compilation time grow quickly;scanlowers to a single XLAWhileOp, compiling the body once. [1][2] - Tune
unrollto balance compile-time vs runtime: higherunrollincreases unrolling (often faster runtime, larger compile); lowerunrollkeeps the loop rolled (often smaller/faster compile).unroll=0(orTrue) fully unrolls. [1][2] - Keep shapes/dtypes static across iterations and across calls:
scanrequires the carry to have a fixed shape/dtype across iterations. [1]- JIT caches by function + argument shapes/dtypes; if input shapes vary a lot (e.g., variable sequence lengths), you’ll trigger frequent recompilations—pad/bucket to reduce shape variety. [3]
- If iteration count is dynamic, don’t force it into
scan:scanis designed for a static number of iterations; preferlax.fori_loop/lax.while_loopfor dynamic trip counts. [1]
Sources:
[1] JAX docs: jax.lax.scan (lowering to WhileOp, static iteration guidance, unroll, fixed carry shape/dtype)
[2] Discussion of Python-loop unrolling vs scan/fori_loop, compile-time tradeoffs, and unroll tuning
[3] Note on JIT cache misses / recompilation when input shapes vary frequently
🌐 Web query:
JAX @jit decorator retracing closure mutable parameters best practices
💡 Result:
When @jax.jit “retraces”/recompiles, it’s almost always because something in the compilation cache key changed: input shape/dtype/sharding, a static argument, or (less obviously) the Python function object you’re JITing/creating.
Best practices (esp. with closures + “mutable params”)
-
Prefer explicit arguments over closing over mutable values
- If you close over a value that changes (Python object, config, non-scalar array, etc.), you can trigger recompiles or get confusing “stale” behavior.
- JAX has specific internal handling for closed-over constants (notably non-scalar arrays encountered during tracing that don’t depend on function arguments), which is a good reason to pass such arrays as explicit arguments instead of capturing them in a closure. [1]
-
Keep jitted functions “stable”: don’t recreate closures in hot loops
- If you build a new inner function/closure repeatedly (new
id), you often defeat caching and end up recompiling because you’re effectively JITing a different function object each time. (This is a common gotcha in systems that warn against “temporary functions” whose identity changes.) [2]
- If you build a new inner function/closure repeatedly (new
-
Treat “parameters/state” as PyTrees (arrays dynamic; metadata static)
- Put arrays (weights, tensors) in the dynamic part of a PyTree so changing their values does not recompile; only changes to shape/dtype/sharding do. [3]
- Put only truly-static, small metadata (bools/ints/strings) in the static part; changing static values recompiles. [3]
-
Use
static_argnums/static_argnamesonly for small, hashable, immutable things- JAX’s own guidance: static args should be hashable and immutable; marking large/mutable objects (e.g., classes containing arrays) as static can cause surprises and cache/memory issues. [4]
- If you must make a custom object static, define meaningful
__hash__/__eq__and keep it immutable. [5]
-
Keep shapes static (or accept recompiles)
- JIT specialization is shape-based; calling with a new shape will recompile. If varying-length data is unavoidable, common workarounds are padding+masking or bucketing by shape. [6]
A robust pattern (avoid closure-mutation issues)
import jax, jax.numpy as jnp
`@jax.jit`
def step(params, x, *, cfg): # cfg can be static_argnames if truly static
# params: pytree of arrays; x: array
# cfg: small immutable metadata (or avoid and use jnp.where/lax instead)
return params, x @ params["W"] + params["b"]If you share a small snippet of your current closure/“mutable params” setup, I can point out exactly what in your code is likely entering the cache key and causing retraces.
Sources: Handling closed-over constants [1]; avoiding temporary/nested functions whose id changes [2]; closure/PyTree static-vs-dynamic recompilation behavior [3]; static arg mutability/hashability caveats and cache/memory issues [4]; JAX “Sharp Bits” note on hashing/equality for static objects [5]; shape change implies recompilation [6].
[1] (docs.jax.dev) [2] (tum-pbs.github.io) [3] (stackoverflow.com) [4] (github.com) [5] (docs.jax.dev) [6] (stackoverflow.com)
Citations:
- 1: https://docs.jax.dev/en/latest/internals/constants.html?utm_source=openai
- 2: https://tum-pbs.github.io/PhiFlow/Known_Issues.html?utm_source=openai
- 3: https://stackoverflow.com/questions/79744486/jax-recompilation-when-using-closure-for-a-function?utm_source=openai
- 4: Memory leak when using @partial(jax.jit, static_argnums=(0,)) in class jax-ml/jax#11855
- 5: https://docs.jax.dev/en/latest/notebooks/Common_Gotchas_in_JAX.html?utm_source=openai
- 6: https://stackoverflow.com/questions/70126391/jax-avoid-just-in-time-recompilation-for-a-function-evaluated-with-a-varying-nu?utm_source=openai
Remove the @jit decorator; jax.lax.scan already provides efficient compilation.
Defining @jit-decorated functions inside _forward_prop causes unnecessary retracing on every forward pass. Since _forward_prop is called repeatedly during training (after each parameter update), a new rnn_step function object is created each time, and JAX's JIT cache—which relies on function identity—does not recognize it as the same function. This triggers recompilation even though the function logic hasn't changed. Additionally, closing over self._parameters (which gets updated during training) exacerbates this.
jax.lax.scan itself compiles the loop body once to an efficient XLA operation, so the @jit decorator on rnn_step is redundant and harmful. Remove it.
Note: This pattern appears in other methods (e.g., bptt_step at line 445, and similar patterns in linear.py) and should be addressed across the codebase.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@mo_net/model/layer/recurrent.py` around lines 369 - 377, Remove the `@jit`
decorator from the nested rnn_step in _forward_prop and stop closing over
mutable self state: capture the needed tensors (weights_ih, weights_hh, biases,
activation_fn) into local variables before defining rnn_step (or pass them as
explicit args to the scan body) so the function identity is stable across
forward passes; do the same refactor for bptt_step and analogous nested
functions in linear.py to avoid repeated JIT retracing and rely on
jax.lax.scan's compilation instead.
| @jit | ||
| def bptt_step(carry, t): | ||
| """Backward pass for single timestep.""" | ||
| dh_next, dW_ih_acc, dW_hh_acc, dB_acc = carry | ||
|
|
||
| # Get current gradients and activations | ||
| dh_t = ( | ||
| dh_all[:, t, :] + dh_next | ||
| ) # Gradient from output + gradient from next timestep | ||
| h_t = hidden_states[:, t + 1, :] # Current hidden state | ||
| h_prev = hidden_states[:, t, :] # Previous hidden state | ||
| x_t = input_activations[:, t, :] # Current input | ||
|
|
||
| # Gradient through activation function | ||
| dh_raw = dh_t * self._activation_fn.deriv( | ||
| x_t @ self._parameters.weights_ih | ||
| + h_prev @ self._parameters.weights_hh | ||
| + self._parameters.biases | ||
| ) | ||
|
|
||
| # Gradients w.r.t. parameters | ||
| dW_ih_t = x_t.T @ dh_raw | ||
| dW_hh_t = h_prev.T @ dh_raw | ||
| dB_t = jnp.sum(dh_raw, axis=0) | ||
|
|
||
| # Gradients w.r.t. inputs | ||
| dx_t = dh_raw @ self._parameters.weights_ih.T | ||
| dh_prev = dh_raw @ self._parameters.weights_hh.T | ||
|
|
||
| return ( | ||
| dh_prev, | ||
| dW_ih_acc + dW_ih_t, | ||
| dW_hh_acc + dW_hh_t, | ||
| dB_acc + dB_t, | ||
| ), dx_t |
There was a problem hiding this comment.
Unused variable h_t inside bptt_step — dead code in a JIT-traced function.
Line 454 assigns h_t = hidden_states[:, t + 1, :] (the current hidden state), but it is never read. This wastes a dynamic slice operation inside the jax.lax.scan loop.
Proposed fix
- h_t = hidden_states[:, t + 1, :] # Current hidden state
h_prev = hidden_states[:, t, :] # Previous hidden state📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| @jit | |
| def bptt_step(carry, t): | |
| """Backward pass for single timestep.""" | |
| dh_next, dW_ih_acc, dW_hh_acc, dB_acc = carry | |
| # Get current gradients and activations | |
| dh_t = ( | |
| dh_all[:, t, :] + dh_next | |
| ) # Gradient from output + gradient from next timestep | |
| h_t = hidden_states[:, t + 1, :] # Current hidden state | |
| h_prev = hidden_states[:, t, :] # Previous hidden state | |
| x_t = input_activations[:, t, :] # Current input | |
| # Gradient through activation function | |
| dh_raw = dh_t * self._activation_fn.deriv( | |
| x_t @ self._parameters.weights_ih | |
| + h_prev @ self._parameters.weights_hh | |
| + self._parameters.biases | |
| ) | |
| # Gradients w.r.t. parameters | |
| dW_ih_t = x_t.T @ dh_raw | |
| dW_hh_t = h_prev.T @ dh_raw | |
| dB_t = jnp.sum(dh_raw, axis=0) | |
| # Gradients w.r.t. inputs | |
| dx_t = dh_raw @ self._parameters.weights_ih.T | |
| dh_prev = dh_raw @ self._parameters.weights_hh.T | |
| return ( | |
| dh_prev, | |
| dW_ih_acc + dW_ih_t, | |
| dW_hh_acc + dW_hh_t, | |
| dB_acc + dB_t, | |
| ), dx_t | |
| `@jit` | |
| def bptt_step(carry, t): | |
| """Backward pass for single timestep.""" | |
| dh_next, dW_ih_acc, dW_hh_acc, dB_acc = carry | |
| # Get current gradients and activations | |
| dh_t = ( | |
| dh_all[:, t, :] + dh_next | |
| ) # Gradient from output + gradient from next timestep | |
| h_prev = hidden_states[:, t, :] # Previous hidden state | |
| x_t = input_activations[:, t, :] # Current input | |
| # Gradient through activation function | |
| dh_raw = dh_t * self._activation_fn.deriv( | |
| x_t @ self._parameters.weights_ih | |
| h_prev @ self._parameters.weights_hh | |
| self._parameters.biases | |
| ) | |
| # Gradients w.r.t. parameters | |
| dW_ih_t = x_t.T @ dh_raw | |
| dW_hh_t = h_prev.T @ dh_raw | |
| dB_t = jnp.sum(dh_raw, axis=0) | |
| # Gradients w.r.t. inputs | |
| dx_t = dh_raw @ self._parameters.weights_ih.T | |
| dh_prev = dh_raw @ self._parameters.weights_hh.T | |
| return ( | |
| dh_prev, | |
| dW_ih_acc + dW_ih_t, | |
| dW_hh_acc + dW_hh_t, | |
| dB_acc + dB_t, | |
| ), dx_t |
🧰 Tools
🪛 Ruff (0.15.1)
[error] 454-454: Local variable h_t is assigned to but never used
Remove assignment to unused variable h_t
(F841)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@mo_net/model/layer/recurrent.py` around lines 445 - 479, The bptt_step
JIT-traced function contains an unused variable h_t = hidden_states[:, t + 1,
:], causing an unnecessary dynamic slice; remove the h_t assignment (and any
mentions) from bptt_step so the code no longer performs that unused slice inside
jax.lax.scan, leaving calculations that reference hidden_states and
input_activations (e.g., dh_t, h_prev, x_t, dh_raw, dW_ih_t, dW_hh_t, dB_t,
dx_t, dh_prev) unchanged.
| @property | ||
| def connection_string(self) -> str: | ||
| return "sqlite://:memory:" |
There was a problem hiding this comment.
Connection string is inconsistent with the actual engine URL.
The engine is created with "sqlite:///:memory:" (line 297, three slashes) but this property returns "sqlite://:memory:" (two slashes). In SQLAlchemy's URL format, two slashes makes :memory: the hostname rather than the database path, so round-tripping this value through create_engine() or urlparse() would produce incorrect results.
Proposed fix
`@property`
def connection_string(self) -> str:
- return "sqlite://:memory:"
+ return "sqlite:///:memory:"📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| @property | |
| def connection_string(self) -> str: | |
| return "sqlite://:memory:" | |
| `@property` | |
| def connection_string(self) -> str: | |
| return "sqlite:///:memory:" |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@mo_net/train/backends/log.py` around lines 302 - 304, The connection_string
property currently returns "sqlite://:memory:" which is inconsistent with the
engine URL created using "sqlite:///:memory:" (three slashes) and causes
incorrect URL parsing; update the connection_string property (method
connection_string) to return the exact same URL used to create the engine (three
slashes, "sqlite:///:memory:") so round-tripping via create_engine()/urlparse()
yields the correct in-memory SQLite URL.
This change is
Summary by CodeRabbit
New Features
Bug Fixes
Documentation
Tests