Skip to content

Commit c21505d

Browse files
jakeharmon8KfacJaxDev
authored andcommitted
Update references to JAX's GitHub repo
JAX has moved from https://github.com/google/jax to https://github.com/jax-ml/jax PiperOrigin-RevId: 702886640
1 parent aaf3064 commit c21505d

File tree

3 files changed

+5
-5
lines changed

3 files changed

+5
-5
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ implementation of the [K-FAC] optimizer and curvature estimator.
1919

2020
KFAC-JAX is written in pure Python, but depends on C++ code via JAX.
2121

22-
First, follow [these instructions](https://github.com/google/jax#installation)
22+
First, follow [these instructions](https://github.com/jax-ml/jax#installation)
2323
to install JAX with the relevant accelerator support.
2424

2525
Then, install KFAC-JAX using pip:
@@ -219,6 +219,6 @@ and the year corresponds to the project's open-source release.
219219

220220

221221
[K-FAC]: https://arxiv.org/abs/1503.05671
222-
[JAX]: https://github.com/google/jax
222+
[JAX]: https://github.com/jax-ml/jax
223223
[Haiku]: https://github.com/google-deepmind/dm-haiku
224224
[documentation]: https://kfac-jax.readthedocs.io/

docs/index.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
KFAC-JAX Documentation
44
======================
55

6-
KFAC-JAX is a library built on top of `JAX <https://github.com/google/jax>`_ for
6+
KFAC-JAX is a library built on top of `JAX <https://github.com/jax-ml/jax>`_ for
77
second-order optimization of neural networks and for computing scalable
88
curvature approximations.
99
The main goal of the library is to provide researchers with an easy-to-use
@@ -16,7 +16,7 @@ Installation
1616

1717
KFAC-JAX is written in pure Python, but depends on C++ code via JAX.
1818

19-
First, follow `these instructions <https://github.com/google/jax#installation>`_
19+
First, follow `these instructions <https://github.com/jax-ml/jax#installation>`_
2020
to install JAX with the relevant accelerator support.
2121

2222
Then, install KFAC-JAX using pip::

examples/losses.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def softmax_cross_entropy(
104104
max_logits = jnp.max(logits, keepdims=True, axis=-1)
105105

106106
# It's unclear whether this stop_gradient is a good idea.
107-
# See https://github.com/google/jax/issues/13529
107+
# See https://github.com/jax-ml/jax/issues/13529
108108
max_logits = lax.stop_gradient(max_logits)
109109

110110
logits = logits - max_logits

0 commit comments

Comments
 (0)