Skip to content

Commit c572f98

Browse files
authored
Fix dp sharding for compute_logits_func (#1212)
1 parent 364d685 commit c572f98

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

tpu_inference/models/vllm/vllm_model_wrapper.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,7 @@ def jit_compute_logits_func(self):
221221
@functools.partial(
222222
jax.jit,
223223
out_shardings=(NamedSharding(self.mesh,
224-
PartitionSpec(None, "model"))),
224+
PartitionSpec("data", "model"))),
225225
)
226226
def compute_logits_func(
227227
params_and_buffers: Any,

0 commit comments

Comments
 (0)