Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
145 changes: 85 additions & 60 deletions BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,14 @@ py_library(
deps = [
# copybara: xprof_analysis_client # buildcleaner: keep
# copybara: xprof_session # buildcleaner: keep
"@jaxite_deps_absl//:pkg:app",
"@jaxite_deps_gmpy2//:pkg",
"@jaxite_deps_jax//:pkg",
"@jaxite_deps_jaxlib//:pkg",
# copybara: jax/experimental:pallas_lib
# copybara: jax/experimental:pallas_tpu
"@jaxite_deps_numpy//:pkg",
"@jaxite_deps_pandas//:pkg",
],
)

Expand Down Expand Up @@ -180,60 +182,6 @@ tpu_test(
],
)

tpu_test(
name = "jaxite_word_ntt_test",
size = "large",
timeout = "eternal",
srcs = ["jaxite/jaxite_word/ntt_test.py"],
shard_count = 3,
deps = [
":jaxite",
# copybara: xprof_analysis_client # buildcleaner: keep
# copybara: xprof_session # buildcleaner: keep
"@com_google_absl_py//absl/testing:absltest",
"@com_google_absl_py//absl/testing:parameterized",
"@jaxite_deps_gmpy2//:pkg",
"@jaxite_deps_jax//:pkg",
"@jaxite_deps_jaxlib//:pkg",
"@jaxite_deps_numpy//:pkg",
],
)

tpu_test(
name = "jaxite_word_sub_test",
size = "large",
timeout = "eternal",
srcs = ["jaxite/jaxite_word/sub_test.py"],
shard_count = 3,
deps = [
":jaxite",
# copybara: xprof_analysis_client # buildcleaner: keep
# copybara: xprof_session # buildcleaner: keep
"@com_google_absl_py//absl/testing:absltest",
"@com_google_absl_py//absl/testing:parameterized",
"@jaxite_deps_gmpy2//:pkg",
"@jaxite_deps_jax//:pkg",
"@jaxite_deps_jaxlib//:pkg",
"@jaxite_deps_numpy//:pkg",
],
)

tpu_test(
name = "add_test",
size = "large",
timeout = "eternal",
srcs = ["jaxite/jaxite_word/add_test.py"],
shard_count = 3,
deps = [
":jaxite",
"@com_google_absl_py//absl/testing:absltest",
"@com_google_absl_py//absl/testing:parameterized",
"@jaxite_deps_jax//:pkg",
"@jaxite_deps_jaxlib//:pkg",
"@jaxite_deps_numpy//:pkg",
],
)

cpu_gpu_tpu_test(
name = "decomposition_test",
size = "small",
Expand Down Expand Up @@ -440,7 +388,7 @@ py_test(
name = "rns_test",
size = "small",
timeout = "moderate",
srcs = ["jaxite/jaxite_ckks/rns_test.py"],
srcs = ["jaxite/jaxite_word/rns_test.py"],
deps = [
":jaxite",
":test_utils",
Expand All @@ -454,20 +402,97 @@ py_test(
],
)

gpu_tpu_test(
name = "ckks_test",
py_test(
name = "ntt_sm_test",
size = "small",
timeout = "moderate",
srcs = ["jaxite/jaxite_ckks/ckks_test.py"],
srcs = ["jaxite/jaxite_word/ntt_sm_test.py"],
deps = [
":jaxite",
":test_utils",
"@com_google_absl_py//absl/testing:absltest",
"@com_google_absl_py//absl/testing:parameterized",
"@jaxite_deps_jax//:pkg",
"@jaxite_deps_jaxlib//:pkg",
"@jaxite_deps_numpy//:pkg",
],
)

py_test(
name = "ntt_mm_test",
size = "small",
timeout = "moderate",
srcs = ["jaxite/jaxite_word/ntt_mm_test.py"],
deps = [
":jaxite",
":test_utils",
"@com_google_absl_py//absl/testing:absltest",
"@com_google_absl_py//absl/testing:parameterized",
"@jaxite_deps_jax//:pkg",
"@jaxite_deps_jaxlib//:pkg",
"@jaxite_deps_numpy//:pkg",
],
)

py_test(
name = "finite_field_test",
size = "small",
timeout = "moderate",
srcs = ["jaxite/jaxite_word/finite_field_test.py"],
deps = [
":jaxite",
"@com_google_absl_py//absl/testing:absltest",
"@com_google_absl_py//absl/testing:parameterized",
"@jaxite_deps_jax//:pkg",
"@jaxite_deps_jaxlib//:pkg",
"@jaxite_deps_numpy//:pkg",
],
)

py_test(
name = "ckks_ctx_test",
size = "small",
timeout = "moderate",
srcs = ["jaxite/jaxite_word/ckks_ctx_test.py"],
deps = [
":jaxite",
":test_utils",
"@com_google_absl_py//absl/testing:absltest",
"@com_google_absl_py//absl/testing:parameterized",
"@jaxite_deps_jax//:pkg",
"@jaxite_deps_jaxlib//:pkg",
"@jaxite_deps_numpy//:pkg",
],
)

py_test(
name = "ciphertext_test",
size = "small",
timeout = "moderate",
srcs = ["jaxite/jaxite_word/ciphertext_test.py"],
deps = [
":jaxite",
":test_utils",
"@com_google_absl_py//absl/testing:absltest",
"@com_google_absl_py//absl/testing:parameterized",
"@jaxite_deps_jax//:pkg",
"@jaxite_deps_jaxlib//:pkg",
"@jaxite_deps_numpy//:pkg",
],
)

py_test(
name = "bconv_test",
size = "small",
timeout = "moderate",
srcs = ["jaxite/jaxite_word/bconv_test.py"],
deps = [
":jaxite",
":test_utils",
"@com_google_absl_py//absl/testing:absltest",
"@com_google_absl_py//absl/testing:parameterized",
"@jaxite_deps_hypothesis//:pkg",
"@jaxite_deps_jax//:pkg",
"@jaxite_deps_jaxlib//:pkg",
"@jaxite_deps_numpy//:pkg",
"@jaxite_deps_parameterized//:pkg",
],
)
Loading
Loading