Port output distribution classes to DynamicEmb#297
Port output distribution classes to DynamicEmb#297z52527 wants to merge 3 commits intoNVIDIA:mainfrom
Conversation
|
Track CI here. |
Greptile OverviewGreptile Summary
Confidence Score: 3/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant Planner as rw_sharding.py
participant OutDist as output_dist.py
participant A2A as SequenceEmbeddingsAllToAll
participant RS as PooledEmbeddingsReduceScatter
participant VRS as VariableBatchPooledEmbeddingsReduceScatter
Planner->>OutDist: create_output_dist(pg, num_features/embedding_dims, codecs)
Note over Planner,OutDist: pg is currently Optional in sharding classes
alt Sequence embeddings
Planner->>OutDist: RwSequenceEmbeddingDist.forward(local_embs, SequenceShardingContext)
OutDist->>A2A: alltoall(lengths/input_splits/output_splits/...)
A2A-->>OutDist: Awaitable[Tensor]
else Pooled embeddings
Planner->>OutDist: RwPooledEmbeddingDist.forward(local_embs, EmbeddingShardingContext?)
alt first call initializes dist module
OutDist->>RS: init reduce-scatter
OutDist->>VRS: init variable-batch reduce-scatter (if ctx.variable_batch_per_feature)
end
alt ctx is None
OutDist->>RS: reduce_scatter(local_embs)
else ctx.variable_batch_per_feature
OutDist->>VRS: reduce_scatter(local_embs, batch_size_per_rank_per_feature, embedding_dims)
else ctx provided
OutDist->>RS: reduce_scatter(local_embs, input_splits=batch_size_per_rank)
end
end
|
| def create_output_dist( | ||
| self, | ||
| device: Optional[torch.device] = None, | ||
| ) -> BaseEmbeddingDist[SequenceShardingContext, torch.Tensor, torch.Tensor]: | ||
| return RwSequenceEmbeddingDist( | ||
| # pyre-fixme[6]: For 1st param expected `ProcessGroup` but got | ||
| # `Optional[ProcessGroup]`. | ||
| self._pg, | ||
| self._get_num_features(), | ||
| device if device is not None else self._device, | ||
| qcomm_codecs_registry=self.qcomm_codecs_registry, | ||
| ) |
There was a problem hiding this comment.
Passes Optional ProcessGroup
create_output_dist() passes self._pg through to RwSequenceEmbeddingDist, but self._pg is typed as Optional[ProcessGroup] (and you’ve added a pyre-fixme to silence it). If self._pg is actually None at runtime (e.g., in non-distributed / single-rank setups), RwSequenceEmbeddingDist.__init__ will call pg.size() and crash. This needs a real guard or to ensure _pg is always non-None before constructing the dist module (same pattern in pooled sharding too).
Also appears in: corelib/dynamicemb/dynamicemb/planner/rw_sharding.py:264-274.
| """ | ||
| if self._dist is None: | ||
| self._create_output_dist_module(sharding_ctx) | ||
|
|
||
| if sharding_ctx is None: | ||
| return cast(PooledEmbeddingsReduceScatter, self._dist)(local_embs) | ||
| elif sharding_ctx.variable_batch_per_feature: | ||
| return cast(VariableBatchPooledEmbeddingsReduceScatter, self._dist)( | ||
| local_embs, | ||
| batch_size_per_rank_per_feature=sharding_ctx.batch_size_per_rank_per_feature, | ||
| embedding_dims=self._embedding_dims, | ||
| ) | ||
| else: | ||
| return cast(PooledEmbeddingsReduceScatter, self._dist)( | ||
| local_embs, | ||
| input_splits=sharding_ctx.batch_size_per_rank, | ||
| ) |
There was a problem hiding this comment.
Missing context for init
When sharding_ctx is None, forward() returns PooledEmbeddingsReduceScatter(local_embs) (line 134-135) but _dist may have been created with a variable-batch module based on the first call’s sharding_ctx (line 131-133 / 151-155). If the first invocation had variable_batch_per_feature=True and a later call passes sharding_ctx=None, this will call a VariableBatchPooledEmbeddingsReduceScatter without its required args, causing a runtime error. Either require sharding_ctx always be provided, or make _dist selection independent of the first call.
Description
#296
Port TorchRec's output distribution classes to DynamicEmb library, enabling future performance optimizations.
Checklist