Skip to content

Commit 5f8d5f5

Browse files
superbobryGoogle-ML-Automation
authored andcommitted
[state] Added some docs to DischargeRule and PartialDischargeRule
PiperOrigin-RevId: 824407066
1 parent 376bc38 commit 5f8d5f5

File tree

1 file changed

+71
-26
lines changed

1 file changed

+71
-26
lines changed

jax/_src/state/discharge.py

Lines changed: 71 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -56,15 +56,29 @@
5656

5757
## Discharging state
5858

59-
# Let's say we have a jaxpr that takes in `Ref`s and outputs regular JAX values
60-
# (`Ref`s should never be outputs from jaxprs). We'd like to convert that jaxpr
61-
# into a "pure" jaxpr that takes in and outputs values and no longer has the
62-
# `Read/Write/Accum` effects.
63-
64-
def discharge_state(jaxpr: core.Jaxpr, consts: Sequence[Any], * ,
65-
should_discharge: bool | Sequence[bool] = True,
66-
) -> tuple[core.Jaxpr, list[Any]]:
67-
"""Converts a jaxpr that takes in `Ref`s into one that doesn't."""
59+
60+
def discharge_state(
61+
jaxpr: core.Jaxpr,
62+
consts: Sequence[Any],
63+
*,
64+
should_discharge: bool | Sequence[bool] = True,
65+
) -> tuple[core.Jaxpr, Sequence[Any]]:
66+
"""Converts a stateful jaxpr into a pure one.
67+
68+
Discharging replaces ``Ref`` inputs with regular values, threads updates
69+
through the computation, and returns updated ``Ref``s as additional outputs.
70+
71+
Args:
72+
jaxpr: A stateful jaxpr with ``Ref`` inputs.
73+
consts: Constants for the jaxpr.
74+
should_discharge: Whether to discharge each ``Ref`` input. If a single bool,
75+
applies to all inputs.
76+
77+
Returns:
78+
A tuple of ``(new_jaxpr, new_consts)`` where ``new_jaxpr`` is a jaxpr with
79+
no ``Read``/``Write``/``Accum`` effects. Discharged ``Ref`` inputs become
80+
regular value inputs, and their updated values are appended to the outputs.
81+
"""
6882
if isinstance(should_discharge, bool):
6983
should_discharge = [should_discharge] * len(jaxpr.invars)
7084
in_avals = [v.aval.inner_aval
@@ -105,37 +119,68 @@ def read(self, v: core.Atom) -> Any:
105119
def write(self, v: core.Var, val: Any) -> None:
106120
self.env[v] = val
107121

122+
108123
class DischargeRule(Protocol):
109124

110-
def __call__(self, in_avals: Sequence[core.AbstractValue],
111-
out_avals: Sequence[core.AbstractValue], *args: Any,
112-
**params: Any) -> tuple[Sequence[Any | None], Sequence[Any]]:
113-
...
125+
def __call__(
126+
self,
127+
in_avals: Sequence[core.AbstractValue],
128+
out_avals: Sequence[core.AbstractValue],
129+
*args: Any,
130+
**params: Any,
131+
) -> tuple[Sequence[Any | None], Any | Sequence[Any]]:
132+
"""Discharge rule for a primitive.
133+
134+
See :func:`discharge_state` for an explanation of what discharge means.
135+
136+
Args:
137+
in_avals: Input abstract values.
138+
out_avals: Output abstract values.
139+
*args: Input values.
140+
**params: Primitive parameters.
141+
142+
Returns:
143+
A tuple of ``(new_invals, new_outvals)`` where:
144+
145+
* ``new_invals`` contains updated values for discharged ``Ref`` inputs,
146+
or ``None`` if the input is not a ``Ref`` or was not updated.
147+
* ``new_outvals`` is the primitive's output. A sequence if the primitive
148+
has multiple results, otherwise a single value.
149+
"""
150+
114151

115152
_discharge_rules: dict[core.Primitive, DischargeRule] = {}
116153

154+
155+
def register_discharge_rule(prim: core.Primitive):
156+
def register(f: DischargeRule):
157+
_discharge_rules[prim] = f
158+
159+
return register
160+
161+
117162
class PartialDischargeRule(Protocol):
118-
"""A partial discharge rule.
163+
"""Discharge rule that supports selective discharging of ``Ref`` inputs.
119164
120-
Exactly like a discharge rule only it accepts a `should_discharge`
121-
argument that indicates which inputs should be discharged and the
122-
return value returns a tuple of which the first element is the new
123-
inputs or none but only the ones that correspond to `True` entries
124-
in `should_charge`.
165+
Generalizes :class:`DischargeRule` by accepting a ``should_discharge``
166+
argument that specifies which ``Ref`` inputs to discharge. The returned
167+
``new_invals`` must contain a non-``None`` value if and only if the
168+
corresponding ``Ref`` was discharged.
125169
"""
126170

127-
def __call__(self, should_discharge: Sequence[bool],
171+
def __call__(
172+
self,
173+
should_discharge: Sequence[bool],
128174
in_avals: Sequence[core.AbstractValue],
129-
out_avals: Sequence[core.AbstractValue], *args: Any,
130-
**params: Any) -> tuple[Sequence[Any | None], Sequence[Any]]:
175+
out_avals: Sequence[core.AbstractValue],
176+
*args: Any,
177+
**params: Any,
178+
) -> tuple[Sequence[Any | None], Any | Sequence[Any]]:
131179
...
132180

181+
133182
_partial_discharge_rules: dict[core.Primitive, PartialDischargeRule] = {}
134183

135-
def register_discharge_rule(prim: core.Primitive):
136-
def register(f: DischargeRule):
137-
_discharge_rules[prim] = f
138-
return register
139184

140185
def register_partial_discharge_rule(prim: core.Primitive):
141186
def register(f: PartialDischargeRule):

0 commit comments

Comments
 (0)