-
Notifications
You must be signed in to change notification settings - Fork 562
Description
🚀 Feature
Introduce an optional priority parameter to mark_sharding that allows users to explicitly guide sharding propagation when multiple operands are sharded along the same dimension.
Motivation
When multiple tensors are sharded along the same dimension, sharding propagation is currently decided based on Shardy’s internal priority rules.
While this works in many cases, it does not allow users to explicitly express intent when those defaults lead to suboptimal or unexpected results.
This becomes especially problematic in cases like embedding weights vs. runtime indices, where one operand’s sharding should clearly dominate propagation decisions, but cannot be enforced today.
Providing a way to surface user intent improves correctness, predictability, and debuggability of sharding behavior.
Pitch
Extend mark_sharding with an optional priority argument:
xs.mark_sharding(tensor, mesh, partition_spec, priority=tuple)
The priority is stored in torch_xla, forwarded as a frontend attribute, and lowered into sdy.sharding with per-dimension priority annotations.
Shardy’s existing user op priority pass then consumes this information to drive propagation according to the specified priorities, without introducing new propagation logic.
This keeps the design minimal while leveraging existing Shardy infrastructure.
Alternatives
Additional context
- The proposal is backward compatible: if priority is not specified, existing behavior remains unchanged.
- Priority is encoded as a simple per-dimension attribute, making it easy to inspect and debug in IR.
- This approach has been validated internally and integrates cleanly with Shardy’s current propagation passes.
- An example implementation has been prepared for illustration and will be shared as a PR.