File tree Expand file tree Collapse file tree 3 files changed +5
-5
lines changed
Expand file tree Collapse file tree 3 files changed +5
-5
lines changed Original file line number Diff line number Diff line change @@ -19,7 +19,7 @@ implementation of the [K-FAC] optimizer and curvature estimator.
1919
2020KFAC-JAX is written in pure Python, but depends on C++ code via JAX.
2121
22- First, follow [ these instructions] ( https://github.com/google /jax#installation )
22+ First, follow [ these instructions] ( https://github.com/jax-ml /jax#installation )
2323to install JAX with the relevant accelerator support.
2424
2525Then, install KFAC-JAX using pip:
@@ -219,6 +219,6 @@ and the year corresponds to the project's open-source release.
219219
220220
221221[ K-FAC ] : https://arxiv.org/abs/1503.05671
222- [ JAX ] : https://github.com/google /jax
222+ [ JAX ] : https://github.com/jax-ml /jax
223223[ Haiku ] : https://github.com/google-deepmind/dm-haiku
224224[ documentation ] : https://kfac-jax.readthedocs.io/
Original file line number Diff line number Diff line change 33KFAC-JAX Documentation
44======================
55
6- KFAC-JAX is a library built on top of `JAX <https://github.com/google /jax >`_ for
6+ KFAC-JAX is a library built on top of `JAX <https://github.com/jax-ml /jax >`_ for
77second-order optimization of neural networks and for computing scalable
88curvature approximations.
99The main goal of the library is to provide researchers with an easy-to-use
@@ -16,7 +16,7 @@ Installation
1616
1717KFAC-JAX is written in pure Python, but depends on C++ code via JAX.
1818
19- First, follow `these instructions <https://github.com/google /jax#installation >`_
19+ First, follow `these instructions <https://github.com/jax-ml /jax#installation >`_
2020to install JAX with the relevant accelerator support.
2121
2222Then, install KFAC-JAX using pip::
Original file line number Diff line number Diff line change @@ -104,7 +104,7 @@ def softmax_cross_entropy(
104104 max_logits = jnp .max (logits , keepdims = True , axis = - 1 )
105105
106106 # It's unclear whether this stop_gradient is a good idea.
107- # See https://github.com/google /jax/issues/13529
107+ # See https://github.com/jax-ml /jax/issues/13529
108108 max_logits = lax .stop_gradient (max_logits )
109109
110110 logits = logits - max_logits
You can’t perform that action at this time.
0 commit comments