|
56 | 56 |
|
57 | 57 | ## Discharging state |
58 | 58 |
|
59 | | -# Let's say we have a jaxpr that takes in `Ref`s and outputs regular JAX values |
60 | | -# (`Ref`s should never be outputs from jaxprs). We'd like to convert that jaxpr |
61 | | -# into a "pure" jaxpr that takes in and outputs values and no longer has the |
62 | | -# `Read/Write/Accum` effects. |
63 | | - |
64 | | -def discharge_state(jaxpr: core.Jaxpr, consts: Sequence[Any], * , |
65 | | - should_discharge: bool | Sequence[bool] = True, |
66 | | - ) -> tuple[core.Jaxpr, list[Any]]: |
67 | | - """Converts a jaxpr that takes in `Ref`s into one that doesn't.""" |
| 59 | + |
| 60 | +def discharge_state( |
| 61 | + jaxpr: core.Jaxpr, |
| 62 | + consts: Sequence[Any], |
| 63 | + *, |
| 64 | + should_discharge: bool | Sequence[bool] = True, |
| 65 | +) -> tuple[core.Jaxpr, Sequence[Any]]: |
| 66 | + """Converts a stateful jaxpr into a pure one. |
| 67 | +
|
| 68 | + Discharging replaces ``Ref`` inputs with regular values, threads updates |
| 69 | + through the computation, and returns updated ``Ref``s as additional outputs. |
| 70 | +
|
| 71 | + Args: |
| 72 | + jaxpr: A stateful jaxpr with ``Ref`` inputs. |
| 73 | + consts: Constants for the jaxpr. |
| 74 | + should_discharge: Whether to discharge each ``Ref`` input. If a single bool, |
| 75 | + applies to all inputs. |
| 76 | +
|
| 77 | + Returns: |
| 78 | + A tuple of ``(new_jaxpr, new_consts)`` where ``new_jaxpr`` is a jaxpr with |
| 79 | + no ``Read``/``Write``/``Accum`` effects. Discharged ``Ref`` inputs become |
| 80 | + regular value inputs, and their updated values are appended to the outputs. |
| 81 | + """ |
68 | 82 | if isinstance(should_discharge, bool): |
69 | 83 | should_discharge = [should_discharge] * len(jaxpr.invars) |
70 | 84 | in_avals = [v.aval.inner_aval |
@@ -105,37 +119,68 @@ def read(self, v: core.Atom) -> Any: |
105 | 119 | def write(self, v: core.Var, val: Any) -> None: |
106 | 120 | self.env[v] = val |
107 | 121 |
|
| 122 | + |
108 | 123 | class DischargeRule(Protocol): |
109 | 124 |
|
110 | | - def __call__(self, in_avals: Sequence[core.AbstractValue], |
111 | | - out_avals: Sequence[core.AbstractValue], *args: Any, |
112 | | - **params: Any) -> tuple[Sequence[Any | None], Sequence[Any]]: |
113 | | - ... |
| 125 | + def __call__( |
| 126 | + self, |
| 127 | + in_avals: Sequence[core.AbstractValue], |
| 128 | + out_avals: Sequence[core.AbstractValue], |
| 129 | + *args: Any, |
| 130 | + **params: Any, |
| 131 | + ) -> tuple[Sequence[Any | None], Any | Sequence[Any]]: |
| 132 | + """Discharge rule for a primitive. |
| 133 | +
|
| 134 | + See :func:`discharge_state` for an explanation of what discharge means. |
| 135 | +
|
| 136 | + Args: |
| 137 | + in_avals: Input abstract values. |
| 138 | + out_avals: Output abstract values. |
| 139 | + *args: Input values. |
| 140 | + **params: Primitive parameters. |
| 141 | +
|
| 142 | + Returns: |
| 143 | + A tuple of ``(new_invals, new_outvals)`` where: |
| 144 | +
|
| 145 | + * ``new_invals`` contains updated values for discharged ``Ref`` inputs, |
| 146 | + or ``None`` if the input is not a ``Ref`` or was not updated. |
| 147 | + * ``new_outvals`` is the primitive's output. A sequence if the primitive |
| 148 | + has multiple results, otherwise a single value. |
| 149 | + """ |
| 150 | + |
114 | 151 |
|
115 | 152 | _discharge_rules: dict[core.Primitive, DischargeRule] = {} |
116 | 153 |
|
| 154 | + |
| 155 | +def register_discharge_rule(prim: core.Primitive): |
| 156 | + def register(f: DischargeRule): |
| 157 | + _discharge_rules[prim] = f |
| 158 | + |
| 159 | + return register |
| 160 | + |
| 161 | + |
117 | 162 | class PartialDischargeRule(Protocol): |
118 | | - """A partial discharge rule. |
| 163 | + """Discharge rule that supports selective discharging of ``Ref`` inputs. |
119 | 164 |
|
120 | | - Exactly like a discharge rule only it accepts a `should_discharge` |
121 | | - argument that indicates which inputs should be discharged and the |
122 | | - return value returns a tuple of which the first element is the new |
123 | | - inputs or none but only the ones that correspond to `True` entries |
124 | | - in `should_charge`. |
| 165 | + Generalizes :class:`DischargeRule` by accepting a ``should_discharge`` |
| 166 | + argument that specifies which ``Ref`` inputs to discharge. The returned |
| 167 | + ``new_invals`` must contain a non-``None`` value if and only if the |
| 168 | + corresponding ``Ref`` was discharged. |
125 | 169 | """ |
126 | 170 |
|
127 | | - def __call__(self, should_discharge: Sequence[bool], |
| 171 | + def __call__( |
| 172 | + self, |
| 173 | + should_discharge: Sequence[bool], |
128 | 174 | in_avals: Sequence[core.AbstractValue], |
129 | | - out_avals: Sequence[core.AbstractValue], *args: Any, |
130 | | - **params: Any) -> tuple[Sequence[Any | None], Sequence[Any]]: |
| 175 | + out_avals: Sequence[core.AbstractValue], |
| 176 | + *args: Any, |
| 177 | + **params: Any, |
| 178 | + ) -> tuple[Sequence[Any | None], Any | Sequence[Any]]: |
131 | 179 | ... |
132 | 180 |
|
| 181 | + |
133 | 182 | _partial_discharge_rules: dict[core.Primitive, PartialDischargeRule] = {} |
134 | 183 |
|
135 | | -def register_discharge_rule(prim: core.Primitive): |
136 | | - def register(f: DischargeRule): |
137 | | - _discharge_rules[prim] = f |
138 | | - return register |
139 | 184 |
|
140 | 185 | def register_partial_discharge_rule(prim: core.Primitive): |
141 | 186 | def register(f: PartialDischargeRule): |
|
0 commit comments