diff --git a/BUILD b/BUILD index c29bf34..d4b28be 100644 --- a/BUILD +++ b/BUILD @@ -457,6 +457,8 @@ multichip_tpu_test( ":jaxite", "@abseil-py//absl/testing:absltest", "@abseil-py//absl/testing:parameterized", + "@jaxite_deps//jax", + "@jaxite_deps//jaxlib", ], ) diff --git a/jaxite/jaxite_bool/pmap_test.py b/jaxite/jaxite_bool/pmap_test.py index 6bb42cb..cf90890 100644 --- a/jaxite/jaxite_bool/pmap_test.py +++ b/jaxite/jaxite_bool/pmap_test.py @@ -3,6 +3,7 @@ This test is separated from the other tests because it can only be run on TPUs. """ +import jax from jaxite.jaxite_bool import bool_params from jaxite.jaxite_bool import jaxite_bool from absl.testing import absltest @@ -36,13 +37,14 @@ def test_pmap_lut3(self) -> None: # For input (a, b, c, tt), # each output is constructed as (tt >> 0b{a, b, c}) & 1 + num_devices = jax.local_device_count() inputs = [ (ct_true, ct_false, ct_true, 221), # false (ct_true, ct_true, ct_false, 221), # true # Forge only gives tests 2 cores, so we can't test parallelism beyond # two operations at once. # 2: (ct_false, ct_false, ct_false, 220), # false - ] + ][:num_devices] outputs = jaxite_bool.pmap_lut3( inputs, self.server_key_set, self.boolean_params ) @@ -50,7 +52,7 @@ def test_pmap_lut3(self) -> None: output_cleartexts = [ jaxite_bool.decrypt(value, self.client_key_set) for value in outputs ] - expected = [False, True] + expected = [False, True][:num_devices] self.assertEqual(expected, output_cleartexts) def test_pmap_lut2(self) -> None: @@ -59,12 +61,13 @@ def test_pmap_lut2(self) -> None: # For input (a, b, tt), # each output is constructed as (tt >> 0b{a, b}) & 1 + num_devices = jax.local_device_count() inputs = [ (ct_true, ct_false, 13), # false (ct_true, ct_true, 13), # true # Forge only gives tests 2 cores, so we can't test parallelism beyond # two operations at once. - ] + ][:num_devices] outputs = jaxite_bool.pmap_lut2( inputs, self.server_key_set, self.boolean_params ) @@ -72,7 +75,7 @@ def test_pmap_lut2(self) -> None: output_cleartexts = [ jaxite_bool.decrypt(value, self.client_key_set) for value in outputs ] - expected = [False, True] + expected = [False, True][:num_devices] self.assertEqual(expected, output_cleartexts)