Refactor: Module restructuring, rename ETraceState to HiddenState, add eligibility trace tests, and fix weight computation bugs#59
Merged
chaoming0625 merged 19 commits intomainfrom Mar 18, 2026
Conversation
…__init__.py, rename esd_rtrl to pp_prop, and adjust __all__ exports for better organization
…ate import paths in __init__.py
…ts into single-line imports for improved readability
…st.py and refactor imports for improved readability
….py and __init__.py for better module organization
…ht function call in matrix multiplication
…mproved clarity and robustness
… for gradient checks
…ddenState and related changes
… formatting initialization parameters
There was a problem hiding this comment.
Pull request overview
This PR updates BrainTrace’s packaging, APIs, examples, and documentation to align with newer brainstate/jax usage, including migrating ETrace state docs to brainstate.HiddenState and removing spike-based readout support.
Changes:
- Bump Python requirement to 3.11 and update version handling via a dedicated
braintrace/_version.py. - Remove
LeakySpikeReadout(code, docs, and tests) and clean upnnexports. - Refactor compiler/VJP internals for determinism, compatibility, and expanded test coverage; update docs notebooks and Sphinx build configuration.
Reviewed changes
Copilot reviewed 66 out of 71 changed files in this pull request and generated 8 comments.
Show a summary per file
| File | Description |
|---|---|
| pyproject.toml | Bumps Python requirement; adjusts dynamic version source; trims optional dependency groups. |
| examples/snn_models.py | Fixes initializer call signatures and reformats layer construction for clarity. |
| examples/002-coba-ei-rsnn.py | Updates BrainPy import style. |
| examples/001-gif-snn-for-dms.py | Removes stray no-op statement in __main__. |
| docs/tutorial/etracestate-zh.ipynb | Migrates tutorial terminology from ETraceState to brainstate.HiddenState; clears outputs. |
| docs/tutorial/etracestate-en.ipynb | Same migration as zh version; updates links and text. |
| docs/tutorial/etraceop-zh.ipynb | Updates notebook lexer metadata to ipython3. |
| docs/tutorial/etraceop-en.ipynb | Updates notebook lexer metadata to ipython3. |
| docs/quickstart/snn_online_learning-zh.ipynb | Updates notebook lexer metadata to ipython3. |
| docs/quickstart/snn_online_learning-en.ipynb | Updates notebook lexer metadata to ipython3. |
| docs/quickstart/rnn_online_learning-zh.ipynb | Updates notebook lexer metadata to ipython3. |
| docs/quickstart/concepts-zh.ipynb | Updates state concept references to brainstate.HiddenState; formatting changes. |
| docs/quickstart/concepts-en.ipynb | Updates state concept references to brainstate.HiddenState; wording updates. |
| docs/highlight_lexer.py | Adds a utility script to rewrite ipython2 lexer metadata in notebooks. |
| docs/conf.py | Adds docs directory to sys.path and runs notebook lexer fixer during Sphinx import. |
| docs/apis/nn.rst | Removes LeakySpikeReadout from generated API docs. |
| docs/apis/concepts.rst | Documents migration of ETraceState classes to brainstate.Hidden*State. |
| docs/advanced/online_algorithm_customization-zh.ipynb | Updates notebook lexer metadata to ipython3. |
| docs/advanced/online_algorithm_customization-en.ipynb | Updates notebook lexer metadata to ipython3. |
| docs/advanced/limitations-zh.ipynb | Updates text to refer to HiddenState (brainstate) + lexer fixes. |
| docs/advanced/limitations-en.ipynb | Updates text to refer to HiddenState (brainstate) + lexer fixes. |
| docs/advanced/IR_analysis-zh.ipynb | Clears outputs, adjusts formatting, updates lexer metadata. |
| docs/advanced/IR_analysis-en.ipynb | Clears outputs, adjusts formatting, updates lexer metadata. |
| changelog.md | Removes VJP import snippet from changelog section. |
| braintrace/nn/_rnn.py | Fixes forget-bias random range; avoids stale-read in complex state update. |
| braintrace/nn/_readout_test.py | Removes spike readout tests; cleans imports; keeps rate readout coverage. |
| braintrace/nn/_readout.py | Removes LeakySpikeReadout; fixes tau shape for LeakyRateReadout. |
| braintrace/nn/_linear.py | Normalizes in_size/out_size to tuples for consistency. |
| braintrace/nn/init.py | Replaces star-exports with explicit exports; updates deprecated state forwarding. |
| braintrace/_version.py | Introduces central version module for packaging. |
| braintrace/_state_managment_test.py | Adds unit tests for state assignment/splitting utilities. |
| braintrace/_state_managment.py | Removes legacy APIs; strengthens key mismatch handling with ValueError. |
| braintrace/_grad_exponential.py | Stores gradients in LongTermState for transform tracking; fixes update mutation. |
| braintrace/_etrace_vjp/pp_prop.py | Consolidates compiler imports; clamps correction factor to avoid divide-by-zero. |
| braintrace/_etrace_vjp/misc_test.py | Adds tests for VJP misc helpers. |
| braintrace/_etrace_vjp/hybrid.py | Refactors imports; renames exported algorithm; fixes leaf traversal call. |
| braintrace/_etrace_vjp/graph_executor.py | Consolidates compiler imports; switches to get_aval helper. |
| braintrace/_etrace_vjp/d_rtrl_test.py | Expands tests for normalization/unit helpers and algorithm behavior. |
| braintrace/_etrace_vjp/d_rtrl.py | Consolidates compiler/misc imports for new compiler package structure. |
| braintrace/_etrace_vjp/base_test.py | Adds extensive tests for VJP base algorithm behavior and compilation flow. |
| braintrace/_etrace_vjp/base.py | Minor import simplification and formatting. |
| braintrace/_etrace_vjp/init.py | Replaces star-exports with explicit exports. |
| braintrace/_etrace_operators_test.py | Improves assertions; adds coverage for weight_fn path. |
| braintrace/_etrace_operators.py | Fixes double-application of weight_fn; fixes convolution placeholder usage. |
| braintrace/_etrace_model_test.py | Updates initializer calls to tuple shapes; touches concat-based weight assembly. |
| braintrace/_etrace_input_data.py | Ensures deterministic merge order for flattened data. |
| braintrace/_etrace_graph_executor.py | Switches to new compiler package import; invalidates cached mappings on recompile. |
| braintrace/_etrace_concepts.py | Adds support for ETraceOp-backed elementwise params via internal flag. |
| braintrace/_etrace_compiler/module_info.py | Uses warnings.warn; preserves insertion order when deduplicating invars. |
| braintrace/_etrace_compiler/hidden_pertubation_test.py | Updates import path to new compiler package layout. |
| braintrace/_etrace_compiler/hidden_pertubation.py | Improves determinism and robustness; avoids unresolved refs; fixes perturb init zeros. |
| braintrace/_etrace_compiler/hidden_group_test.py | Updates import paths; fixes incorrect assertions. |
| braintrace/_etrace_compiler/hidden_group.py | Updates imports; adds cycle/ordering failure detection during topo sort. |
| braintrace/_etrace_compiler/hid_param_op.py | Refactors imports; fixes tracer field docs; adds sentinel-safe replace(). |
| braintrace/_etrace_compiler/graph.py | Refactors imports; makes temp outvar selection deterministic. |
| braintrace/_etrace_compiler/base_test.py | Adds tests for compiler base utilities and JAXPR evaluation dispatch. |
| braintrace/_etrace_compiler/base.py | Fixes swapped op names for scan/while unsupported-op checks; import refactor. |
| braintrace/_etrace_compiler/init.py | Adds public re-export surface for new compiler subpackage. |
| braintrace/_etrace_algorithms_test.py | Adds unit tests for ETraceAlgorithm base behaviors and compilation workflow. |
| braintrace/_etrace_algorithms.py | Switches to new compiler import and invalidates cached splits on compile. |
| braintrace/init.py | Replaces star-exports with explicit API surface; sources version from _version. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
You can also share your feedback on Copilot code review. Take the survey.
Comment on lines
+88
to
+95
| # 创建备份 | ||
| backup_path = file_path + '.backup' | ||
| with open(backup_path, 'w', encoding='utf-8') as f: | ||
| json.dump(data, f, indent=2, ensure_ascii=False) | ||
|
|
||
| # 保存修复后的文件 | ||
| with open(file_path, 'w', encoding='utf-8') as f: | ||
| json.dump(data, f, indent=2, ensure_ascii=False) |
| import sys | ||
|
|
||
| sys.path.insert(0, os.path.abspath('../')) | ||
| sys.path.insert(0, os.path.abspath('./')) |
|
|
||
| from highlight_lexer import fix_ipython2_lexer_in_notebooks | ||
|
|
||
| fix_ipython2_lexer_in_notebooks(os.path.abspath(os.path.dirname(os.path.abspath(__file__)))) |
| self.n_inh_in = n_in - self.n_exc_in | ||
|
|
||
| weight = jnp.concat([ff_init([self.n_exc_in, n_rec]), rec_init([self.n_exc_rec, n_rec])], axis=0) | ||
| weight = jnp.concat([ff_init((self.n_exc_in, n_rec)), rec_init((self.n_exc_rec, n_rec))], axis=0) |
Comment on lines
+1084
to
+1085
| "pygments_lexer": "ipython3", | ||
| "version": "2.7.6" |
Comment on lines
+46
to
51
| from brainstate._compatible_import import get_aval | ||
| from jax.extend import linear_util as lu | ||
| from jax.interpreters import partial_eval as pe | ||
| from jax.tree_util import register_pytree_node_class | ||
|
|
||
| from braintrace._compatible_imports import Var |
Comment on lines
+362
to
372
| self._is_etrace_op = False | ||
| if isinstance(op, ETraceOp): | ||
| assert isinstance(op, ElemWiseOp), ( | ||
| f'op should be ElemWiseOp. ' | ||
| f'But we got {type(op)}.' | ||
| ) | ||
| op = op.xw_to_y | ||
| self._is_etrace_op = True | ||
| self.op = op | ||
| self.value = weight | ||
| self.name = name |
Comment on lines
+384
to
+386
| if self._is_etrace_op: | ||
| return self.op(None, self.value) | ||
| return self.op(self.value) |
…fects' to 'effs' in JaxprEqn initialization
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
🔧 1. Module & Import Restructuring
__init__.pyfiles throughout the project._version.pyfile.esd_rtrlmodule topp_propfor better naming clarity.__all__exports across multiple modules.🗑️ 2. Removal of
LeakySpikeReadoutLeakySpikeReadoutclass from_readout.pyand its corresponding tests, streamlining the readout module.🔄 3. Class Renaming:
ETraceState→HiddenStateETraceStateclass toHiddenStatefor better semantic clarity, with documentation and code updated accordingly.🐛 4. Bug Fixes
_etrace_operators.pymatrix multiplication._version.py.✅ 5. New & Refactored Tests
_state_management_test.py).ETraceAlgorithmandEligibilityTracefunctionality.LeakySpikeReadouttests.🛡️ 6. State Management & Error Handling