Skip to content

Commit 61b9cbe

Browse files
authored
Fixing a few failures in tests/test_quantization.py. (#1258)
1 parent ec409f9 commit 61b9cbe

File tree

1 file changed

+3
-0
lines changed

1 file changed

+3
-0
lines changed

tests/test_quantization.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,8 @@ def setUp(self):
112112
self.mesh = Mesh(jax.devices(), ('model', ))
113113
self.rng = jax.random.PRNGKey(0)
114114
self.model = SimpleModel(rngs=nnx.Rngs(0))
115+
self.model.vllm_config = MagicMock()
116+
self.model.vllm_config.model_config.use_mla = False
115117

116118
self.qwix_config = [
117119
{
@@ -131,6 +133,7 @@ def test_quantization_call_with_correct_args(self, mock_quantize_model):
131133
"""Test that qwix.quantize_model is called with the correct arguments."""
132134
quantized_model_mock = MagicMock(spec=nnx.Module)
133135
mock_quantize_model.return_value = quantized_model_mock
136+
self.model.vllm_config.sharding_config.total_dp_size = 1
134137

135138
with patch(
136139
"tpu_inference.models.jax.utils.quantization.quantization_utils.init_logger",

0 commit comments

Comments
 (0)