Skip to content

How is GetOutputShardings supposed to work for PJRT Implementers? #9726

@jameszianxuTT

Description

@jameszianxuTT

We have a custom shardy + stablehlo pipeline manage shard propagation inside our compiler stack. We're having trouble communcating the correct output sharding back to the framework, and cannot find any obvious interface to do so, and wanted to ask what the intended path for this looks like.

To be clear, this is the path our compiler takes:

  1. We get the SHLO in Shardy dialect from the torch-xla framework
  2. We run Shardy to solve the SHLO graph
  3. We lower it to our own custom dialect and execute from there.

We do not convert the SHLO graph back to HLO (as Jax does). After the graph is solved in step 2, we would like to tell torch-xla what the correct output shardings are.

Observed Behavior

In torch_xla, we observe that output shardings are retrieved during compilation in ths path:
torch_xla::XLAGraphExecutor::Compile -> torch_xla::runtime::PjRtComputationClient::Compile -> PjRtComputation constructor -> output_shardings_ = this->executable->GetOutputShardings();

This eventually calls into the base PJRTExecutable implementation of GetOutputShardings.

The mechanism by which output shardings seem to be extracted from the implementer side is by calling PJRT_Executable_OptimizedProgram to retrieve the post-compile MLIR from our PJRT implementation in xla::PjRtCApiExecutable::GetHloModules().

The MLIR is then converted to an xla-internal HLO module construct and output shardings are eventually extracted from that construct inside PjRtExecutable::GetOutputShardings()

How should this work?

This existing path would suggest that the way a PJRT implementer "communicates" output shardings back to the framework post-compilation is by generating IR with output shardings in some format compatible with how they are ingested in XLA. This seems both complex and unidiomatic, because other paths to return data from compilation to the framework involve well defined interfaces in PJRT (like PJRT_Executable_OutputDimensions) and PjRtCApi overrides to use those interfaces and cast the result to xla internal types.

What is the recommended way to communicate output shardings to the framework from a lower-level compiler?

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions