Skip to content

Commit ec409f9

Browse files
authored
Fix image tests. (#1253)
Signed-off-by: Qiliang Cui <derrhein@gmail.com>
1 parent b95ee39 commit ec409f9

File tree

1 file changed

+5
-6
lines changed

1 file changed

+5
-6
lines changed

tests/runner/test_multimodal_manager.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -144,8 +144,8 @@ def test_execute_mm_encoder_single_image(self):
144144
assert passed_pixel_values.dtype == jnp.bfloat16
145145

146146
# Convert torch tensor for comparison
147-
expected_pixel_values = dummy_pixel_values.unsqueeze(0).unsqueeze(
148-
0).to(torch.float32).numpy().astype(jnp.bfloat16)
147+
expected_pixel_values = dummy_pixel_values.unsqueeze(0).to(
148+
torch.float32).numpy().astype(jnp.bfloat16)
149149
np.testing.assert_array_equal(np.asarray(passed_pixel_values),
150150
expected_pixel_values)
151151

@@ -249,11 +249,10 @@ def test_execute_mm_encoder_multiple_images(self):
249249
assert "pixel_values" in kwargs_arg
250250

251251
passed_pixel_values = kwargs_arg['pixel_values']
252-
assert passed_pixel_values.shape == (2, 1, 3, 224, 224)
252+
assert passed_pixel_values.shape == (2, 3, 224, 224)
253253

254-
expected_pixel_values = torch.stack(
255-
[px_1, px_2],
256-
dim=0).unsqueeze(1).to(torch.float32).numpy().astype(jnp.bfloat16)
254+
expected_pixel_values = torch.stack([px_1, px_2], dim=0).to(
255+
torch.float32).numpy().astype(jnp.bfloat16)
257256
np.testing.assert_array_equal(np.asarray(passed_pixel_values),
258257
expected_pixel_values)
259258

0 commit comments

Comments
 (0)