Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -457,6 +457,8 @@ multichip_tpu_test(
":jaxite",
"@abseil-py//absl/testing:absltest",
"@abseil-py//absl/testing:parameterized",
"@jaxite_deps//jax",
"@jaxite_deps//jaxlib",
],
)

Expand Down
11 changes: 7 additions & 4 deletions jaxite/jaxite_bool/pmap_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
This test is separated from the other tests because it can only be run on TPUs.
"""

import jax
from jaxite.jaxite_bool import bool_params
from jaxite.jaxite_bool import jaxite_bool
from absl.testing import absltest
Expand Down Expand Up @@ -36,21 +37,22 @@ def test_pmap_lut3(self) -> None:

# For input (a, b, c, tt),
# each output is constructed as (tt >> 0b{a, b, c}) & 1
num_devices = jax.local_device_count()
inputs = [
(ct_true, ct_false, ct_true, 221), # false
(ct_true, ct_true, ct_false, 221), # true
# Forge only gives tests 2 cores, so we can't test parallelism beyond
# two operations at once.
# 2: (ct_false, ct_false, ct_false, 220), # false
]
][:num_devices]
outputs = jaxite_bool.pmap_lut3(
inputs, self.server_key_set, self.boolean_params
)

output_cleartexts = [
jaxite_bool.decrypt(value, self.client_key_set) for value in outputs
]
expected = [False, True]
expected = [False, True][:num_devices]
self.assertEqual(expected, output_cleartexts)

def test_pmap_lut2(self) -> None:
Expand All @@ -59,20 +61,21 @@ def test_pmap_lut2(self) -> None:

# For input (a, b, tt),
# each output is constructed as (tt >> 0b{a, b}) & 1
num_devices = jax.local_device_count()
inputs = [
(ct_true, ct_false, 13), # false
(ct_true, ct_true, 13), # true
# Forge only gives tests 2 cores, so we can't test parallelism beyond
# two operations at once.
]
][:num_devices]
outputs = jaxite_bool.pmap_lut2(
inputs, self.server_key_set, self.boolean_params
)

output_cleartexts = [
jaxite_bool.decrypt(value, self.client_key_set) for value in outputs
]
expected = [False, True]
expected = [False, True][:num_devices]
self.assertEqual(expected, output_cleartexts)


Expand Down
Loading