-
Notifications
You must be signed in to change notification settings - Fork 149
Refactor advanced subtensor #1756
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Co-authored-by: ricardoV94 <28983449+ricardoV94@users.noreply.github.com>
… handling Co-authored-by: ricardoV94 <28983449+ricardoV94@users.noreply.github.com>
…ation Co-authored-by: ricardoV94 <28983449+ricardoV94@users.noreply.github.com>
…ensor approach Co-authored-by: ricardoV94 <28983449+ricardoV94@users.noreply.github.com>
… interface, store expected_inputs_len Co-authored-by: ricardoV94 <28983449+ricardoV94@users.noreply.github.com>
3ce54c7 to
92f61ed
Compare
92f61ed to
a6cb68d
Compare
a4a305c to
546100c
Compare
546100c to
4b02064
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is looking pretty good.
Mostly I want to believe there is still room to simplify things / reuse code.
This is also a good opportunity to simplify the idx_list. There's no reason to use ScalarTypes in the dummy slices, and it's complicating our equality and hashing.
What about using simple integers to indicate what is the role of each index variable?
old_idx_list = (ps.int64, slice(ps.int64, None, Non), ps.int64, slice(3, ps.int64, 4))
new_idx_list = (0, slice(1, None, None), 2, slice(3, None, 4))Having the indices could probably come in handy anyway. With this we shouldn't need a custom hash / eq, we can just use the default one from __props__.
| return axis | ||
|
|
||
|
|
||
| def reconstruct_indices(idx_list, tensor_inputs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What does this helper do? Sounds like it returns dummy slice (not with actual arguments) or tensor variable. Shouldn't the slice variables be placed inside the slices. If so, there's already a helper that does that IIRC
| z_broad[k] | ||
| and not same_shape(xi, y, dim_x=k, dim_y=k) | ||
| and shape_of[y][k] != 1 | ||
| and shape_of[xi][k] == 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If this fixes a bug, it will need a specific regression test, and be in a separate commit
| else: | ||
| x, y, *idxs = node.inputs | ||
| x, y = node.inputs[0], node.inputs[1] | ||
| tensor_inputs = node.inputs[2:] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't like the name tensor_inputs, x, y are also tensor and inputs. Use index_variables?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This applies elsewhere
| if isinstance(node.op, AdvancedSubtensor): | ||
| new_out = node.op(raveled_x, *new_idxs) | ||
| # Create new AdvancedSubtensor with updated idx_list | ||
| new_idx_list = list(node.op.idx_list) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we use some helper to do this? Isn't there already something for when you do x.__getitem__?
| # must already be raveled in the original graph, so we don't need to do anything to it | ||
| new_out = node.op(raveled_x, y, *new_idxs) | ||
| # But we must reshape the output to math the original shape | ||
| new_out = AdvancedIncSubtensor( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You should use type(op) so that subclasses are respected. It may also make sense to add a method to these indexing Ops like op.with_new_indices() that clones itself with a new idx_list. Maybe that will be the one that handles creating the new idx_list, instead of having to be here in the rewrite.
| def __init__(self, idx_list): | ||
| """ | ||
| Initialize AdvancedSubtensor with index list. | ||
| Parameters | ||
| ---------- | ||
| idx_list : tuple | ||
| List of indices where slices are stored as-is, | ||
| and numerical indices are replaced by their types. | ||
| """ | ||
| self.idx_list = tuple( | ||
| index_vars_to_types(idx, allow_advanced=True) for idx in idx_list | ||
| ) | ||
| # Store expected number of tensor inputs for validation | ||
| self.expected_inputs_len = len( | ||
| get_slice_elements(self.idx_list, lambda entry: isinstance(entry, Type)) | ||
| ) | ||
|
|
||
| def __hash__(self): | ||
| msg = [] | ||
| for entry in self.idx_list: | ||
| if isinstance(entry, slice): | ||
| msg += [(entry.start, entry.stop, entry.step)] | ||
| else: | ||
| msg += [entry] | ||
|
|
||
| idx_list = tuple(msg) | ||
| return hash((type(self), idx_list)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This already exists in Subtensor? If so create a BaseSubtensor class that handles idx_list and hash/equality based on it.
Make all Subtensor operations inherit from it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note this advice may make no sense if we simplify the idx_list to not need custom hash / eq
| ) | ||
| else: | ||
| return vectorize_node_fallback(op, node, batch_x, *batch_idxs) | ||
| # With the new interface, all inputs are tensors, so Blockwise can handle them |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Comment should not mention a specific time period. Previous status is not relevant here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also we still want to avoid Blockwise eagerly if we can
| def astype(self, dtype): | ||
| return pt.basic.cast(self, dtype) | ||
|
|
||
| def _getitem_with_newaxis(self, args): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we move all this logic to subtensor? And just import/use it from here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I mean all the subtensor specific logic also in __getitem__
|
|
||
| # Check if we can return the view directly if all new_args are full slices | ||
| # We can't do arg == slice(None, None, None) as in | ||
| # Python 2.7, this call __lt__ if we have a slice |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We're long past python 2.7, so we can just check if arg == slice(None)
| pattern = [] | ||
| new_args = [] | ||
| for arg in args: | ||
| if arg is np.newaxis or arg is NoneConst: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't have to worry about NoneConst no more, just check for newaxis. Also because this may be used manytimes, let's jsut do if arg is None, instead of the cute newaxis from np
| pattern.append("x") | ||
| new_args.append(slice(None)) | ||
| else: | ||
| # Check for boolean index which consumes multiple dimensions |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is probably right, but why does it matter that boolean indexing consumes multiple dimensions? Aren't we doing expand_dims where there was None -> replace new_axis by None slice -> index again?
|
|
||
| with pytest.raises(TypeError): | ||
| index_vars_to_types(1) | ||
| # Integers are now allowed |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
now is time specific, remove comment
| (11, 7, 5, 3, 5), | ||
| (2,), | ||
| True, | ||
| marks=pytest.mark.xfail(raises=NotImplementedError), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
there are some comments saying not supported that should be removed in between the parametrizations
|
Just wanted to repeat, this is looking great. Thanks so far @jaanerik I'm being picky because indexing is a pretty fundamental operation, so want to make sure we get it right this time. |
Description
Allows vectorizing AdvancedSetSubtensor.
Gemini picks up where Copilot left off.
Related Issue
Checklist
Type of change