Skip to content

Port output distribution classes to DynamicEmb#297

Open
z52527 wants to merge 3 commits intoNVIDIA:mainfrom
z52527:fea-output-dist
Open

Port output distribution classes to DynamicEmb#297
z52527 wants to merge 3 commits intoNVIDIA:mainfrom
z52527:fea-output-dist

Conversation

@z52527
Copy link
Collaborator

@z52527 z52527 commented Feb 6, 2026

Description

#296
Port TorchRec's output distribution classes to DynamicEmb library, enabling future performance optimizations.

Checklist

  • I am familiar with the Contributing Guidelines.
  • New or existing tests cover these changes.
  • The documentation is up to date with these changes.

@z52527
Copy link
Collaborator Author

z52527 commented Feb 6, 2026

Track CI here.

@z52527 z52527 requested a review from shijieliu February 6, 2026 10:23
@greptile-apps
Copy link

greptile-apps bot commented Feb 6, 2026

Greptile Overview

Greptile Summary

  • Adds DynamicEmb equivalents of TorchRec RW output distribution modules for sequence (AlltoAll) and pooled (reduce-scatter) embeddings in output_dist.py.
  • Hooks these output distribution modules into DynamicEmb’s RW sharding implementations in planner/rw_sharding.py via new create_output_dist() overrides.
  • Enables future performance work by keeping distribution logic local to DynamicEmb while still leveraging TorchRec comm primitives and qcomm codec registry integration.

Confidence Score: 3/5

  • This PR is mergeable after fixing two concrete runtime-crash paths in the new output distribution wiring.
  • The changes are localized and follow TorchRec patterns, but (1) the sharding code can pass a None process group into dist constructors (leading to immediate crashes), and (2) pooled output dist can be initialized in variable-batch mode and later invoked without required ctx args, also crashing.
  • corelib/dynamicemb/dynamicemb/output_dist.py, corelib/dynamicemb/dynamicemb/planner/rw_sharding.py

Important Files Changed

Filename Overview
corelib/dynamicemb/dynamicemb/output_dist.py Adds RW output distribution modules for sequence (AlltoAll) and pooled (reduce-scatter) embeddings; there is a runtime hazard if the first call initializes variable-batch reduce-scatter and later calls omit sharding_ctx.
corelib/dynamicemb/dynamicemb/planner/rw_sharding.py Wires new output distribution classes into RW dynamic embedding shardings; currently passes Optional process group into dist constructors (can crash when pg is None) and leaves TODO/commented uncertainty.

Sequence Diagram

sequenceDiagram
  participant Planner as rw_sharding.py
  participant OutDist as output_dist.py
  participant A2A as SequenceEmbeddingsAllToAll
  participant RS as PooledEmbeddingsReduceScatter
  participant VRS as VariableBatchPooledEmbeddingsReduceScatter

  Planner->>OutDist: create_output_dist(pg, num_features/embedding_dims, codecs)
  Note over Planner,OutDist: pg is currently Optional in sharding classes

  alt Sequence embeddings
    Planner->>OutDist: RwSequenceEmbeddingDist.forward(local_embs, SequenceShardingContext)
    OutDist->>A2A: alltoall(lengths/input_splits/output_splits/...)
    A2A-->>OutDist: Awaitable[Tensor]
  else Pooled embeddings
    Planner->>OutDist: RwPooledEmbeddingDist.forward(local_embs, EmbeddingShardingContext?)
    alt first call initializes dist module
      OutDist->>RS: init reduce-scatter
      OutDist->>VRS: init variable-batch reduce-scatter (if ctx.variable_batch_per_feature)
    end
    alt ctx is None
      OutDist->>RS: reduce_scatter(local_embs)
    else ctx.variable_batch_per_feature
      OutDist->>VRS: reduce_scatter(local_embs, batch_size_per_rank_per_feature, embedding_dims)
    else ctx provided
      OutDist->>RS: reduce_scatter(local_embs, input_splits=batch_size_per_rank)
    end
  end
Loading

Copy link

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2 files reviewed, 2 comments

Edit Code Review Agent Settings | Greptile

Comment on lines +155 to +166
def create_output_dist(
self,
device: Optional[torch.device] = None,
) -> BaseEmbeddingDist[SequenceShardingContext, torch.Tensor, torch.Tensor]:
return RwSequenceEmbeddingDist(
# pyre-fixme[6]: For 1st param expected `ProcessGroup` but got
# `Optional[ProcessGroup]`.
self._pg,
self._get_num_features(),
device if device is not None else self._device,
qcomm_codecs_registry=self.qcomm_codecs_registry,
)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Passes Optional ProcessGroup

create_output_dist() passes self._pg through to RwSequenceEmbeddingDist, but self._pg is typed as Optional[ProcessGroup] (and you’ve added a pyre-fixme to silence it). If self._pg is actually None at runtime (e.g., in non-distributed / single-rank setups), RwSequenceEmbeddingDist.__init__ will call pg.size() and crash. This needs a real guard or to ensure _pg is always non-None before constructing the dist module (same pattern in pooled sharding too).

Also appears in: corelib/dynamicemb/dynamicemb/planner/rw_sharding.py:264-274.

Comment on lines +130 to +146
"""
if self._dist is None:
self._create_output_dist_module(sharding_ctx)

if sharding_ctx is None:
return cast(PooledEmbeddingsReduceScatter, self._dist)(local_embs)
elif sharding_ctx.variable_batch_per_feature:
return cast(VariableBatchPooledEmbeddingsReduceScatter, self._dist)(
local_embs,
batch_size_per_rank_per_feature=sharding_ctx.batch_size_per_rank_per_feature,
embedding_dims=self._embedding_dims,
)
else:
return cast(PooledEmbeddingsReduceScatter, self._dist)(
local_embs,
input_splits=sharding_ctx.batch_size_per_rank,
)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing context for init

When sharding_ctx is None, forward() returns PooledEmbeddingsReduceScatter(local_embs) (line 134-135) but _dist may have been created with a variable-batch module based on the first call’s sharding_ctx (line 131-133 / 151-155). If the first invocation had variable_batch_per_feature=True and a later call passes sharding_ctx=None, this will call a VariableBatchPooledEmbeddingsReduceScatter without its required args, causing a runtime error. Either require sharding_ctx always be provided, or make _dist selection independent of the first call.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant