From a8e6c276273bbe2cb7f2508ba07d1d192a467c13 Mon Sep 17 00:00:00 2001 From: Janmajayamall <40303619+Janmajayamall@users.noreply.github.com> Date: Mon, 8 Jul 2024 16:39:31 +0530 Subject: [PATCH] Add version 0.1.0 (#2) Add version 0.1.0. Includes: - FHE Uint8, FHE bool, div by 0 error flag, and mux APIs - non-interactive mpc for <= 8 parties - interactive mpc for <= 8 parties - multi-party decryption - necessary examples --- .gitignore | 1 + Cargo.lock | 517 +++++- Cargo.toml | 51 +- README.md | 86 + benches/modulus.rs | 152 ++ benches/ntt.rs | 151 ++ examples/bomberman.rs | 178 ++ examples/div_by_zero.rs | 126 ++ examples/if_and_else.rs | 107 ++ examples/interactive_fheuint8.rs | 180 ++ examples/meeting_friends.rs | 150 ++ examples/non_interactive_fheuint8.rs | 177 ++ src/backend.rs | 163 -- src/backend/mod.rs | 141 ++ src/backend/modulus_u64.rs | 337 ++++ src/backend/power_of_2.rs | 112 ++ src/backend/word_size.rs | 124 ++ src/bool/evaluator.rs | 2323 ++++++++++++++++++++++++++ src/bool/keys.rs | 1559 +++++++++++++++++ src/bool/mod.rs | 266 +++ src/bool/mp_api.rs | 697 ++++++++ src/bool/ni_mp_api.rs | 459 +++++ src/bool/parameters.rs | 738 ++++++++ src/bool/print_noise.rs | 1020 +++++++++++ src/decomposer.rs | 383 ++++- src/lib.rs | 132 +- src/lwe.rs | 395 +++-- src/main.rs | 4 +- src/multi_party.rs | 286 ++++ src/ntt.rs | 325 ++-- src/num.rs | 3 - src/pbs.rs | 482 ++++++ src/random.rs | 215 ++- src/rgsw.rs | 466 ------ src/rgsw/keygen.rs | 677 ++++++++ src/rgsw/mod.rs | 982 +++++++++++ src/rgsw/runtime.rs | 1063 ++++++++++++ src/shortint/enc_dec.rs | 370 ++++ src/shortint/mod.rs | 294 ++++ src/shortint/ops.rs | 356 ++++ src/utils.rs | 270 ++- 41 files changed, 15362 insertions(+), 1156 deletions(-) create mode 100644 README.md create mode 100644 benches/modulus.rs create mode 100644 benches/ntt.rs create mode 100644 examples/bomberman.rs create mode 100644 examples/div_by_zero.rs create mode 100644 examples/if_and_else.rs create mode 100644 examples/interactive_fheuint8.rs create mode 100644 examples/meeting_friends.rs create mode 100644 examples/non_interactive_fheuint8.rs delete mode 100644 src/backend.rs create mode 100644 src/backend/mod.rs create mode 100644 src/backend/modulus_u64.rs create mode 100644 src/backend/power_of_2.rs create mode 100644 src/backend/word_size.rs create mode 100644 src/bool/evaluator.rs create mode 100644 src/bool/keys.rs create mode 100644 src/bool/mod.rs create mode 100644 src/bool/mp_api.rs create mode 100644 src/bool/ni_mp_api.rs create mode 100644 src/bool/parameters.rs create mode 100644 src/bool/print_noise.rs create mode 100644 src/multi_party.rs delete mode 100644 src/num.rs create mode 100644 src/pbs.rs delete mode 100644 src/rgsw.rs create mode 100644 src/rgsw/keygen.rs create mode 100644 src/rgsw/mod.rs create mode 100644 src/rgsw/runtime.rs create mode 100644 src/shortint/enc_dec.rs create mode 100644 src/shortint/mod.rs create mode 100644 src/shortint/ops.rs diff --git a/.gitignore b/.gitignore index ea8c4bf..5f084eb 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,2 @@ /target +/.obsidian \ No newline at end of file diff --git a/Cargo.lock b/Cargo.lock index cb828f8..47e2125 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2,6 +2,27 @@ # It is not intended for manual editing. version = 3 +[[package]] +name = "aho-corasick" +version = "1.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e60d3430d3a69478ad0993f19238d2df97c507009a52b3c10addcd7f6bcb916" +dependencies = [ + "memchr", +] + +[[package]] +name = "anes" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4b46cbb362ab8752921c97e041f5e366ee6297bd428a31275b9fcf1e380f7299" + +[[package]] +name = "anstyle" +version = "1.0.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "038dfcf04a5feb68e9c60b21c9625a54c2c0616e79b72b0fd87075a056ae1d1b" + [[package]] name = "autocfg" version = "1.2.0" @@ -9,16 +30,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f1fdabc7756949593fe60f30ec81974b613357de856987752631dea1e3394c80" [[package]] -name = "bin-rs" -version = "0.1.0" -dependencies = [ - "itertools", - "num-bigint-dig", - "num-traits", - "rand", - "rand_chacha", - "rand_distr", -] +name = "bumpalo" +version = "3.16.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "79296716171880943b8470b5f8d03aa55eb2e645a4874bdbb28adb49162e012c" [[package]] name = "byteorder" @@ -26,12 +41,137 @@ version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" +[[package]] +name = "cast" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5" + [[package]] name = "cfg-if" version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" +[[package]] +name = "ciborium" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42e69ffd6f0917f5c029256a24d0161db17cea3997d185db0d35926308770f0e" +dependencies = [ + "ciborium-io", + "ciborium-ll", + "serde", +] + +[[package]] +name = "ciborium-io" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "05afea1e0a06c9be33d539b876f1ce3692f4afea2cb41f740e7743225ed1c757" + +[[package]] +name = "ciborium-ll" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "57663b653d948a338bfb3eeba9bb2fd5fcfaecb9e199e87e1eda4d9e8b240fd9" +dependencies = [ + "ciborium-io", + "half", +] + +[[package]] +name = "clap" +version = "4.5.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "90bc066a67923782aa8515dbaea16946c5bcc5addbd668bb80af688e53e548a0" +dependencies = [ + "clap_builder", +] + +[[package]] +name = "clap_builder" +version = "4.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ae129e2e766ae0ec03484e609954119f123cc1fe650337e155d03b022f24f7b4" +dependencies = [ + "anstyle", + "clap_lex", +] + +[[package]] +name = "clap_lex" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "98cc8fbded0c607b7ba9dd60cd98df59af97e84d24e49c8557331cfc26d301ce" + +[[package]] +name = "criterion" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f2b12d017a929603d80db1831cd3a24082f8137ce19c69e6447f54f5fc8d692f" +dependencies = [ + "anes", + "cast", + "ciborium", + "clap", + "criterion-plot", + "is-terminal", + "itertools 0.10.5", + "num-traits", + "once_cell", + "oorandom", + "plotters", + "rayon", + "regex", + "serde", + "serde_derive", + "serde_json", + "tinytemplate", + "walkdir", +] + +[[package]] +name = "criterion-plot" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6b50826342786a51a89e2da3a28f1c32b06e387201bc2d19791f622c673706b1" +dependencies = [ + "cast", + "itertools 0.10.5", +] + +[[package]] +name = "crossbeam-deque" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "613f8cc01fe9cf1a3eb3d7f488fd2fa8388403e97039e2f73692932e291a770d" +dependencies = [ + "crossbeam-epoch", + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-epoch" +version = "0.9.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b82ac4a3c2ca9c3460964f020e1402edd5753411d7737aa39c3714ad1b5420e" +dependencies = [ + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-utils" +version = "0.8.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "22ec99545bb0ed0ea7bb9b8e1e9122ea386ff8a48c0922e43f36d45ab09e0e80" + +[[package]] +name = "crunchy" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7a81dae078cea95a014a339291cec439d2f232ebe854a9d672b796c6afafa9b7" + [[package]] name = "either" version = "1.11.0" @@ -49,6 +189,42 @@ dependencies = [ "wasi", ] +[[package]] +name = "half" +version = "2.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6dd08c532ae367adf81c312a4580bc67f1d0fe8bc9c460520283f4c0ff277888" +dependencies = [ + "cfg-if", + "crunchy", +] + +[[package]] +name = "hermit-abi" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d231dfb89cfffdbc30e7fc41579ed6066ad03abda9e567ccafae602b97ec5024" + +[[package]] +name = "is-terminal" +version = "0.4.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f23ff5ef2b80d608d61efee834934d862cd92461afc0560dedf493e4c033738b" +dependencies = [ + "hermit-abi", + "libc", + "windows-sys", +] + +[[package]] +name = "itertools" +version = "0.10.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b0fd2260e829bddf4cb6ea802289de2f86d6a7a690192fbe91b3f46e0f2c8473" +dependencies = [ + "either", +] + [[package]] name = "itertools" version = "0.12.1" @@ -58,6 +234,21 @@ dependencies = [ "either", ] +[[package]] +name = "itoa" +version = "1.0.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "49f1f14873335454500d59611f1cf4a4b0f786f9ac11f4312a78e4cf2566695b" + +[[package]] +name = "js-sys" +version = "0.3.69" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "29c15563dc2726973df627357ce0c9ddddbea194836909d655df6a75d2cf296d" +dependencies = [ + "wasm-bindgen", +] + [[package]] name = "lazy_static" version = "1.4.0" @@ -79,6 +270,18 @@ version = "0.2.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4ec2a862134d2a7d32d7983ddcdd1c4923530833c9f2ea1a44fc5fa473989058" +[[package]] +name = "log" +version = "0.4.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "90ed8c1e510134f979dbc4f070f87d4313098b704861a105fe34231c70a3901c" + +[[package]] +name = "memchr" +version = "2.7.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6c8640c5d730cb13ebd907d8d04b52f55ac9a2eec55b440c8892f40d56c76c1d" + [[package]] name = "num-bigint-dig" version = "0.8.4" @@ -126,6 +329,59 @@ dependencies = [ "libm", ] +[[package]] +name = "once_cell" +version = "1.19.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92" + +[[package]] +name = "oorandom" +version = "11.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0ab1bc2a289d34bd04a330323ac98a1b4bc82c9d9fcb1e66b63caa84da26b575" + +[[package]] +name = "phantom-zone" +version = "0.1.0" +dependencies = [ + "criterion", + "itertools 0.12.1", + "num-bigint-dig", + "num-traits", + "rand", + "rand_chacha", + "rand_distr", +] + +[[package]] +name = "plotters" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a15b6eccb8484002195a3e44fe65a4ce8e93a625797a063735536fd59cb01cf3" +dependencies = [ + "num-traits", + "plotters-backend", + "plotters-svg", + "wasm-bindgen", + "web-sys", +] + +[[package]] +name = "plotters-backend" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "414cec62c6634ae900ea1c56128dfe87cf63e7caece0852ec76aba307cebadb7" + +[[package]] +name = "plotters-svg" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "81b30686a7d9c3e010b84284bdd26a29f2138574f52f5eb6f794fc0ad924e705" +dependencies = [ + "plotters-backend", +] + [[package]] name = "ppv-lite86" version = "0.2.17" @@ -190,6 +446,70 @@ dependencies = [ "rand", ] +[[package]] +name = "rayon" +version = "1.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b418a60154510ca1a002a752ca9714984e21e4241e804d32555251faf8b78ffa" +dependencies = [ + "either", + "rayon-core", +] + +[[package]] +name = "rayon-core" +version = "1.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1465873a3dfdaa8ae7cb14b4383657caab0b3e8a0aa9ae8e04b044854c8dfce2" +dependencies = [ + "crossbeam-deque", + "crossbeam-utils", +] + +[[package]] +name = "regex" +version = "1.10.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c117dbdfde9c8308975b6a18d71f3f385c89461f7b3fb054288ecf2a2058ba4c" +dependencies = [ + "aho-corasick", + "memchr", + "regex-automata", + "regex-syntax", +] + +[[package]] +name = "regex-automata" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "86b83b8b9847f9bf95ef68afb0b8e6cdb80f498442f5179a29fad448fcc1eaea" +dependencies = [ + "aho-corasick", + "memchr", + "regex-syntax", +] + +[[package]] +name = "regex-syntax" +version = "0.8.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "adad44e29e4c806119491a7f06f03de4d1af22c3a680dd47f1e6e179439d1f56" + +[[package]] +name = "ryu" +version = "1.0.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f3cb5ba0dc43242ce17de99c180e96db90b235b8a9fdc9543c96d2209116bd9f" + +[[package]] +name = "same-file" +version = "1.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "93fc1dc3aaa9bfed95e02e6eadabb4baf7e3078b0bd1b4d7b6b0b68378900502" +dependencies = [ + "winapi-util", +] + [[package]] name = "serde" version = "1.0.198" @@ -210,6 +530,17 @@ dependencies = [ "syn", ] +[[package]] +name = "serde_json" +version = "1.0.117" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "455182ea6142b14f93f4bc5320a2b31c1f266b66a4a5c858b013302a5d8cbfc3" +dependencies = [ + "itoa", + "ryu", + "serde", +] + [[package]] name = "smallvec" version = "1.13.2" @@ -233,14 +564,180 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "tinytemplate" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "be4d6b5f19ff7664e8c98d03e2139cb510db9b0a60b55f8e8709b689d939b6bc" +dependencies = [ + "serde", + "serde_json", +] + [[package]] name = "unicode-ident" version = "1.0.12" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b" +[[package]] +name = "walkdir" +version = "2.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "29790946404f91d9c5d06f9874efddea1dc06c5efe94541a7d6863108e3a5e4b" +dependencies = [ + "same-file", + "winapi-util", +] + [[package]] name = "wasi" version = "0.11.0+wasi-snapshot-preview1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" + +[[package]] +name = "wasm-bindgen" +version = "0.2.92" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4be2531df63900aeb2bca0daaaddec08491ee64ceecbee5076636a3b026795a8" +dependencies = [ + "cfg-if", + "wasm-bindgen-macro", +] + +[[package]] +name = "wasm-bindgen-backend" +version = "0.2.92" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "614d787b966d3989fa7bb98a654e369c762374fd3213d212cfc0251257e747da" +dependencies = [ + "bumpalo", + "log", + "once_cell", + "proc-macro2", + "quote", + "syn", + "wasm-bindgen-shared", +] + +[[package]] +name = "wasm-bindgen-macro" +version = "0.2.92" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a1f8823de937b71b9460c0c34e25f3da88250760bec0ebac694b49997550d726" +dependencies = [ + "quote", + "wasm-bindgen-macro-support", +] + +[[package]] +name = "wasm-bindgen-macro-support" +version = "0.2.92" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e94f17b526d0a461a191c78ea52bbce64071ed5c04c9ffe424dcb38f74171bb7" +dependencies = [ + "proc-macro2", + "quote", + "syn", + "wasm-bindgen-backend", + "wasm-bindgen-shared", +] + +[[package]] +name = "wasm-bindgen-shared" +version = "0.2.92" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "af190c94f2773fdb3729c55b007a722abb5384da03bc0986df4c289bf5567e96" + +[[package]] +name = "web-sys" +version = "0.3.69" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "77afa9a11836342370f4817622a2f0f418b134426d91a82dfb48f532d2ec13ef" +dependencies = [ + "js-sys", + "wasm-bindgen", +] + +[[package]] +name = "winapi-util" +version = "0.1.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4d4cc384e1e73b93bafa6fb4f1df8c41695c8a91cf9c4c64358067d15a7b6c6b" +dependencies = [ + "windows-sys", +] + +[[package]] +name = "windows-sys" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d" +dependencies = [ + "windows-targets", +] + +[[package]] +name = "windows-targets" +version = "0.52.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6f0713a46559409d202e70e28227288446bf7841d3211583a4b53e3f6d96e7eb" +dependencies = [ + "windows_aarch64_gnullvm", + "windows_aarch64_msvc", + "windows_i686_gnu", + "windows_i686_gnullvm", + "windows_i686_msvc", + "windows_x86_64_gnu", + "windows_x86_64_gnullvm", + "windows_x86_64_msvc", +] + +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.52.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7088eed71e8b8dda258ecc8bac5fb1153c5cffaf2578fc8ff5d61e23578d3263" + +[[package]] +name = "windows_aarch64_msvc" +version = "0.52.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9985fd1504e250c615ca5f281c3f7a6da76213ebd5ccc9561496568a2752afb6" + +[[package]] +name = "windows_i686_gnu" +version = "0.52.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "88ba073cf16d5372720ec942a8ccbf61626074c6d4dd2e745299726ce8b89670" + +[[package]] +name = "windows_i686_gnullvm" +version = "0.52.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "87f4261229030a858f36b459e748ae97545d6f1ec60e5e0d6a3d32e0dc232ee9" + +[[package]] +name = "windows_i686_msvc" +version = "0.52.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "db3c2bf3d13d5b658be73463284eaf12830ac9a26a90c717b7f771dfe97487bf" + +[[package]] +name = "windows_x86_64_gnu" +version = "0.52.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4e4246f76bdeff09eb48875a0fd3e2af6aada79d409d33011886d3e1581517d9" + +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.52.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "852298e482cd67c356ddd9570386e2862b5673c85bd5f88df9ab6802b334c596" + +[[package]] +name = "windows_x86_64_msvc" +version = "0.52.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bec47e5bfd1bff0eeaf6d8b485cc1074891a197ab4225d504cb7a1ab88b02bf0" diff --git a/Cargo.toml b/Cargo.toml index 74fc978..f899433 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,7 +1,12 @@ [package] -name = "bin-rs" +name = "phantom-zone" version = "0.1.0" edition = "2021" +readme = "README.md" +repository = "https://github.com/gausslabs/phantom-zone" +license = "MIT" +keywords = ["fhe", "mpc", "cryptography"] +description = "Library for multi-party computation using fully-homomorphic encryption" [dependencies] itertools = "0.12.0" @@ -10,3 +15,47 @@ rand = "0.8.5" rand_chacha = "0.3.1" rand_distr = "0.4.3" num-bigint-dig = { version = "0.8.4", features = ["prime"] } + +[dev-dependencies] +criterion = "0.5.1" + +[features] +interactive_mp = [] +non_interactive_mp = [] + +[[bench]] +name = "ntt" +harness = false + +[[bench]] +name = "modulus" +harness = false + +[[example]] +name = "interactive_fheuint8" +path = "./examples/interactive_fheuint8.rs" + +[[example]] +name = "non_interactive_fheuint8" +path = "./examples/non_interactive_fheuint8.rs" +required-features = ["non_interactive_mp"] + +[[example]] +name = "meeting_friends" +path = "./examples/meeting_friends.rs" +required-features = ["non_interactive_mp"] + +[[example]] +name = "bomberman" +path = "./examples/bomberman.rs" +required-features = ["non_interactive_mp"] + +[[example]] +name = "div_by_zero" +path = "./examples/div_by_zero.rs" +required-features = ["non_interactive_mp"] + +[[example]] +name = "if_and_else" +path = "./examples/if_and_else.rs" +required-features = ["non_interactive_mp"] \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000..b0d6318 --- /dev/null +++ b/README.md @@ -0,0 +1,86 @@ +**Phantom zone** is similar to the zone where superman gets locked and observes everything outside the zone, but no one outside can see or hear superman. However our phantom zone isn't meant to lock anyone. Instead it's meant to be a new zone in parallel to reality. It's the zone to which you teleport yourself with others, take arbitrary actions, and remember only predefined set of memories when you're back. Think of the zone as a computer that erases itself off of the face of the earth after it returns the output, leaving no trace behind. + +**More formally, phantom-zone is a experimental multi-party computation library that uses multi-party fully homomorphic encryption to compute arbitrary functions on private inputs from multiple parties.** + +At the moment phantom-zone is pretty limited in its functionality. It offers to write circuits with encrypted 8 bit unsigned integers (referred to as FheUint8) and only supports upto 8 parties. FheUint8 supports the same arithmetic as a regular uint8, with a few exceptions mentioned below. We plan to extend APIs to other signed/unsigned types in the future. + +We provide two types of multi-party protocols, both only differ in key-generation procedure. +1. **Non-interactive multi-party protocol,** which requires a single shot message from the clients to the server after which the server can evaluate any arbitrary function on encrypted client inputs. +2. **Interactive multi-party protocol**, a 2 round protocol where in the first round clients interact to generate collective public key and in the second round clients send their server key share to the server, after which server can evaluate any arbitrary function on encrypted client inputs. + +Understanding that library is in experimental stage, if you want to use it for your application but find it lacking in features, please don't shy away from opening a issue or getting in touch. We don't want you to hold back those imaginative horses! + +## How to use + +We provide a brief overview below. Please refer to detailed examples, especially [non_interactive_fheuint8](./examples/non_interactive_fheuint8.rs) and [interactive_fheuint8](./examples/interactive_fheuint8.rs), to understand how to instantiate and run the protocols. + +### Non-interactive multi-party + +Each client is assigned an `id`, referred to as `user_id`, which denotes serial no. of the client out of total clients participating in the multi-party protocol. After learning their `user_id`, the client uploads their server key share along with encryptions of private inputs in a single shot message to the server. Server can then evaluate any arbitrary function on clients' private inputs. New private inputs can be provided in the future by the fix set of parties that participated in the protocol. + +### Interactive multi-party + +Like the non-interactive multi-party, each client is assigned `user_id`. After learning their `id`, clients participate in a 2 round protocol. In round 1, clients generate public key shares, share it with each other, and aggregate public key shares to produce the collective public key. In round 2, clients use the collective public key to generate their server key shares and encrypt their private inputs. Server receives server key shares and encryptions of private inputs from each client. Server aggregates the server key shares, after which it can evaluate any arbitrary function on clients' private inputs. New private inputs can be provided in the future by anyone with access to collective public key. + +### Multi-party decryption + +To decrypt output ciphertext(s) obtained as result of some computation, the clients come online. They download output ciphertext(s) from the server, generate decryption shares, and share it with other parties. Clients, after receiving decryption shares of other parties, aggregate the shares and decrypt the ciphertext(s). + +### Parameter selection + +We provide parameters to run both multi-party protocols for upto 8 parties. + +| $\leq$ # Parties | Interactive multi-party | Non-interactive multi-party | +| ------------ | ----------------------- | --------------------------- | +| 2 | InteractiveLTE2Party | NonInteractiveLTE2Party | +| 4 | InteractiveLTE4Party | NonInteractiveLTE4Party | +| 8 | InteractiveLTE8Party | NonInteractiveLTE8Party | + +If you have use-case `> 8` parties, please open an issue. We're willing to find suitable parameters for `> 8` parties. + +Parameters supporting `<= N` parties must not be used for multi-party compute between `> N` parties. This will lead to increase in failure probability. + +### Feature selection + +To use the library for non-interactive multi-party, you must add `non_interactive_mp` feature flag like `--features "non_interactive_mp"`. And to use the library for interactive multi-party you must add `interactive_mp` feature flag like `--features "interactive_mp"`. + +### FheUInt8 + +We provide APIs for all basic arithmetic (+, -, x, /, %) and comparison operations. + +All arithmetic operation by default wrap around (i.e. $\mod{256}$ for FheUint8). We also provide `overflow_{add/add_assign}` and `overflow_sub` that returns a flag ciphertext which is set to `True` if addition/subtraction overflowed and to `False` otherwise. + +Division operation (/) returns `quotient` and the remainder operation (%) returns `remainder` s.t. `dividend = division x quotient + remainder`. If both `quotient` and `remainder` are required, then `div_rem` can be used. In case of division by zero, [Div by zero error flag](#Div-by-zero-error-flag) will be set and `quotient` will be set to `255` and `remainder` to equal `dividend`. + +**Div by zero error flag** + +In encrypted domain there's no way to panic upon division by zero at runtime. Instead we set a local flag ciphertext accessible via `div_by_zero_flag()` that stores a boolean ciphertext indicating whether any of the divisions performed during the execution attempted division by zero. Assuming division by zero detection is critical for your application, we recommend decrypting the flag ciphertext along with other output ciphertexts in multi-party decryption procedure. + +The div by zero flag is thread local. If you run multiple different FHE circuits in sequence without stopping the thread (i.e. within a single program) you will have to reset div by zero error flag before starting of the next in-sequence circuit execution with `reset_error_flags()`. + +Please refer to [div_by_zero](./examples/div_by_zero.rs) example for more details. + +**If and else using mux** + +Branching in encrypted domain is expensive because the code must execute all the branches. Hence cost grows exponentially with no. of conditional branches. In general we recommend to modify the code to minimise conditional branches. However, if a code cannot be modified to made branchless, we provide `mux` API for FheUint8s. `mux` selects one of the two FheUint8s based on a selector bit. Please refer to [if_and_else](./examples/if_and_else.rs) example for more details. + +## Security + +> [!WARNING] +> Code has not been audited and we currently do not provide any security guarantees outside of the cryptographic parameters. We don't recommend to deploy it in production or use to handle sensitive data. + +All provided parameters are $2^{128}$ ring operations secure according to [lattice estimator](https://github.com/malb/lattice-estimator) and have failure probability of $\leq 2^{-40}$. However, there are two vital points to keep in mind: + +1. Users must not generate two different decryption shares for the same ciphertext, as it can lead to key-recovery attacks. To avoid this, we suggest users maintain a local table listing ciphertext against any previously generated decryption share. Then only generate a new decryption share if ciphertext does not exist in the table, otherwise return the existing share. We believe this should be handled by the library and will add support for this in future. +2. Users must not run the MPC protocol more than once for the same application seed and produce different outputs, as it can lead to key-recovery attacks. We believe this should be handled by the library and will add support for this in future. + +## Credits + +- We thank Barry Whitehat and Brian Lawrence for many helpful discussions. +- We thank Vivek Bhupatiraju and Andrew Lu for for many insightful discussions on fascinating phantom zone applications. +- Non-interactive multi-party RLWE key generation setup is a new protocol designed by Jean Philippe and Janmajaya. We thank Christian Mouchet for the review of the protocol and helpful suggestions. +- We thank Yao Wang for his help with rust compiler troubles. + +## References +1. [Efficient FHEW Bootstrapping with Small Evaluation Keys, and Applications to Threshold Homomorphic Encryption](https://eprint.iacr.org/2022/198.pdf) +2. [Multiparty Homomorphic Encryption from Ring-Learning-with-Errors](https://eprint.iacr.org/2020/304.pdf) \ No newline at end of file diff --git a/benches/modulus.rs b/benches/modulus.rs new file mode 100644 index 0000000..f18a740 --- /dev/null +++ b/benches/modulus.rs @@ -0,0 +1,152 @@ +use bin_rs::{ + ArithmeticLazyOps, ArithmeticOps, Decomposer, DefaultDecomposer, ModInit, ModularOpsU64, + ShoupMatrixFMA, VectorOps, +}; +use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion}; +use itertools::{izip, Itertools}; +use rand::{thread_rng, Rng}; +use rand_distr::Uniform; + +fn decompose_r(r: &[u64], decomp_r: &mut [Vec], decomposer: &DefaultDecomposer) { + let ring_size = r.len(); + // let d = decomposer.decomposition_count(); + // let mut count = 0; + for ri in 0..ring_size { + // let el_decomposed = decomposer.decompose(&r[ri]); + decomposer + .decompose_iter(&r[ri]) + .enumerate() + .into_iter() + .for_each(|(j, el)| { + decomp_r[j][ri] = el; + }); + } +} + +fn matrix_fma(out: &mut [u64], a: &Vec>, b: &Vec>, modop: &ModularOpsU64) { + izip!(a.iter(), b.iter()).for_each(|(a_r, b_r)| { + izip!(out.iter_mut(), a_r.iter(), b_r.iter()) + .for_each(|(o, ai, bi)| *o = modop.add_lazy(o, &modop.mul_lazy(ai, bi))); + }); +} + +fn benchmark_decomposer(c: &mut Criterion) { + let mut group = c.benchmark_group("decomposer"); + + // let decomposers = vec![]; + // 55 + for prime in [36028797017456641] { + for ring_size in [1 << 11] { + let logb = 11; + let decomposer = DefaultDecomposer::new(prime, logb, 2); + + let mut rng = thread_rng(); + let dist = Uniform::new(0, prime); + let a = (&mut rng).sample_iter(dist).take(ring_size).collect_vec(); + + group.bench_function( + BenchmarkId::new( + "decompose", + format!( + "q={prime}/N={ring_size}/logB={logb}/d={}", + *decomposer.decomposition_count().as_ref() + ), + ), + |b| { + b.iter_batched_ref( + || { + ( + a.clone(), + vec![ + vec![0u64; ring_size]; + *decomposer.decomposition_count().as_ref() + ], + ) + }, + |(r, decomp_r)| (decompose_r(r, decomp_r, &decomposer)), + criterion::BatchSize::PerIteration, + ) + }, + ); + } + } + + group.finish(); +} + +fn benchmark(c: &mut Criterion) { + let mut group = c.benchmark_group("modulus"); + // 55 + for prime in [36028797017456641] { + for ring_size in [1 << 11] { + let modop = ModularOpsU64::new(prime); + + let mut rng = thread_rng(); + let dist = Uniform::new(0, prime); + + let a0 = (&mut rng).sample_iter(dist).take(ring_size).collect_vec(); + let a1 = (&mut rng).sample_iter(dist).take(ring_size).collect_vec(); + let a2 = (&mut rng).sample_iter(dist).take(ring_size).collect_vec(); + + let d = 1; + let a0_matrix = (0..d) + .into_iter() + .map(|_| (&mut rng).sample_iter(dist).take(ring_size).collect_vec()) + .collect_vec(); + // a0 in shoup representation + let a0_shoup_matrix = a0_matrix + .iter() + .map(|r| { + r.iter() + .map(|v| { + // $(v * 2^{\beta}) / p$ + ((*v as u128 * (1u128 << 64)) / prime as u128) as u64 + }) + .collect_vec() + }) + .collect_vec(); + let a1_matrix = (0..d) + .into_iter() + .map(|_| (&mut rng).sample_iter(dist).take(ring_size).collect_vec()) + .collect_vec(); + + group.bench_function( + BenchmarkId::new("matrix_fma_lazy", format!("q={prime}/N={ring_size}/d={d}")), + |b| { + b.iter_batched_ref( + || (vec![0u64; ring_size]), + |(out)| black_box(matrix_fma(out, &a0_matrix, &a1_matrix, &modop)), + criterion::BatchSize::PerIteration, + ) + }, + ); + + group.bench_function( + BenchmarkId::new( + "matrix_shoup_fma_lazy", + format!("q={prime}/N={ring_size}/d={d}"), + ), + |b| { + b.iter_batched_ref( + || (vec![0u64; ring_size]), + |(out)| { + black_box(modop.shoup_matrix_fma( + out, + &a0_matrix, + &a0_shoup_matrix, + &a1_matrix, + )) + }, + criterion::BatchSize::PerIteration, + ) + }, + ); + } + } + + group.finish(); +} + +criterion_group!(decomposer, benchmark_decomposer); +criterion_group!(modulus, benchmark); +criterion_main!(modulus, decomposer); diff --git a/benches/ntt.rs b/benches/ntt.rs new file mode 100644 index 0000000..b9155c1 --- /dev/null +++ b/benches/ntt.rs @@ -0,0 +1,151 @@ +use bin_rs::{Ntt, NttBackendU64, NttInit}; +use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion}; +use itertools::Itertools; +use rand::{thread_rng, Rng}; +use rand_distr::Uniform; + +fn forward_matrix(a: &mut [Vec], nttop: &NttBackendU64) { + a.iter_mut().for_each(|r| nttop.forward(r.as_mut_slice())); +} + +fn forward_lazy_matrix(a: &mut [Vec], nttop: &NttBackendU64) { + a.iter_mut() + .for_each(|r| nttop.forward_lazy(r.as_mut_slice())); +} + +fn backward_matrix(a: &mut [Vec], nttop: &NttBackendU64) { + a.iter_mut().for_each(|r| nttop.backward(r.as_mut_slice())); +} + +fn backward_lazy_matrix(a: &mut [Vec], nttop: &NttBackendU64) { + a.iter_mut() + .for_each(|r| nttop.backward_lazy(r.as_mut_slice())); +} + +fn benchmark(c: &mut Criterion) { + let mut group = c.benchmark_group("ntt"); + // 55 + for prime in [36028797017456641] { + for ring_size in [1 << 11] { + let ntt = NttBackendU64::new(&prime, ring_size); + let mut rng = thread_rng(); + + let a = (&mut rng) + .sample_iter(Uniform::new(0, prime)) + .take(ring_size) + .collect_vec(); + let d = 2; + let a_matrix = (0..d) + .map(|_| { + (&mut rng) + .sample_iter(Uniform::new(0, prime)) + .take(ring_size) + .collect_vec() + }) + .collect_vec(); + + { + group.bench_function( + BenchmarkId::new("forward", format!("q={prime}/N={ring_size}")), + |b| { + b.iter_batched_ref( + || a.clone(), + |mut a| black_box(ntt.forward(&mut a)), + criterion::BatchSize::PerIteration, + ) + }, + ); + + group.bench_function( + BenchmarkId::new("forward_lazy", format!("q={prime}/N={ring_size}")), + |b| { + b.iter_batched_ref( + || a.clone(), + |mut a| black_box(ntt.forward_lazy(&mut a)), + criterion::BatchSize::PerIteration, + ) + }, + ); + + group.bench_function( + BenchmarkId::new("forward_matrix", format!("q={prime}/N={ring_size}/d={d}")), + |b| { + b.iter_batched_ref( + || a_matrix.clone(), + |a_matrix| black_box(forward_matrix(a_matrix, &ntt)), + criterion::BatchSize::PerIteration, + ) + }, + ); + + group.bench_function( + BenchmarkId::new( + "forward_lazy_matrix", + format!("q={prime}/N={ring_size}/d={d}"), + ), + |b| { + b.iter_batched_ref( + || a_matrix.clone(), + |a_matrix| black_box(forward_lazy_matrix(a_matrix, &ntt)), + criterion::BatchSize::PerIteration, + ) + }, + ); + } + + { + group.bench_function( + BenchmarkId::new("backward", format!("q={prime}/N={ring_size}")), + |b| { + b.iter_batched_ref( + || a.clone(), + |mut a| black_box(ntt.backward(&mut a)), + criterion::BatchSize::PerIteration, + ) + }, + ); + + group.bench_function( + BenchmarkId::new("backward_lazy", format!("q={prime}/N={ring_size}")), + |b| { + b.iter_batched_ref( + || a.clone(), + |mut a| black_box(ntt.backward_lazy(&mut a)), + criterion::BatchSize::PerIteration, + ) + }, + ); + + group.bench_function( + BenchmarkId::new("backward_matrix", format!("q={prime}/N={ring_size}")), + |b| { + b.iter_batched_ref( + || a_matrix.clone(), + |a_matrix| black_box(backward_matrix(a_matrix, &ntt)), + criterion::BatchSize::PerIteration, + ) + }, + ); + + group.bench_function( + BenchmarkId::new( + "backward_lazy_matrix", + format!("q={prime}/N={ring_size}/d={d}"), + ), + |b| { + b.iter_batched_ref( + || a_matrix.clone(), + |a_matrix| black_box(backward_lazy_matrix(a_matrix, &ntt)), + criterion::BatchSize::PerIteration, + ) + }, + ); + } + } + } + + group.finish(); +} + +criterion_group!(ntt, benchmark); +criterion_main!(ntt); diff --git a/examples/bomberman.rs b/examples/bomberman.rs new file mode 100644 index 0000000..8438c09 --- /dev/null +++ b/examples/bomberman.rs @@ -0,0 +1,178 @@ +use std::fmt::Debug; + +use itertools::Itertools; +use phantom_zone::*; +use rand::{thread_rng, Rng, RngCore}; + +struct Coordinates(T, T); +impl Coordinates { + fn new(x: T, y: T) -> Self { + Coordinates(x, y) + } + fn x(&self) -> &T { + &self.0 + } + + fn y(&self) -> &T { + &self.1 + } +} + +impl Debug for Coordinates +where + T: Debug, +{ + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Coordinates") + .field("x", self.x()) + .field("y", self.y()) + .finish() + } +} + +fn coordinates_is_equal(a: &Coordinates, b: &Coordinates) -> FheBool { + &(a.x().eq(b.x())) & &(a.y().eq(b.y())) +} + +/// Traverse the map with `p0` moves and check whether any of the moves equal +/// bomb coordinates (in encrypted domain) +fn traverse_map(p0: &[Coordinates], bomb_coords: &[Coordinates]) -> FheBool { + // First move + let mut out = coordinates_is_equal(&p0[0], &bomb_coords[0]); + bomb_coords.iter().skip(1).for_each(|b_coord| { + out |= coordinates_is_equal(&p0[0], &b_coord); + }); + + // rest of the moves + p0.iter().skip(1).for_each(|m_coord| { + bomb_coords.iter().for_each(|b_coord| { + out |= coordinates_is_equal(m_coord, b_coord); + }); + }); + + out +} + +// Do you recall bomberman? It's an interesting game where the bomberman has to +// cross the map without stepping on strategically placed bombs all over the +// map. Below we implement a very basic prototype of bomberman with 4 players. +// +// The map has 256 tiles with bottom left-most tile labelled with coordinates +// (0,0) and top right-most tile labelled with coordinates (255, 255). There are +// 4 players: Player 0, Player 1, Player 2, Player 3. Player 0's task is to walk +// across the map with fixed no. of moves while preventing itself from stepping +// on any of the bombs placed on the map by Player 1, 2, and 3. +// +// The twist is that Player 0's moves and the locations of bombs placed by other +// players are encrypted. Player 0 moves across the map in encrypted domain. +// Only a boolean output indicating whether player 0 survived after all the +// moves or killed itself by stepping onto a bomb is revealed at the end. If +// player 0 survives, Player 1, 2, 3 never learn what moves did Player 0 make. +// If Player 0 kills itself by stepping onto a bomb, it only learns that bomb +// was placed on one of the coordinates it moved to. Moreover, Player 1, 2, 3 +// never learn locations of each other bombs or whose bomb killed Player 0. +fn main() { + set_parameter_set(ParameterSelector::NonInteractiveLTE4Party); + + // set application's common reference seed + let mut seed = [0; 32]; + thread_rng().fill_bytes(&mut seed); + set_common_reference_seed(seed); + + let no_of_parties = 4; + + // Client side // + + // Players generate client keys + let cks = (0..no_of_parties).map(|_| gen_client_key()).collect_vec(); + + // Players generate server keys + let server_key_shares = cks + .iter() + .enumerate() + .map(|(index, k)| gen_server_key_share(index, no_of_parties, k)) + .collect_vec(); + + // Player 0 describes its moves as sequence of coordinates on the map + let no_of_moves = 10; + let player_0_moves = (0..no_of_moves) + .map(|_| Coordinates::new(thread_rng().gen::(), thread_rng().gen())) + .collect_vec(); + // Coordinates of bomb placed by Player 1 + let player_1_bomb = Coordinates::new(thread_rng().gen::(), thread_rng().gen()); + // Coordinates of bomb placed by Player 2 + let player_2_bomb = Coordinates::new(thread_rng().gen::(), thread_rng().gen()); + // Coordinates of bomb placed by Player 3 + let player_3_bomb = Coordinates::new(thread_rng().gen::(), thread_rng().gen()); + + println!("P0 moves coordinates: {:?}", &player_0_moves); + println!("P1 bomb coordinates : {:?}", &player_1_bomb); + println!("P2 bomb coordinates : {:?}", &player_2_bomb); + println!("P3 bomb coordinates : {:?}", &player_3_bomb); + + // Players encrypt their private inputs + let player_0_enc = cks[0].encrypt( + player_0_moves + .iter() + .flat_map(|c| vec![*c.x(), *c.y()]) + .collect_vec() + .as_slice(), + ); + let player_1_enc = cks[1].encrypt(vec![*player_1_bomb.x(), *player_1_bomb.y()].as_slice()); + let player_2_enc = cks[2].encrypt(vec![*player_2_bomb.x(), *player_2_bomb.y()].as_slice()); + let player_3_enc = cks[3].encrypt(vec![*player_3_bomb.x(), *player_3_bomb.y()].as_slice()); + + // Players upload the encrypted inputs and server key shares to the server + + // Server side // + + let server_key = aggregate_server_key_shares(&server_key_shares); + server_key.set_server_key(); + + // server parses Player inputs + let player_0_moves_enc = { + let c = player_0_enc + .unseed::>>() + .key_switch(0) + .extract_all(); + c.into_iter() + .chunks(2) + .into_iter() + .map(|mut x_y| Coordinates::new(x_y.next().unwrap(), x_y.next().unwrap())) + .collect_vec() + }; + let player_1_bomb_enc = { + let c = player_1_enc.unseed::>>().key_switch(1); + Coordinates::new(c.extract_at(0), c.extract_at(1)) + }; + let player_2_bomb_enc = { + let c = player_2_enc.unseed::>>().key_switch(2); + Coordinates::new(c.extract_at(0), c.extract_at(1)) + }; + let player_3_bomb_enc = { + let c = player_3_enc.unseed::>>().key_switch(3); + Coordinates::new(c.extract_at(0), c.extract_at(1)) + }; + + // Server runs the game + let player_0_dead_ct = traverse_map( + &player_0_moves_enc, + &vec![player_1_bomb_enc, player_2_bomb_enc, player_3_bomb_enc], + ); + + // Client side // + + // Players generate decryption shares and send them to each other + let decryption_shares = cks + .iter() + .map(|k| k.gen_decryption_share(&player_0_dead_ct)) + .collect_vec(); + // Players decrypt to find whether Player 0 survived + let player_0_dead = cks[0].aggregate_decryption_shares(&player_0_dead_ct, &decryption_shares); + + if player_0_dead { + println!("Oops! Player 0 dead"); + } else { + println!("Wohoo! Player 0 survived"); + } +} diff --git a/examples/div_by_zero.rs b/examples/div_by_zero.rs new file mode 100644 index 0000000..e5b56f5 --- /dev/null +++ b/examples/div_by_zero.rs @@ -0,0 +1,126 @@ +use itertools::Itertools; +use phantom_zone::*; +use rand::{thread_rng, Rng, RngCore}; + +fn main() { + set_parameter_set(ParameterSelector::NonInteractiveLTE2Party); + + // set application's common reference seed + let mut seed = [0u8; 32]; + thread_rng().fill_bytes(&mut seed); + set_common_reference_seed(seed); + + let no_of_parties = 2; + + // Generate client keys + let cks = (0..no_of_parties).map(|_| gen_client_key()).collect_vec(); + + // Generate server key shares + let server_key_shares = cks + .iter() + .enumerate() + .map(|(id, k)| gen_server_key_share(id, no_of_parties, k)) + .collect_vec(); + + // Aggregate server key shares and set the server key + let server_key = aggregate_server_key_shares(&server_key_shares); + server_key.set_server_key(); + + // -------- + + // We attempt to divide by 0 in encrypted domain and then check whether div by 0 + // error flag is set to True. + let numerator = thread_rng().gen::(); + let numerator_enc = cks[0] + .encrypt(vec![numerator].as_slice()) + .unseed::>>() + .key_switch(0) + .extract_at(0); + let zero_enc = cks[1] + .encrypt(vec![0].as_slice()) + .unseed::>>() + .key_switch(1) + .extract_at(0); + + let (quotient_enc, remainder_enc) = numerator_enc.div_rem(&zero_enc); + + // When attempting to divide by zero, for uint8 quotient is always 255 and + // remainder = numerator + let quotient = cks[0].aggregate_decryption_shares( + "ient_enc, + &cks.iter() + .map(|k| k.gen_decryption_share("ient_enc)) + .collect_vec(), + ); + let remainder = cks[0].aggregate_decryption_shares( + &remainder_enc, + &cks.iter() + .map(|k| k.gen_decryption_share(&remainder_enc)) + .collect_vec(), + ); + assert!(quotient == 255); + assert!(remainder == numerator); + + // Div by zero error flag must be True + let div_by_zero_enc = div_zero_error_flag().expect("We performed division. Flag must be set"); + let div_by_zero = cks[0].aggregate_decryption_shares( + &div_by_zero_enc, + &cks.iter() + .map(|k| k.gen_decryption_share(&div_by_zero_enc)) + .collect_vec(), + ); + assert!(div_by_zero == true); + + // ------- + + // div by zero error flag is thread local. If we were to run another circuit + // without stopping the thread (i.e. within the same program as previous + // one), we must reset errors flags set by previous circuit with + // `reset_error_flags()` to prevent error flags of previous circuit affecting + // the flags of the next circuit. + reset_error_flags(); + + // We divide again but with non-zero denominator this time and check that div + // by zero flag is set to False + let numerator = thread_rng().gen::(); + let mut denominator = thread_rng().gen::(); + while denominator == 0 { + denominator = thread_rng().gen::(); + } + let numerator_enc = cks[0] + .encrypt(vec![numerator].as_slice()) + .unseed::>>() + .key_switch(0) + .extract_at(0); + let denominator_enc = cks[1] + .encrypt(vec![denominator].as_slice()) + .unseed::>>() + .key_switch(1) + .extract_at(0); + + let (quotient_enc, remainder_enc) = numerator_enc.div_rem(&denominator_enc); + let quotient = cks[0].aggregate_decryption_shares( + "ient_enc, + &cks.iter() + .map(|k| k.gen_decryption_share("ient_enc)) + .collect_vec(), + ); + let remainder = cks[0].aggregate_decryption_shares( + &remainder_enc, + &cks.iter() + .map(|k| k.gen_decryption_share(&remainder_enc)) + .collect_vec(), + ); + assert!(quotient == numerator.div_euclid(denominator)); + assert!(remainder == numerator.rem_euclid(denominator)); + + // Div by zero error flag must be set to False + let div_by_zero_enc = div_zero_error_flag().expect("We performed division. Flag must be set"); + let div_by_zero = cks[0].aggregate_decryption_shares( + &div_by_zero_enc, + &cks.iter() + .map(|k| k.gen_decryption_share(&div_by_zero_enc)) + .collect_vec(), + ); + assert!(div_by_zero == false); +} diff --git a/examples/if_and_else.rs b/examples/if_and_else.rs new file mode 100644 index 0000000..bf2f8ea --- /dev/null +++ b/examples/if_and_else.rs @@ -0,0 +1,107 @@ +use itertools::Itertools; +use phantom_zone::*; +use rand::{thread_rng, Rng, RngCore}; + +/// Code that runs when conditional branch is `True` +fn circuit_branch_true(a: &FheUint8, b: &FheUint8) -> FheUint8 { + a + b +} + +/// Code that runs when conditional branch is `False` +fn circuit_branch_false(a: &FheUint8, b: &FheUint8) -> FheUint8 { + a * b +} + +// Conditional branching (ie. If and else) are generally expensive in encrypted +// domain. The code must execute all the branches, and, as apparent, the +// runtime cost grows exponentially with no. of conditional branches. +// +// In general we recommend to write branchless code. In case the code cannot be +// modified to be branchless, the code must execute all branches and use a +// muxer to select correct output at the end. +// +// Below we showcase example of a single conditional branch in encrypted domain. +// The code executes both the branches (i.e. program runs both If and Else) and +// selects output of one of the branches with a mux. +fn main() { + set_parameter_set(ParameterSelector::NonInteractiveLTE2Party); + + // set application's common reference seed + let mut seed = [0u8; 32]; + thread_rng().fill_bytes(&mut seed); + set_common_reference_seed(seed); + + let no_of_parties = 2; + + // Generate client keys + let cks = (0..no_of_parties).map(|_| gen_client_key()).collect_vec(); + + // Generate server key shares + let server_key_shares = cks + .iter() + .enumerate() + .map(|(id, k)| gen_server_key_share(id, no_of_parties, k)) + .collect_vec(); + + // Aggregate server key shares and set the server key + let server_key = aggregate_server_key_shares(&server_key_shares); + server_key.set_server_key(); + + // ------- + + // User 0 encrypts their private input `v_a` and User 1 encrypts their + // private input `v_b`. We want to execute: + // + // if v_a < v_b: + // return v_a + v_b + // else: + // return v_a * v_b + // + // We define two functions + // (1) `circuit_branch_true`: which executes v_a + v_b in encrypted domain. + // (2) `circuit_branch_false`: which executes v_a * v_b in encrypted + // domain. + // + // The circuit runs both `circuit_branch_true` and `circuit_branch_false` and + // then selects the output of `circuit_branch_true` if `v_a < v_b == TRUE` + // otherwise selects the output of `circuit_branch_false` if `v_a < v_b == + // FALSE` using mux. + + // Clients private inputs + let v_a = thread_rng().gen::(); + let v_b = thread_rng().gen::(); + let v_a_enc = cks[0] + .encrypt(vec![v_a].as_slice()) + .unseed::>>() + .key_switch(0) + .extract_at(0); + let v_b_enc = cks[1] + .encrypt(vec![v_b].as_slice()) + .unseed::>>() + .key_switch(1) + .extract_at(0); + + // Run both branches + let out_true_enc = circuit_branch_true(&v_a_enc, &v_b_enc); + let out_false_enc = circuit_branch_false(&v_a_enc, &v_b_enc); + + // define condition select v_a < v_b + let selector_bit = v_a_enc.lt(&v_b_enc); + + // select output of `circuit_branch_true` if selector_bit == TRUE otherwise + // select output of `circuit_branch_false` + let out_enc = out_true_enc.mux(&out_false_enc, &selector_bit); + + let out = cks[0].aggregate_decryption_shares( + &out_enc, + &cks.iter() + .map(|k| k.gen_decryption_share(&out_enc)) + .collect_vec(), + ); + let want_out = if v_a < v_b { + v_a.wrapping_add(v_b) + } else { + v_a.wrapping_mul(v_b) + }; + assert_eq!(out, want_out); +} diff --git a/examples/interactive_fheuint8.rs b/examples/interactive_fheuint8.rs new file mode 100644 index 0000000..cf785b1 --- /dev/null +++ b/examples/interactive_fheuint8.rs @@ -0,0 +1,180 @@ +use itertools::Itertools; +use phantom_zone::*; +use rand::{thread_rng, Rng, RngCore}; + +fn function1(a: u8, b: u8, c: u8, d: u8) -> u8 { + ((a + b) * c) * d +} + +fn function1_fhe(a: &FheUint8, b: &FheUint8, c: &FheUint8, d: &FheUint8) -> FheUint8 { + &(&(a + b) * c) * d +} + +fn function2(a: u8, b: u8, c: u8, d: u8) -> u8 { + (a * b) + (c * d) +} + +fn function2_fhe(a: &FheUint8, b: &FheUint8, c: &FheUint8, d: &FheUint8) -> FheUint8 { + &(a * b) + &(c * d) +} + +fn main() { + // Select parameter set + set_parameter_set(ParameterSelector::InteractiveLTE4Party); + + // set application's common reference seed + let mut seed = [0u8; 32]; + thread_rng().fill_bytes(&mut seed); + set_common_reference_seed(seed); + + let no_of_parties = 4; + + // Client side // + + // Clients generate their private keys + let cks = (0..no_of_parties) + .into_iter() + .map(|_| gen_client_key()) + .collect_vec(); + + // -- Round 1 -- // + // In round 1 each client generates their share for the collective public key. + // They send public key shares to each other with or out without the server. + // After receiving others public key shares clients independently aggregate + // the shares and produce the collective public key `pk` + + let pk_shares = cks.iter().map(|k| collective_pk_share(k)).collect_vec(); + + // Clients aggregate public key shares to produce collective public key `pk` + let pk = aggregate_public_key_shares(&pk_shares); + + // -- Round 2 -- // + // In round 2 each client generates server key share using the public key `pk`. + // Clients may also encrypt their private inputs using collective public key + // `pk`. Each client then uploads their server key share and private input + // ciphertexts to the server. + + // Clients generate server key shares + // + // We assign user_id 0 to client 0, user_id 1 to client 1, user_id 2 to client + // 2, and user_id 4 to client 4. + // + // Note that `user_id`'s must be unique among the clients and must be less than + // total number of clients. + let server_key_shares = cks + .iter() + .enumerate() + .map(|(user_id, k)| collective_server_key_share(k, user_id, no_of_parties, &pk)) + .collect_vec(); + + // Each client encrypts their private inputs using the collective public key + // `pk`. Unlike non-inteactive MPC protocol, private inputs are + // encrypted using collective public key. + let c0_a = thread_rng().gen::(); + let c0_enc = pk.encrypt(vec![c0_a].as_slice()); + let c1_a = thread_rng().gen::(); + let c1_enc = pk.encrypt(vec![c1_a].as_slice()); + let c2_a = thread_rng().gen::(); + let c2_enc = pk.encrypt(vec![c2_a].as_slice()); + let c3_a = thread_rng().gen::(); + let c3_enc = pk.encrypt(vec![c3_a].as_slice()); + + // Clients upload their server key along with private encrypted inputs to + // the server + + // Server side // + + // Server receives server key shares from each client and proceeds to + // aggregate the shares and produce the server key + let server_key = aggregate_server_key_shares(&server_key_shares); + server_key.set_server_key(); + + // Server proceeds to extract clients private inputs + // + // Clients encrypt their FheUint8s inputs packed in a batched ciphertext. + // The server must extract clients private inputs from the batch ciphertext + // either (1) using `extract_at(index)` to extract `index`^{th} FheUint8 + // ciphertext (2) or using `extract_all()` to extract all available FheUint8s + // (3) or using `extract_many(many)` to extract first `many` available FheUint8s + let c0_a_enc = c0_enc.extract_at(0); + let c1_a_enc = c1_enc.extract_at(0); + let c2_a_enc = c2_enc.extract_at(0); + let c3_a_enc = c3_enc.extract_at(0); + + // Server proceeds to evaluate function1 on clients private inputs + let ct_out_f1 = function1_fhe(&c0_a_enc, &c1_a_enc, &c2_a_enc, &c3_a_enc); + + // After server has finished evaluating the circuit on client private + // inputs, clients can proceed to multi-party decryption protocol to + // decrypt output ciphertext + + // Client Side // + + // In multi-party decryption protocol, client must come online, download the + // output ciphertext from the server, product "output ciphertext" dependent + // decryption share, and send it to other parties. After receiving + // decryption shares of other parties, clients independently aggregate the + // decrytion shares and decrypt the output ciphertext. + + // Clients generate decryption shares + let decryption_shares = cks + .iter() + .map(|k| k.gen_decryption_share(&ct_out_f1)) + .collect_vec(); + + // After receiving decryption shares from other parties, clients aggregate the + // shares and decrypt output ciphertext + let out_f1 = cks[0].aggregate_decryption_shares(&ct_out_f1, &decryption_shares); + + // Check correctness of function1 output + let want_f1 = function1(c0_a, c1_a, c2_a, c3_a); + assert!(out_f1 == want_f1); + + // -------- + + // Once server key is produced it can be re-used across different functions + // with different private client inputs for the same set of clients. + // + // Here we run `function2_fhe` for the same of clients but with different + // private inputs. Clients do not need to participate in the 2 round + // protocol again, instead they only upload their new private inputs to the + // server. + + // Clients encrypt their private inputs + let c0_a = thread_rng().gen::(); + let c0_enc = pk.encrypt(vec![c0_a].as_slice()); + let c1_a = thread_rng().gen::(); + let c1_enc = pk.encrypt(vec![c1_a].as_slice()); + let c2_a = thread_rng().gen::(); + let c2_enc = pk.encrypt(vec![c2_a].as_slice()); + let c3_a = thread_rng().gen::(); + let c3_enc = pk.encrypt(vec![c3_a].as_slice()); + + // Clients uploads only their new private inputs to the server + + // Server side // + + // Server receives private inputs from the clients, extracts them, and + // proceeds to evaluate `function2_fhe` + let c0_a_enc = c0_enc.extract_at(0); + let c1_a_enc = c1_enc.extract_at(0); + let c2_a_enc = c2_enc.extract_at(0); + let c3_a_enc = c3_enc.extract_at(0); + + let ct_out_f2 = function2_fhe(&c0_a_enc, &c1_a_enc, &c2_a_enc, &c3_a_enc); + + // Client side // + + // Clients generate decryption shares for `ct_out_f2` + let decryption_shares = cks + .iter() + .map(|k| k.gen_decryption_share(&ct_out_f2)) + .collect_vec(); + + // Clients aggregate decryption shares and decrypt `ct_out_f2` + let out_f2 = cks[0].aggregate_decryption_shares(&ct_out_f2, &decryption_shares); + + // We check correctness of function2 + let want_f2 = function2(c0_a, c1_a, c2_a, c3_a); + assert!(want_f2 == out_f2); +} diff --git a/examples/meeting_friends.rs b/examples/meeting_friends.rs new file mode 100644 index 0000000..d3a877c --- /dev/null +++ b/examples/meeting_friends.rs @@ -0,0 +1,150 @@ +use itertools::Itertools; +use phantom_zone::*; +use rand::{thread_rng, Rng, RngCore}; + +struct Location(T, T); + +impl Location { + fn new(x: T, y: T) -> Self { + Location(x, y) + } + + fn x(&self) -> &T { + &self.0 + } + fn y(&self) -> &T { + &self.1 + } +} + +fn should_meet(a: &Location, b: &Location, b_threshold: &u8) -> bool { + let diff_x = a.x() - b.x(); + let diff_y = a.y() - b.y(); + let d_sq = &(&diff_x * &diff_x) + &(&diff_y * &diff_y); + + d_sq.le(b_threshold) +} + +/// Calculates distance square between a's and b's location. Returns a boolean +/// indicating whether diatance sqaure is <= `b_threshold`. +fn should_meet_fhe( + a: &Location, + b: &Location, + b_threshold: &FheUint8, +) -> FheBool { + let diff_x = a.x() - b.x(); + let diff_y = a.y() - b.y(); + let d_sq = &(&diff_x * &diff_x) + &(&diff_y * &diff_y); + + d_sq.le(b_threshold) +} + +// Ever wondered who are the long distance friends (friends of friends or +// friends of friends of friends...) that live nearby ? But how do you find +// them? Surely no-one will simply reveal their exact location just because +// there's a slight chance that a long distance friend lives nearby. +// +// Here we write a simple application with two users `a` and `b`. User `a` wants +// to find (long distance) friends that live in their neighbourhood. User `b` is +// open to meeting new friends within some distance of their location. Both user +// `a` and `b` encrypt their locations and upload their encrypted locations to +// the server. User `b` also encrypts the distance square threshold within which +// they are interested in meeting new friends and sends encrypted distance +// square threshold to the server. +// The server calculates the square of the distance between user a's location +// and user b's location and produces encrypted boolean output indicating +// whether square of distance is <= user b's supplied distance square threshold. +// User `a` then comes online, downloads output ciphertext, produces their +// decryption share for user `b`, and uploads the decryption share to the +// server. User `b` comes online, downloads output ciphertext and user a's +// decryption share, produces their own decryption share, and then decrypts the +// encrypted boolean output. If the output is `True`, it indicates user `a` is +// within the distance square threshold defined by user `b`. +fn main() { + set_parameter_set(ParameterSelector::NonInteractiveLTE2Party); + + // set application's common reference seed + let mut seed = [0u8; 32]; + thread_rng().fill_bytes(&mut seed); + set_common_reference_seed(seed); + + let no_of_parties = 2; + + // Client Side // + + // Generate client keys + let cks = (0..no_of_parties).map(|_| gen_client_key()).collect_vec(); + + // We assign user_id 0 to user `a` and user_id 1 user `b` + let a_id = 0; + let b_id = 1; + let user_a_secret = &cks[0]; + let user_b_secret = &cks[1]; + + // User `a` and `b` generate server key shares + let a_server_key_share = gen_server_key_share(a_id, no_of_parties, user_a_secret); + let b_server_key_share = gen_server_key_share(b_id, no_of_parties, user_b_secret); + + // User `a` and `b` encrypt their locations + let user_a_secret = &cks[0]; + let user_a_location = Location::new(thread_rng().gen::(), thread_rng().gen::()); + let user_a_enc = + user_a_secret.encrypt(vec![*user_a_location.x(), *user_a_location.y()].as_slice()); + + let user_b_location = Location::new(thread_rng().gen::(), thread_rng().gen::()); + // User `b` also encrypts the distance square threshold + let user_b_threshold = 40; + let user_b_enc = user_b_secret + .encrypt(vec![*user_b_location.x(), *user_b_location.y(), user_b_threshold].as_slice()); + + // Server Side // + + // Both user `a` and `b` upload their private inputs and server key shares to + // the server in single shot message + let server_key = aggregate_server_key_shares(&vec![a_server_key_share, b_server_key_share]); + server_key.set_server_key(); + + // Server parses private inputs from user `a` and `b` + let user_a_location_enc = { + let c = user_a_enc.unseed::>>().key_switch(a_id); + Location::new(c.extract_at(0), c.extract_at(1)) + }; + let (user_b_location_enc, user_b_threshold_enc) = { + let c = user_b_enc.unseed::>>().key_switch(b_id); + ( + Location::new(c.extract_at(0), c.extract_at(1)), + c.extract_at(2), + ) + }; + + // run the circuit + let out_c = should_meet_fhe( + &user_a_location_enc, + &user_b_location_enc, + &user_b_threshold_enc, + ); + + // Client Side // + + // user `a` comes online, downloads `out_c`, produces a decryption share, and + // uploads the decryption share to the server. + let a_dec_share = user_a_secret.gen_decryption_share(&out_c); + + // user `b` comes online downloads user `a`'s decryption share, generates their + // own decryption share, decrypts the output ciphertext. If the output is + // True, user `b` contacts user `a` to meet. + let b_dec_share = user_b_secret.gen_decryption_share(&out_c); + let out_bool = + user_b_secret.aggregate_decryption_shares(&out_c, &vec![b_dec_share, a_dec_share]); + + assert_eq!( + out_bool, + should_meet(&user_a_location, &user_b_location, &user_b_threshold) + ); + + if out_bool { + println!("A lives nearby. B should meet A."); + } else { + println!("A lives too far away!") + } +} diff --git a/examples/non_interactive_fheuint8.rs b/examples/non_interactive_fheuint8.rs new file mode 100644 index 0000000..19b09d6 --- /dev/null +++ b/examples/non_interactive_fheuint8.rs @@ -0,0 +1,177 @@ +use itertools::Itertools; +use phantom_zone::*; +use rand::{thread_rng, Rng, RngCore}; + +fn function1(a: u8, b: u8, c: u8, d: u8) -> u8 { + ((a + b) * c) * d +} + +fn function1_fhe(a: &FheUint8, b: &FheUint8, c: &FheUint8, d: &FheUint8) -> FheUint8 { + &(&(a + b) * c) * d +} + +fn function2(a: u8, b: u8, c: u8, d: u8) -> u8 { + (a * b) + (c * d) +} + +fn function2_fhe(a: &FheUint8, b: &FheUint8, c: &FheUint8, d: &FheUint8) -> FheUint8 { + &(a * b) + &(c * d) +} + +fn main() { + set_parameter_set(ParameterSelector::NonInteractiveLTE4Party); + + // set application's common reference seed + let mut seed = [0u8; 32]; + thread_rng().fill_bytes(&mut seed); + set_common_reference_seed(seed); + + let no_of_parties = 4; + + // Clide side // + + // Generate client keys + let cks = (0..no_of_parties).map(|_| gen_client_key()).collect_vec(); + + // client 0 encrypts its private inputs + let c0_a = thread_rng().gen::(); + // Clients encrypt their private inputs in a seeded batched ciphertext using + // their private RLWE secret `u_j`. + let c0_enc = cks[0].encrypt(vec![c0_a].as_slice()); + + // client 1 encrypts its private inputs + let c1_a = thread_rng().gen::(); + let c1_enc = cks[1].encrypt(vec![c1_a].as_slice()); + + // client 2 encrypts its private inputs + let c2_a = thread_rng().gen::(); + let c2_enc = cks[2].encrypt(vec![c2_a].as_slice()); + + // client 3 encrypts its private inputs + let c3_a = thread_rng().gen::(); + let c3_enc = cks[3].encrypt(vec![c3_a].as_slice()); + + // Clients independently generate their server key shares + // + // We assign user_id 0 to client 0, user_id 1 to client 1, user_id 2 to client + // 2, user_id 3 to client 3. + // + // Note that `user_id`s must be unique among the clients and must be less than + // total number of clients. + let server_key_shares = cks + .iter() + .enumerate() + .map(|(id, k)| gen_server_key_share(id, no_of_parties, k)) + .collect_vec(); + + // Each client uploads their server key shares and encrypted private inputs to + // the server in a single shot message. + + // Server side // + + // Server receives server key shares from each client and proceeds to aggregate + // them to produce the server key. After this point, server can use the server + // key to evaluate any arbitrary function on encrypted private inputs from + // the fixed set of clients + + // aggregate server shares and generate the server key + let server_key = aggregate_server_key_shares(&server_key_shares); + server_key.set_server_key(); + + // Server proceeds to extract private inputs sent by clients + // + // To extract client 0's (with user_id=0) private inputs we first key switch + // client 0's private inputs from theit secret `u_j` to ideal secret of the mpc + // protocol. To indicate we're key switching client 0's private input we + // supply client 0's `user_id` i.e. we call `key_switch(0)`. Then we extract + // the first ciphertext by calling `extract_at(0)`. + // + // Since client 0 only encrypts 1 input in batched ciphertext, calling + // extract_at(index) for `index` > 0 will panic. If client 0 had more private + // inputs then we can either extract them all at once with `extract_all` or + // first `many` of them with `extract_many(many)` + let ct_c0_a = c0_enc.unseed::>>().key_switch(0).extract_at(0); + + let ct_c1_a = c1_enc.unseed::>>().key_switch(1).extract_at(0); + let ct_c2_a = c2_enc.unseed::>>().key_switch(2).extract_at(0); + let ct_c3_a = c3_enc.unseed::>>().key_switch(3).extract_at(0); + + // After extracting each client's private inputs, server proceeds to evaluate + // function1 + let now = std::time::Instant::now(); + let ct_out_f1 = function1_fhe(&ct_c0_a, &ct_c1_a, &ct_c2_a, &ct_c3_a); + println!("Function1 FHE evaluation time: {:?}", now.elapsed()); + + // Server has finished running compute. Clients can proceed to decrypt the + // output ciphertext using multi-party decryption. + + // Client side // + + // In multi-party decryption, each client needs to come online, download output + // ciphertext from the server, produce "output ciphertext" dependent decryption + // share, and send it to other parties (either via p2p or via server). After + // receving decryption shares from other parties, clients can independently + // decrypt output ciphertext. + + // each client produces decryption share + let decryption_shares = cks + .iter() + .map(|k| k.gen_decryption_share(&ct_out_f1)) + .collect_vec(); + + // With all decryption shares, clients can aggregate the shares and decrypt the + // ciphertext + let out_f1 = cks[0].aggregate_decryption_shares(&ct_out_f1, &decryption_shares); + + // we check correctness of function1 + let want_out_f1 = function1(c0_a, c1_a, c2_a, c3_a); + assert_eq!(out_f1, want_out_f1); + + // ----------- + + // Server key can be re-used for different functions with different private + // client inputs for the same set of clients. + // + // Here we run `function2_fhe` for the same set of client but with new inputs. + // Clients only have to upload their private inputs to the server this time. + + // Each client encrypts their private input + let c0_a = thread_rng().gen::(); + let c0_enc = cks[0].encrypt(vec![c0_a].as_slice()); + let c1_a = thread_rng().gen::(); + let c1_enc = cks[1].encrypt(vec![c1_a].as_slice()); + let c2_a = thread_rng().gen::(); + let c2_enc = cks[2].encrypt(vec![c2_a].as_slice()); + let c3_a = thread_rng().gen::(); + let c3_enc = cks[3].encrypt(vec![c3_a].as_slice()); + + // Clients upload only their new private inputs to the server + + // Server side // + + // Server receives clients private inputs and extracts them + let ct_c0_a = c0_enc.unseed::>>().key_switch(0).extract_at(0); + let ct_c1_a = c1_enc.unseed::>>().key_switch(1).extract_at(0); + let ct_c2_a = c2_enc.unseed::>>().key_switch(2).extract_at(0); + let ct_c3_a = c3_enc.unseed::>>().key_switch(3).extract_at(0); + + // Server proceeds to evaluate `function2_fhe` + let now = std::time::Instant::now(); + let ct_out_f2 = function2_fhe(&ct_c0_a, &ct_c1_a, &ct_c2_a, &ct_c3_a); + println!("Function2 FHE evaluation time: {:?}", now.elapsed()); + + // Client side // + + // Each client generates decrytion share for `ct_out_f2` + let decryption_shares = cks + .iter() + .map(|k| k.gen_decryption_share(&ct_out_f2)) + .collect_vec(); + + // Clients independently aggregate the shares and decrypt + let out_f2 = cks[0].aggregate_decryption_shares(&ct_out_f2, &decryption_shares); + + // We check correctness of function2 + let want_out_f2 = function2(c0_a, c1_a, c2_a, c3_a); + assert_eq!(out_f2, want_out_f2); +} diff --git a/src/backend.rs b/src/backend.rs deleted file mode 100644 index 2b3f92c..0000000 --- a/src/backend.rs +++ /dev/null @@ -1,163 +0,0 @@ -use itertools::izip; - -pub trait VectorOps { - type Element; - - fn elwise_scalar_mul(&self, out: &mut [Self::Element], a: &[Self::Element], b: &Self::Element); - fn elwise_mul(&self, out: &mut [Self::Element], a: &[Self::Element], b: &[Self::Element]); - - fn elwise_add_mut(&self, a: &mut [Self::Element], b: &[Self::Element]); - fn elwise_mul_mut(&self, a: &mut [Self::Element], b: &[Self::Element]); - fn elwise_neg_mut(&self, a: &mut [Self::Element]); - /// inplace mutates `a`: a = a + b*c - fn elwise_fma_mut(&self, a: &mut [Self::Element], b: &[Self::Element], c: &[Self::Element]); - - fn modulus(&self) -> Self::Element; -} - -pub trait ArithmeticOps { - type Element; - - fn mul(&self, a: &Self::Element, b: &Self::Element) -> Self::Element; - fn add(&self, a: &Self::Element, b: &Self::Element) -> Self::Element; - fn sub(&self, a: &Self::Element, b: &Self::Element) -> Self::Element; - - fn modulus(&self) -> Self::Element; -} - -pub struct ModularOpsU64 { - q: u64, - logq: usize, - barrett_mu: u128, - barrett_alpha: usize, -} - -impl ModularOpsU64 { - pub fn new(q: u64) -> ModularOpsU64 { - let logq = 64 - q.leading_zeros(); - - // barrett calculation - let mu = (1u128 << (logq * 2 + 3)) / (q as u128); - let alpha = logq + 3; - - ModularOpsU64 { - q, - logq: logq as usize, - barrett_alpha: alpha as usize, - barrett_mu: mu, - } - } - - fn add_mod_fast(&self, a: u64, b: u64) -> u64 { - debug_assert!(a < self.q); - debug_assert!(b < self.q); - - let mut o = a + b; - if o >= self.q { - o -= self.q; - } - o - } - - fn sub_mod_fast(&self, a: u64, b: u64) -> u64 { - debug_assert!(a < self.q); - debug_assert!(b < self.q); - - if a > b { - a - b - } else { - (self.q + a) - b - } - } - - /// returns (a * b) % q - /// - /// - both a and b must be in range [0, 2q) - /// - output is in range [0 , q) - fn mul_mod_fast(&self, a: u64, b: u64) -> u64 { - debug_assert!(a < 2 * self.q); - debug_assert!(b < 2 * self.q); - - let ab = a as u128 * b as u128; - - // ab / (2^{n + \beta}) - // note: \beta is assumed to -2 - let tmp = ab >> (self.logq - 2); - - // k = ((ab / (2^{n + \beta})) * \mu) / 2^{\alpha - (-2)} - let k = (tmp * self.barrett_mu) >> (self.barrett_alpha + 2); - - // ab - k*p - let tmp = k * (self.q as u128); - - let mut out = (ab - tmp) as u64; - - if out >= self.q { - out -= self.q; - } - - return out; - } -} - -impl ArithmeticOps for ModularOpsU64 { - type Element = u64; - - fn add(&self, a: &Self::Element, b: &Self::Element) -> Self::Element { - self.add_mod_fast(*a, *b) - } - - fn mul(&self, a: &Self::Element, b: &Self::Element) -> Self::Element { - self.mul_mod_fast(*a, *b) - } - - fn sub(&self, a: &Self::Element, b: &Self::Element) -> Self::Element { - self.sub_mod_fast(*a, *b) - } - - fn modulus(&self) -> Self::Element { - self.q - } -} - -impl VectorOps for ModularOpsU64 { - type Element = u64; - - fn elwise_add_mut(&self, a: &mut [Self::Element], b: &[Self::Element]) { - izip!(a.iter_mut(), b.iter()).for_each(|(ai, bi)| { - *ai = self.add_mod_fast(*ai, *bi); - }); - } - - fn elwise_mul_mut(&self, a: &mut [Self::Element], b: &[Self::Element]) { - izip!(a.iter_mut(), b.iter()).for_each(|(ai, bi)| { - *ai = self.mul_mod_fast(*ai, *bi); - }); - } - - fn elwise_neg_mut(&self, a: &mut [Self::Element]) { - a.iter_mut().for_each(|ai| *ai = self.q - *ai); - } - - fn elwise_scalar_mul(&self, out: &mut [Self::Element], a: &[Self::Element], b: &Self::Element) { - izip!(out.iter_mut(), a.iter()).for_each(|(oi, ai)| { - *oi = self.mul_mod_fast(*ai, *b); - }); - } - - fn elwise_mul(&self, out: &mut [Self::Element], a: &[Self::Element], b: &[Self::Element]) { - izip!(out.iter_mut(), a.iter(), b.iter()).for_each(|(oi, ai, bi)| { - *oi = self.mul_mod_fast(*ai, *bi); - }); - } - - fn elwise_fma_mut(&self, a: &mut [Self::Element], b: &[Self::Element], c: &[Self::Element]) { - izip!(a.iter_mut(), b.iter(), c.iter()).for_each(|(ai, bi, ci)| { - *ai = self.add_mod_fast(*ai, self.mul_mod_fast(*bi, *ci)); - }); - } - - fn modulus(&self) -> Self::Element { - self.q - } -} diff --git a/src/backend/mod.rs b/src/backend/mod.rs new file mode 100644 index 0000000..d3c0721 --- /dev/null +++ b/src/backend/mod.rs @@ -0,0 +1,141 @@ +use num_traits::ToPrimitive; + +use crate::{utils::log2, Row}; + +mod modulus_u64; +mod power_of_2; +mod word_size; + +pub use modulus_u64::ModularOpsU64; +pub(crate) use power_of_2::ModulusPowerOf2; + +pub trait Modulus { + type Element; + /// Modulus value if it fits in Element + fn q(&self) -> Option; + /// Log2 of `q` + fn log_q(&self) -> usize; + /// Modulus value as f64 if it fits in f64 + fn q_as_f64(&self) -> Option; + /// Is modulus native? + fn is_native(&self) -> bool; + /// -1 in signed representaiton + fn neg_one(&self) -> Self::Element; + /// Largest unsigned value that fits in the modulus. That is, q - 1. + fn largest_unsigned_value(&self) -> Self::Element; + /// Smallest unsigned value that fits in the modulus + /// Always assmed to be 0. + fn smallest_unsigned_value(&self) -> Self::Element; + /// Convert unsigned value in signed represetation to i64 + fn map_element_to_i64(&self, v: &Self::Element) -> i64; + /// Convert f64 to signed represented in modulus + fn map_element_from_f64(&self, v: f64) -> Self::Element; + /// Convert i64 to signed represented in modulus + fn map_element_from_i64(&self, v: i64) -> Self::Element; +} + +impl Modulus for u64 { + type Element = u64; + fn is_native(&self) -> bool { + // q that fits in u64 can never be a native modulus + false + } + fn largest_unsigned_value(&self) -> Self::Element { + self - 1 + } + fn neg_one(&self) -> Self::Element { + self - 1 + } + fn smallest_unsigned_value(&self) -> Self::Element { + 0 + } + fn map_element_to_i64(&self, v: &Self::Element) -> i64 { + assert!(v <= self, "{v} must be <= {self}"); + if *v >= (self >> 1) { + -ToPrimitive::to_i64(&(self - v)).unwrap() + } else { + ToPrimitive::to_i64(v).unwrap() + } + } + fn map_element_from_f64(&self, v: f64) -> Self::Element { + let v = v.round(); + let v_u64 = v.abs().to_u64().unwrap(); + assert!(v_u64 <= self.largest_unsigned_value()); + if v < 0.0 { + self - v_u64 + } else { + v_u64 + } + } + fn map_element_from_i64(&self, v: i64) -> Self::Element { + let v_u64 = v.abs().to_u64().unwrap(); + assert!(v_u64 <= self.largest_unsigned_value()); + if v < 0 { + self - v_u64 + } else { + v_u64 + } + } + fn q(&self) -> Option { + Some(*self) + } + fn q_as_f64(&self) -> Option { + self.to_f64() + } + fn log_q(&self) -> usize { + log2(&self.q().unwrap()) + } +} + +pub trait ModInit { + type M; + fn new(modulus: Self::M) -> Self; +} + +pub trait GetModulus { + type Element; + type M: Modulus; + fn modulus(&self) -> &Self::M; +} + +pub trait VectorOps { + type Element; + + /// Sets out as `out[i] = a[i] * b` + fn elwise_scalar_mul(&self, out: &mut [Self::Element], a: &[Self::Element], b: &Self::Element); + fn elwise_mul(&self, out: &mut [Self::Element], a: &[Self::Element], b: &[Self::Element]); + + fn elwise_add_mut(&self, a: &mut [Self::Element], b: &[Self::Element]); + fn elwise_sub_mut(&self, a: &mut [Self::Element], b: &[Self::Element]); + fn elwise_mul_mut(&self, a: &mut [Self::Element], b: &[Self::Element]); + fn elwise_scalar_mul_mut(&self, a: &mut [Self::Element], b: &Self::Element); + fn elwise_neg_mut(&self, a: &mut [Self::Element]); + /// inplace mutates `a`: a = a + b*c + fn elwise_fma_mut(&self, a: &mut [Self::Element], b: &[Self::Element], c: &[Self::Element]); + fn elwise_fma_scalar_mut( + &self, + a: &mut [Self::Element], + b: &[Self::Element], + c: &Self::Element, + ); +} + +pub trait ArithmeticOps { + type Element; + fn mul(&self, a: &Self::Element, b: &Self::Element) -> Self::Element; + fn add(&self, a: &Self::Element, b: &Self::Element) -> Self::Element; + fn sub(&self, a: &Self::Element, b: &Self::Element) -> Self::Element; + fn neg(&self, a: &Self::Element) -> Self::Element; +} + +pub trait ArithmeticLazyOps { + type Element; + fn mul_lazy(&self, a: &Self::Element, b: &Self::Element) -> Self::Element; + fn add_lazy(&self, a: &Self::Element, b: &Self::Element) -> Self::Element; +} + +pub trait ShoupMatrixFMA { + /// Returns summation of `row-wise product of matrix a and b` + out where + /// each element is in range [0, 2q) + fn shoup_matrix_fma(&self, out: &mut [R::Element], a: &[R], a_shoup: &[R], b: &[R]); +} diff --git a/src/backend/modulus_u64.rs b/src/backend/modulus_u64.rs new file mode 100644 index 0000000..e3ff495 --- /dev/null +++ b/src/backend/modulus_u64.rs @@ -0,0 +1,337 @@ +use itertools::izip; +use num_traits::WrappingMul; + +use super::{ + ArithmeticLazyOps, ArithmeticOps, GetModulus, ModInit, Modulus, ShoupMatrixFMA, VectorOps, +}; +use crate::RowMut; + +pub struct ModularOpsU64 { + q: u64, + q_twice: u64, + logq: usize, + barrett_mu: u128, + barrett_alpha: usize, + modulus: T, +} + +impl ModInit for ModularOpsU64 +where + T: Modulus, +{ + type M = T; + fn new(modulus: Self::M) -> ModularOpsU64 { + assert!(!modulus.is_native()); + + // largest unsigned value modulus fits is modulus-1 + let q = modulus.largest_unsigned_value() + 1; + let logq = 64 - (q + 1u64).leading_zeros(); + + // barrett calculation + let mu = (1u128 << (logq * 2 + 3)) / (q as u128); + let alpha = logq + 3; + + ModularOpsU64 { + q, + q_twice: q << 1, + logq: logq as usize, + barrett_alpha: alpha as usize, + barrett_mu: mu, + modulus, + } + } +} + +impl ModularOpsU64 { + fn add_mod_fast(&self, a: u64, b: u64) -> u64 { + debug_assert!(a < self.q); + debug_assert!(b < self.q); + + let mut o = a + b; + if o >= self.q { + o -= self.q; + } + o + } + + fn add_mod_fast_lazy(&self, a: u64, b: u64) -> u64 { + debug_assert!(a < self.q_twice); + debug_assert!(b < self.q_twice); + + let mut o = a + b; + if o >= self.q_twice { + o -= self.q_twice; + } + o + } + + fn sub_mod_fast(&self, a: u64, b: u64) -> u64 { + debug_assert!(a < self.q); + debug_assert!(b < self.q); + + if a >= b { + a - b + } else { + (self.q + a) - b + } + } + + // returns (a * b) % q + /// + /// - both a and b must be in range [0, 2q) + /// - output is in range [0 , 2q) + fn mul_mod_fast_lazy(&self, a: u64, b: u64) -> u64 { + debug_assert!(a < 2 * self.q); + debug_assert!(b < 2 * self.q); + + let ab = a as u128 * b as u128; + + // ab / (2^{n + \beta}) + // note: \beta is assumed to -2 + let tmp = ab >> (self.logq - 2); + + // k = ((ab / (2^{n + \beta})) * \mu) / 2^{\alpha - (-2)} + let k = (tmp * self.barrett_mu) >> (self.barrett_alpha + 2); + + // ab - k*p + let tmp = k * (self.q as u128); + + (ab - tmp) as u64 + } + + /// returns (a * b) % q + /// + /// - both a and b must be in range [0, 2q) + /// - output is in range [0 , q) + fn mul_mod_fast(&self, a: u64, b: u64) -> u64 { + debug_assert!(a < 2 * self.q); + debug_assert!(b < 2 * self.q); + + let ab = a as u128 * b as u128; + + // ab / (2^{n + \beta}) + // note: \beta is assumed to -2 + let tmp = ab >> (self.logq - 2); + + // k = ((ab / (2^{n + \beta})) * \mu) / 2^{\alpha - (-2)} + let k = (tmp * self.barrett_mu) >> (self.barrett_alpha + 2); + + // ab - k*p + let tmp = k * (self.q as u128); + + let mut out = (ab - tmp) as u64; + + if out >= self.q { + out -= self.q; + } + + return out; + } +} + +impl ArithmeticOps for ModularOpsU64 { + type Element = u64; + + fn add(&self, a: &Self::Element, b: &Self::Element) -> Self::Element { + self.add_mod_fast(*a, *b) + } + + fn mul(&self, a: &Self::Element, b: &Self::Element) -> Self::Element { + self.mul_mod_fast(*a, *b) + } + + fn sub(&self, a: &Self::Element, b: &Self::Element) -> Self::Element { + self.sub_mod_fast(*a, *b) + } + + fn neg(&self, a: &Self::Element) -> Self::Element { + self.q - *a + } + + // fn modulus(&self) -> Self::Element { + // self.q + // } +} + +impl ArithmeticLazyOps for ModularOpsU64 { + type Element = u64; + fn add_lazy(&self, a: &Self::Element, b: &Self::Element) -> Self::Element { + self.add_mod_fast_lazy(*a, *b) + } + fn mul_lazy(&self, a: &Self::Element, b: &Self::Element) -> Self::Element { + self.mul_mod_fast_lazy(*a, *b) + } +} + +impl VectorOps for ModularOpsU64 { + type Element = u64; + + fn elwise_add_mut(&self, a: &mut [Self::Element], b: &[Self::Element]) { + izip!(a.iter_mut(), b.iter()).for_each(|(ai, bi)| { + *ai = self.add_mod_fast(*ai, *bi); + }); + } + + fn elwise_sub_mut(&self, a: &mut [Self::Element], b: &[Self::Element]) { + izip!(a.iter_mut(), b.iter()).for_each(|(ai, bi)| { + *ai = self.sub_mod_fast(*ai, *bi); + }); + } + + fn elwise_mul_mut(&self, a: &mut [Self::Element], b: &[Self::Element]) { + izip!(a.iter_mut(), b.iter()).for_each(|(ai, bi)| { + *ai = self.mul_mod_fast(*ai, *bi); + }); + } + + fn elwise_neg_mut(&self, a: &mut [Self::Element]) { + a.iter_mut().for_each(|ai| *ai = self.q - *ai); + } + + fn elwise_scalar_mul(&self, out: &mut [Self::Element], a: &[Self::Element], b: &Self::Element) { + izip!(out.iter_mut(), a.iter()).for_each(|(oi, ai)| { + *oi = self.mul_mod_fast(*ai, *b); + }); + } + + fn elwise_mul(&self, out: &mut [Self::Element], a: &[Self::Element], b: &[Self::Element]) { + izip!(out.iter_mut(), a.iter(), b.iter()).for_each(|(oi, ai, bi)| { + *oi = self.mul_mod_fast(*ai, *bi); + }); + } + + fn elwise_scalar_mul_mut(&self, a: &mut [Self::Element], b: &Self::Element) { + a.iter_mut().for_each(|ai| { + *ai = self.mul_mod_fast(*ai, *b); + }); + } + + fn elwise_fma_mut(&self, a: &mut [Self::Element], b: &[Self::Element], c: &[Self::Element]) { + izip!(a.iter_mut(), b.iter(), c.iter()).for_each(|(ai, bi, ci)| { + *ai = self.add_mod_fast(*ai, self.mul_mod_fast(*bi, *ci)); + }); + } + + fn elwise_fma_scalar_mut( + &self, + a: &mut [Self::Element], + b: &[Self::Element], + c: &Self::Element, + ) { + izip!(a.iter_mut(), b.iter()).for_each(|(ai, bi)| { + *ai = self.add_mod_fast(*ai, self.mul_mod_fast(*bi, *c)); + }); + } + + // fn modulus(&self) -> Self::Element { + // self.q + // } +} + +impl, T> ShoupMatrixFMA for ModularOpsU64 { + fn shoup_matrix_fma(&self, out: &mut [R::Element], a: &[R], a_shoup: &[R], b: &[R]) { + assert!(a.len() == a_shoup.len()); + assert!( + a.len() == b.len(), + "Unequal length {}!={}", + a.len(), + b.len() + ); + + let q = self.q; + let q_twice = self.q << 1; + + izip!(a.iter(), a_shoup.iter(), b.iter()).for_each(|(a_row, a_shoup_row, b_row)| { + izip!( + out.as_mut().iter_mut(), + a_row.as_ref().iter(), + a_shoup_row.as_ref().iter(), + b_row.as_ref().iter() + ) + .for_each(|(o, a0, a0_shoup, b0)| { + let quotient = ((*a0_shoup as u128 * *b0 as u128) >> 64) as u64; + let mut v = (a0.wrapping_mul(b0)).wrapping_add(*o); + v = v.wrapping_sub(q.wrapping_mul(quotient)); + + if v >= q_twice { + v -= q_twice; + } + + *o = v; + }); + }); + } +} + +impl GetModulus for ModularOpsU64 +where + T: Modulus, +{ + type Element = T::Element; + type M = T; + fn modulus(&self) -> &Self::M { + &self.modulus + } +} + +#[cfg(test)] +mod tests { + use super::*; + use itertools::Itertools; + use rand::{thread_rng, Rng}; + use rand_distr::Uniform; + + #[test] + fn fma() { + let mut rng = thread_rng(); + let prime = 36028797017456641; + let ring_size = 1 << 3; + + let dist = Uniform::new(0, prime); + let d = 2; + let a0_matrix = (0..d) + .into_iter() + .map(|_| (&mut rng).sample_iter(dist).take(ring_size).collect_vec()) + .collect_vec(); + // a0 in shoup representation + let a0_shoup_matrix = a0_matrix + .iter() + .map(|r| { + r.iter() + .map(|v| { + // $(v * 2^{\beta}) / p$ + ((*v as u128 * (1u128 << 64)) / prime as u128) as u64 + }) + .collect_vec() + }) + .collect_vec(); + let a1_matrix = (0..d) + .into_iter() + .map(|_| (&mut rng).sample_iter(dist).take(ring_size).collect_vec()) + .collect_vec(); + + let modop = ModularOpsU64::new(prime); + + let mut out_shoup_fma_lazy = vec![0u64; ring_size]; + modop.shoup_matrix_fma( + &mut out_shoup_fma_lazy, + &a0_matrix, + &a0_shoup_matrix, + &a1_matrix, + ); + let out_shoup_fma = out_shoup_fma_lazy + .iter() + .map(|v| if *v >= prime { v - prime } else { *v }) + .collect_vec(); + + // expected + let mut out_expected = vec![0u64; ring_size]; + izip!(a0_matrix.iter(), a1_matrix.iter()).for_each(|(a_r, b_r)| { + izip!(out_expected.iter_mut(), a_r.iter(), b_r.iter()).for_each(|(o, a0, a1)| { + *o = (*o + ((*a0 as u128 * *a1 as u128) % prime as u128) as u64) % prime; + }); + }); + + assert_eq!(out_expected, out_shoup_fma); + } +} diff --git a/src/backend/power_of_2.rs b/src/backend/power_of_2.rs new file mode 100644 index 0000000..e89a6e1 --- /dev/null +++ b/src/backend/power_of_2.rs @@ -0,0 +1,112 @@ +use itertools::izip; + +use crate::{ArithmeticOps, ModInit, VectorOps}; + +use super::{GetModulus, Modulus}; + +pub(crate) struct ModulusPowerOf2 { + modulus: T, + /// Modulus mask: (1 << q) - 1 + mask: u64, +} + +impl ArithmeticOps for ModulusPowerOf2 { + type Element = u64; + #[inline] + fn add(&self, a: &Self::Element, b: &Self::Element) -> Self::Element { + (a.wrapping_add(*b)) & self.mask + } + #[inline] + fn sub(&self, a: &Self::Element, b: &Self::Element) -> Self::Element { + (a.wrapping_sub(*b)) & self.mask + } + #[inline] + fn mul(&self, a: &Self::Element, b: &Self::Element) -> Self::Element { + (a.wrapping_mul(*b)) & self.mask + } + #[inline] + fn neg(&self, a: &Self::Element) -> Self::Element { + (0u64.wrapping_sub(*a)) & self.mask + } +} + +impl VectorOps for ModulusPowerOf2 { + type Element = u64; + + #[inline] + fn elwise_add_mut(&self, a: &mut [Self::Element], b: &[Self::Element]) { + izip!(a.iter_mut(), b.iter()).for_each(|(a0, b0)| *a0 = (a0.wrapping_add(*b0)) & self.mask); + } + + #[inline] + fn elwise_mul_mut(&self, a: &mut [Self::Element], b: &[Self::Element]) { + izip!(a.iter_mut(), b.iter()).for_each(|(a0, b0)| *a0 = (a0.wrapping_mul(*b0)) & self.mask); + } + + #[inline] + fn elwise_neg_mut(&self, a: &mut [Self::Element]) { + a.iter_mut() + .for_each(|a0| *a0 = 0u64.wrapping_sub(*a0) & self.mask); + } + #[inline] + fn elwise_sub_mut(&self, a: &mut [Self::Element], b: &[Self::Element]) { + izip!(a.iter_mut(), b.iter()).for_each(|(a0, b0)| *a0 = (a0.wrapping_sub(*b0)) & self.mask); + } + + #[inline] + fn elwise_fma_mut(&self, a: &mut [Self::Element], b: &[Self::Element], c: &[Self::Element]) { + izip!(a.iter_mut(), b.iter(), c.iter()).for_each(|(a0, b0, c0)| { + *a0 = a0.wrapping_add(b0.wrapping_mul(*c0)) & self.mask; + }); + } + + #[inline] + fn elwise_fma_scalar_mut( + &self, + a: &mut [Self::Element], + b: &[Self::Element], + c: &Self::Element, + ) { + izip!(a.iter_mut(), b.iter()).for_each(|(a0, b0)| { + *a0 = a0.wrapping_add(b0.wrapping_mul(*c)) & self.mask; + }); + } + #[inline] + fn elwise_scalar_mul_mut(&self, a: &mut [Self::Element], b: &Self::Element) { + a.iter_mut() + .for_each(|a0| *a0 = a0.wrapping_mul(*b) & self.mask) + } + + #[inline] + fn elwise_mul(&self, out: &mut [Self::Element], a: &[Self::Element], b: &[Self::Element]) { + izip!(out.iter_mut(), a.iter(), b.iter()).for_each(|(o0, a0, b0)| { + *o0 = a0.wrapping_mul(*b0) & self.mask; + }); + } + + #[inline] + fn elwise_scalar_mul(&self, out: &mut [Self::Element], a: &[Self::Element], b: &Self::Element) { + izip!(out.iter_mut(), a.iter()).for_each(|(o0, a0)| { + *o0 = a0.wrapping_mul(*b) & self.mask; + }); + } +} + +impl> ModInit for ModulusPowerOf2 { + type M = T; + fn new(modulus: Self::M) -> Self { + assert!(!modulus.is_native()); + assert!(modulus.q().unwrap().is_power_of_two()); + let q = modulus.q().unwrap(); + let mask = q - 1; + Self { modulus, mask } + } +} + +impl> GetModulus for ModulusPowerOf2 { + type Element = u64; + type M = T; + fn modulus(&self) -> &Self::M { + &self.modulus + } +} diff --git a/src/backend/word_size.rs b/src/backend/word_size.rs new file mode 100644 index 0000000..82f5b5a --- /dev/null +++ b/src/backend/word_size.rs @@ -0,0 +1,124 @@ +use itertools::izip; +use num_traits::{WrappingAdd, WrappingMul, WrappingSub, Zero}; + +use super::{ArithmeticOps, GetModulus, ModInit, Modulus, VectorOps}; + +pub struct WordSizeModulus { + modulus: T, +} + +impl ModInit for WordSizeModulus +where + T: Modulus, +{ + type M = T; + fn new(modulus: T) -> Self { + assert!(modulus.is_native()); + // For now assume ModulusOpsU64 is only used for u64 + Self { modulus: modulus } + } +} + +impl ArithmeticOps for WordSizeModulus +where + T: Modulus, + T::Element: WrappingAdd + WrappingSub + WrappingMul + Zero, +{ + type Element = T::Element; + fn add(&self, a: &Self::Element, b: &Self::Element) -> Self::Element { + T::Element::wrapping_add(a, b) + } + + fn mul(&self, a: &Self::Element, b: &Self::Element) -> Self::Element { + T::Element::wrapping_mul(a, b) + } + + fn neg(&self, a: &Self::Element) -> Self::Element { + T::Element::wrapping_sub(&T::Element::zero(), a) + } + + fn sub(&self, a: &Self::Element, b: &Self::Element) -> Self::Element { + T::Element::wrapping_sub(a, b) + } +} + +impl VectorOps for WordSizeModulus +where + T: Modulus, + T::Element: WrappingAdd + WrappingSub + WrappingMul + Zero, +{ + type Element = T::Element; + + fn elwise_add_mut(&self, a: &mut [Self::Element], b: &[Self::Element]) { + izip!(a.iter_mut(), b.iter()).for_each(|(ai, bi)| { + *ai = T::Element::wrapping_add(ai, bi); + }); + } + + fn elwise_sub_mut(&self, a: &mut [Self::Element], b: &[Self::Element]) { + izip!(a.iter_mut(), b.iter()).for_each(|(ai, bi)| { + *ai = T::Element::wrapping_sub(ai, bi); + }); + } + + fn elwise_mul_mut(&self, a: &mut [Self::Element], b: &[Self::Element]) { + izip!(a.iter_mut(), b.iter()).for_each(|(ai, bi)| { + *ai = T::Element::wrapping_mul(ai, bi); + }); + } + + fn elwise_neg_mut(&self, a: &mut [Self::Element]) { + a.iter_mut() + .for_each(|ai| *ai = T::Element::wrapping_sub(&T::Element::zero(), ai)); + } + + fn elwise_scalar_mul(&self, out: &mut [Self::Element], a: &[Self::Element], b: &Self::Element) { + izip!(out.iter_mut(), a.iter()).for_each(|(oi, ai)| { + *oi = T::Element::wrapping_mul(ai, b); + }); + } + + fn elwise_mul(&self, out: &mut [Self::Element], a: &[Self::Element], b: &[Self::Element]) { + izip!(out.iter_mut(), a.iter(), b.iter()).for_each(|(oi, ai, bi)| { + *oi = T::Element::wrapping_mul(ai, bi); + }); + } + + fn elwise_scalar_mul_mut(&self, a: &mut [Self::Element], b: &Self::Element) { + a.iter_mut().for_each(|ai| { + *ai = T::Element::wrapping_mul(ai, b); + }); + } + + fn elwise_fma_mut(&self, a: &mut [Self::Element], b: &[Self::Element], c: &[Self::Element]) { + izip!(a.iter_mut(), b.iter(), c.iter()).for_each(|(ai, bi, ci)| { + *ai = T::Element::wrapping_add(ai, &T::Element::wrapping_mul(bi, ci)); + }); + } + + fn elwise_fma_scalar_mut( + &self, + a: &mut [Self::Element], + b: &[Self::Element], + c: &Self::Element, + ) { + izip!(a.iter_mut(), b.iter()).for_each(|(ai, bi)| { + *ai = T::Element::wrapping_add(ai, &T::Element::wrapping_mul(bi, c)); + }); + } + + // fn modulus(&self) -> &T { + // &self.modulus + // } +} + +impl GetModulus for WordSizeModulus +where + T: Modulus, +{ + type Element = T::Element; + type M = T; + fn modulus(&self) -> &Self::M { + &self.modulus + } +} diff --git a/src/bool/evaluator.rs b/src/bool/evaluator.rs new file mode 100644 index 0000000..6b3f0b9 --- /dev/null +++ b/src/bool/evaluator.rs @@ -0,0 +1,2323 @@ +use std::{ + collections::HashMap, + fmt::{Debug, Display}, + marker::PhantomData, + usize, +}; + +use itertools::{izip, Itertools}; +use num_traits::{FromPrimitive, One, PrimInt, ToPrimitive, WrappingAdd, WrappingSub, Zero}; +use rand_distr::uniform::SampleUniform; + +use crate::{ + backend::{ArithmeticOps, GetModulus, ModInit, Modulus, ShoupMatrixFMA, VectorOps}, + bool::parameters::ParameterVariant, + decomposer::{Decomposer, DefaultDecomposer, NumInfo, RlweDecomposer}, + lwe::{decrypt_lwe, encrypt_lwe, seeded_lwe_ksk_keygen}, + multi_party::{ + non_interactive_ksk_gen, non_interactive_ksk_zero_encryptions_for_other_party_i, + public_key_share, + }, + ntt::{Ntt, NttInit}, + pbs::{pbs, PbsInfo, PbsKey, WithShoupRepr}, + random::{ + DefaultSecureRng, NewWithSeed, RandomFill, RandomFillGaussianInModulus, + RandomFillUniformInModulus, + }, + rgsw::{ + generate_auto_map, public_key_encrypt_rgsw, rgsw_by_rgsw_inplace, rgsw_x_rgsw_scratch_rows, + rlwe_auto_scratch_rows, rlwe_x_rgsw_scratch_rows, secret_key_encrypt_rgsw, + seeded_auto_key_gen, RgswCiphertextMutRef, RgswCiphertextRef, RuntimeScratchMutRef, + }, + utils::{ + encode_x_pow_si_with_emebedding_factor, mod_exponent, puncture_p_rng, TryConvertFrom1, + WithLocal, + }, + BooleanGates, Encoder, Matrix, MatrixEntity, MatrixMut, RowEntity, RowMut, +}; + +use super::{ + keys::{ + ClientKey, CommonReferenceSeededCollectivePublicKeyShare, + CommonReferenceSeededInteractiveMultiPartyServerKeyShare, + CommonReferenceSeededNonInteractiveMultiPartyServerKeyShare, + InteractiveMultiPartyClientKey, NonInteractiveMultiPartyClientKey, + SeededInteractiveMultiPartyServerKey, SeededNonInteractiveMultiPartyServerKey, + SeededSinglePartyServerKey, SinglePartyClientKey, + }, + parameters::{BoolParameters, CiphertextModulus, DecompositionCount, DoubleDecomposerParams}, +}; + +/// Common reference seed used for Interactive multi-party, +/// +/// Seeds for public key shares and differents parts of server key shares are +/// derived from common reference seed with different puncture rountines. +/// +/// ## Punctures +/// +/// Initial Seed: +/// Puncture 1 -> Public key share seed +/// Puncture 2 -> Main server key share seed +/// Puncture 1 -> Auto keys cipertexts seed +/// Puncture 2 -> LWE ksk seed +#[derive(Clone, PartialEq)] +pub struct InteractiveMultiPartyCrs { + pub(super) seed: S, +} + +impl InteractiveMultiPartyCrs<[u8; 32]> { + pub(super) fn random() -> Self { + DefaultSecureRng::with_local_mut(|rng| { + let mut seed = [0u8; 32]; + rng.fill_bytes(&mut seed); + Self { seed } + }) + } +} + +impl InteractiveMultiPartyCrs { + /// Seed to generate public key share + fn public_key_share_seed + RandomFill>(&self) -> S { + let mut prng = Rng::new_with_seed(self.seed); + puncture_p_rng(&mut prng, 1) + } + + /// Main server key share seed + fn key_seed + RandomFill>(&self) -> S { + let mut prng = Rng::new_with_seed(self.seed); + puncture_p_rng(&mut prng, 2) + } + + pub(super) fn auto_keys_cts_seed + RandomFill>(&self) -> S { + let mut key_prng = Rng::new_with_seed(self.key_seed::()); + puncture_p_rng(&mut key_prng, 1) + } + + pub(super) fn lwe_ksk_cts_seed_seed + RandomFill>(&self) -> S { + let mut key_prng = Rng::new_with_seed(self.key_seed::()); + puncture_p_rng(&mut key_prng, 2) + } +} + +/// Common reference seed used for non-interactive multi-party. +/// +/// Initial Seed +/// Puncture 1 -> Key Seed +/// Puncture 1 -> Rgsw ciphertext seed +/// Puncture l+1 -> Seed for zero encs and non-interactive +/// multi-party RGSW ciphertexts of +/// l^th LWE index. +/// Puncture 2 -> auto keys seed +/// Puncture 3 -> Lwe key switching key seed +/// Puncture 2 -> user specific seed for u_j to s ksk +/// Punture j+1 -> user j's seed +#[derive(Clone, PartialEq)] +pub struct NonInteractiveMultiPartyCrs { + pub(super) seed: S, +} + +impl NonInteractiveMultiPartyCrs<[u8; 32]> { + pub(super) fn random() -> Self { + DefaultSecureRng::with_local_mut(|rng| { + let mut seed = [0u8; 32]; + rng.fill_bytes(&mut seed); + Self { seed } + }) + } +} + +impl NonInteractiveMultiPartyCrs { + fn key_seed + RandomFill>(&self) -> S { + let mut p_rng = R::new_with_seed(self.seed); + puncture_p_rng(&mut p_rng, 1) + } + + pub(crate) fn ni_rgsw_cts_main_seed + RandomFill>(&self) -> S { + let mut p_rng = R::new_with_seed(self.key_seed::()); + puncture_p_rng(&mut p_rng, 1) + } + + pub(crate) fn ni_rgsw_ct_seed_for_index + RandomFill>( + &self, + lwe_index: usize, + ) -> S { + let mut p_rng = R::new_with_seed(self.ni_rgsw_cts_main_seed::()); + puncture_p_rng(&mut p_rng, lwe_index + 1) + } + + pub(crate) fn auto_keys_cts_seed + RandomFill>(&self) -> S { + let mut p_rng = R::new_with_seed(self.key_seed::()); + puncture_p_rng(&mut p_rng, 2) + } + + pub(crate) fn lwe_ksk_cts_seed + RandomFill>(&self) -> S { + let mut p_rng = R::new_with_seed(self.key_seed::()); + puncture_p_rng(&mut p_rng, 3) + } + + fn ui_to_s_ks_seed + RandomFill>(&self) -> S { + let mut p_rng = R::new_with_seed(self.seed); + puncture_p_rng(&mut p_rng, 2) + } + + pub(crate) fn ui_to_s_ks_seed_for_user_i + RandomFill>( + &self, + user_i: usize, + ) -> S { + let ks_seed = self.ui_to_s_ks_seed::(); + let mut p_rng = R::new_with_seed(ks_seed); + + puncture_p_rng(&mut p_rng, user_i + 1) + } +} + +struct ScratchMemory +where + M: Matrix, +{ + lwe_vector: M::R, + decomposition_matrix: M, +} + +impl ScratchMemory +where + M::R: RowEntity, +{ + fn new(parameters: &BoolParameters) -> Self { + // Vector to store LWE ciphertext with LWE dimesnion n + let lwe_vector = M::R::zeros(parameters.lwe_n().0 + 1); + + // PBS perform two operations at runtime: RLWE x RGW and RLWE auto. Since the + // operations are performed serially same scratch space can be used for both. + // Hence we create scratch space that contains maximum amount of rows that + // suffices for RLWE x RGSW and RLWE auto + let decomposition_matrix = M::zeros( + std::cmp::max( + rlwe_x_rgsw_scratch_rows(parameters.rlwe_by_rgsw_decomposition_params()), + rlwe_auto_scratch_rows(parameters.auto_decomposition_param()), + ), + parameters.rlwe_n().0, + ); + + Self { + lwe_vector, + decomposition_matrix, + } + } +} + +pub(super) trait BoolEncoding { + type Element; + fn true_el(&self) -> Self::Element; + fn false_el(&self) -> Self::Element; + fn qby4(&self) -> Self::Element; + fn decode(&self, m: Self::Element) -> bool; +} + +impl BoolEncoding for CiphertextModulus +where + CiphertextModulus: Modulus, + T: PrimInt + NumInfo, +{ + type Element = T; + + fn qby4(&self) -> Self::Element { + if self.is_native() { + T::one() << ((T::BITS as usize) - 2) + } else { + self.q().unwrap() >> 2 + } + } + /// Q/8 + fn true_el(&self) -> Self::Element { + if self.is_native() { + T::one() << ((T::BITS as usize) - 3) + } else { + self.q().unwrap() >> 3 + } + } + /// -Q/8 + fn false_el(&self) -> Self::Element { + self.largest_unsigned_value() - self.true_el() + T::one() + } + fn decode(&self, m: Self::Element) -> bool { + let qby8 = self.true_el(); + let m = (((m + qby8).to_f64().unwrap() * 4.0f64) / self.q_as_f64().unwrap()).round() + as usize + % 4usize; + + if m == 0 { + return false; + } else if m == 1 { + return true; + } else { + panic!("Incorrect bool decryption. Got m={m} but expected m to be 0 or 1") + } + } +} + +impl Encoder for B +where + B: BoolEncoding, +{ + fn encode(&self, v: bool) -> B::Element { + if v { + self.true_el() + } else { + self.false_el() + } + } +} + +pub(super) struct BoolPbsInfo { + auto_decomposer: DefaultDecomposer, + rlwe_rgsw_decomposer: ( + DefaultDecomposer, + DefaultDecomposer, + ), + lwe_decomposer: DefaultDecomposer, + g_k_dlog_map: Vec, + rlwe_nttop: Ntt, + rlwe_modop: RlweModOp, + lwe_modop: LweModOp, + embedding_factor: usize, + rlwe_qby4: M::MatElement, + rlwe_auto_maps: Vec<(Vec, Vec)>, + parameters: BoolParameters, +} + +impl PbsInfo for BoolPbsInfo +where + M::MatElement: PrimInt + + WrappingSub + + NumInfo + + FromPrimitive + + From + + Display + + WrappingAdd + + Debug, + RlweModOp: ArithmeticOps + ShoupMatrixFMA, + LweModOp: ArithmeticOps + VectorOps, + NttOp: Ntt, +{ + type M = M; + type Modulus = CiphertextModulus; + type D = DefaultDecomposer; + type RlweModOp = RlweModOp; + type LweModOp = LweModOp; + type NttOp = NttOp; + fn rlwe_auto_map(&self, k: usize) -> &(Vec, Vec) { + &self.rlwe_auto_maps[k] + } + fn br_q(&self) -> usize { + *self.parameters.br_q() + } + fn lwe_decomposer(&self) -> &Self::D { + &self.lwe_decomposer + } + fn rlwe_rgsw_decomposer(&self) -> &(Self::D, Self::D) { + &self.rlwe_rgsw_decomposer + } + fn auto_decomposer(&self) -> &Self::D { + &self.auto_decomposer + } + fn embedding_factor(&self) -> usize { + self.embedding_factor + } + fn g(&self) -> isize { + self.parameters.g() as isize + } + fn w(&self) -> usize { + self.parameters.w() + } + fn g_k_dlog_map(&self) -> &[usize] { + &self.g_k_dlog_map + } + fn lwe_n(&self) -> usize { + self.parameters.lwe_n().0 + } + fn lwe_q(&self) -> &Self::Modulus { + self.parameters.lwe_q() + } + fn rlwe_n(&self) -> usize { + self.parameters.rlwe_n().0 + } + fn rlwe_q(&self) -> &Self::Modulus { + self.parameters.rlwe_q() + } + fn modop_lweq(&self) -> &Self::LweModOp { + &self.lwe_modop + } + fn modop_rlweq(&self) -> &Self::RlweModOp { + &self.rlwe_modop + } + fn nttop_rlweq(&self) -> &Self::NttOp { + &self.rlwe_nttop + } +} + +pub(crate) struct BoolEvaluator +where + M: Matrix, +{ + pbs_info: BoolPbsInfo, + scratch_memory: ScratchMemory, + nand_test_vec: M::R, + and_test_vec: M::R, + or_test_vec: M::R, + nor_test_vec: M::R, + xor_test_vec: M::R, + xnor_test_vec: M::R, + /// Non-interactive u_i -> s key switch decomposer + ni_ui_to_s_ks_decomposer: Option>, + _phantom: PhantomData, +} + +impl + BoolEvaluator +{ + pub(crate) fn parameters(&self) -> &BoolParameters { + &self.pbs_info.parameters + } + + pub(super) fn pbs_info(&self) -> &BoolPbsInfo { + &self.pbs_info + } + + pub(super) fn ni_ui_to_s_ks_decomposer(&self) -> &Option> { + &self.ni_ui_to_s_ks_decomposer + } +} + +fn trim_rgsw_ct_matrix_from_rgrg_to_rlrg< + M: MatrixMut + MatrixEntity, + D: DoubleDecomposerParams, +>( + rgsw_ct_in: M, + rgrg_params: D, + rlrg_params: D, +) -> M +where + M::R: RowMut, + M::MatElement: Copy, +{ + let (rgswrgsw_d_a, rgswrgsw_d_b) = ( + rgrg_params.decomposition_count_a(), + rgrg_params.decomposition_count_b(), + ); + let (rlrg_d_a, rlrg_d_b) = ( + rlrg_params.decomposition_count_a(), + rlrg_params.decomposition_count_b(), + ); + let rgsw_ct_rows_in = rgswrgsw_d_a.0 * 2 + rgswrgsw_d_b.0 * 2; + let rgsw_ct_rows_out = rlrg_d_a.0 * 2 + rlrg_d_b.0 * 2; + assert!(rgsw_ct_in.dimension().0 == rgsw_ct_rows_in); + assert!(rgswrgsw_d_a.0 >= rlrg_d_a.0, "RGSWxRGSW part A decomposition count {} must be >= RLWExRGSW part A decomposition count {}", rgswrgsw_d_a.0 , rlrg_d_a.0); + assert!(rgswrgsw_d_b.0 >= rlrg_d_b.0, "RGSWxRGSW part B decomposition count {} must be >= RLWExRGSW part B decomposition count {}", rgswrgsw_d_b.0 , rlrg_d_b.0); + + let mut reduced_ct_i_out = M::zeros(rgsw_ct_rows_out, rgsw_ct_in.dimension().1); + + // RLWE'(-sm) part A + izip!( + reduced_ct_i_out.iter_rows_mut().take(rlrg_d_a.0), + rgsw_ct_in + .iter_rows() + .skip(rgswrgsw_d_a.0 - rlrg_d_a.0) + .take(rlrg_d_a.0) + ) + .for_each(|(to_ri, from_ri)| { + to_ri.as_mut().copy_from_slice(from_ri.as_ref()); + }); + + // RLWE'(-sm) part B + izip!( + reduced_ct_i_out + .iter_rows_mut() + .skip(rlrg_d_a.0) + .take(rlrg_d_a.0), + rgsw_ct_in + .iter_rows() + .skip(rgswrgsw_d_a.0 + (rgswrgsw_d_a.0 - rlrg_d_a.0)) + .take(rlrg_d_a.0) + ) + .for_each(|(to_ri, from_ri)| { + to_ri.as_mut().copy_from_slice(from_ri.as_ref()); + }); + + // RLWE'(m) Part A + izip!( + reduced_ct_i_out + .iter_rows_mut() + .skip(rlrg_d_a.0 * 2) + .take(rlrg_d_b.0), + rgsw_ct_in + .iter_rows() + .skip(rgswrgsw_d_a.0 * 2 + (rgswrgsw_d_b.0 - rlrg_d_b.0)) + .take(rlrg_d_b.0) + ) + .for_each(|(to_ri, from_ri)| { + to_ri.as_mut().copy_from_slice(from_ri.as_ref()); + }); + + // RLWE'(m) Part B + izip!( + reduced_ct_i_out + .iter_rows_mut() + .skip(rlrg_d_a.0 * 2 + rlrg_d_b.0) + .take(rlrg_d_b.0), + rgsw_ct_in + .iter_rows() + .skip(rgswrgsw_d_a.0 * 2 + rgswrgsw_d_b.0 + (rgswrgsw_d_b.0 - rlrg_d_b.0)) + .take(rlrg_d_b.0) + ) + .for_each(|(to_ri, from_ri)| { + to_ri.as_mut().copy_from_slice(from_ri.as_ref()); + }); + + reduced_ct_i_out +} + +fn produce_rgsw_ciphertext_from_ni_rgsw< + M: MatrixMut + MatrixEntity, + D: RlweDecomposer, + ModOp: VectorOps, + NttOp: Ntt, +>( + ni_rgsw_ct: &M, + aggregated_decomposed_ni_rgsw_zero_encs: &[M], + decomposed_neg_ais: &[M], + decomposer: &D, + parameters: &BoolParameters, + uj_to_s_ksk: (&M, &M), + rlwe_modop: &ModOp, + nttop: &NttOp, + out_eval: bool, +) -> M +where + ::R: RowMut + Clone, +{ + let max_decomposer = + if decomposer.a().decomposition_count().0 > decomposer.b().decomposition_count().0 { + decomposer.a() + } else { + decomposer.b() + }; + + assert!( + ni_rgsw_ct.dimension() + == ( + max_decomposer.decomposition_count().0, + parameters.rlwe_n().0 + ) + ); + assert!( + aggregated_decomposed_ni_rgsw_zero_encs.len() == decomposer.a().decomposition_count().0, + ); + assert!(decomposed_neg_ais.len() == decomposer.b().decomposition_count().0); + + let mut rgsw_i = M::zeros( + decomposer.a().decomposition_count().0 * 2 + decomposer.b().decomposition_count().0 * 2, + parameters.rlwe_n().0, + ); + let (rlwe_dash_nsm, rlwe_dash_m) = + rgsw_i.split_at_row_mut(decomposer.a().decomposition_count().0 * 2); + + // RLWE'_{s}(-sm) + // Key switch `s * a_{i, l} + e` using ksk(u_j -> s) to produce RLWE(s * + // u_{j=user_id} * a_{i, l}). + // + // Then set RLWE_{s}(-s B^i m) = (0, u_{j=user_id} * a_{i, l} + e + B^i m) + + // RLWE(s * u_{j=user_id} * a_{i, l}) + { + let (rlwe_dash_nsm_parta, rlwe_dash_nsm_partb) = + rlwe_dash_nsm.split_at_mut(decomposer.a().decomposition_count().0); + izip!( + rlwe_dash_nsm_parta.iter_mut(), + rlwe_dash_nsm_partb.iter_mut(), + ni_rgsw_ct.iter_rows().skip( + max_decomposer.decomposition_count().0 - decomposer.a().decomposition_count().0 + ), + aggregated_decomposed_ni_rgsw_zero_encs.iter() + ) + .for_each(|(rlwe_a, rlwe_b, ni_rlwe_ct, decomp_zero_enc)| { + // KS(s * a_{i, l} + e) = RLWE(s * u_j * + // a_{i, l}) using user j's Ksk + izip!( + decomp_zero_enc.iter_rows(), + uj_to_s_ksk.0.iter_rows(), + uj_to_s_ksk.1.iter_rows() + ) + .for_each(|(c, pb, pa)| { + rlwe_modop.elwise_fma_mut(rlwe_b.as_mut(), pb.as_ref(), c.as_ref()); + rlwe_modop.elwise_fma_mut(rlwe_a.as_mut(), pa.as_ref(), c.as_ref()); + }); + + // RLWE(-s beta^i m) = (0, u_j * a_{j, l} + + // e + beta^i m) + RLWE(s * u_j * a_{i, l}) + if out_eval { + let mut ni_rlwe_ct = ni_rlwe_ct.clone(); + nttop.forward(ni_rlwe_ct.as_mut()); + rlwe_modop.elwise_add_mut(rlwe_a.as_mut(), ni_rlwe_ct.as_ref()); + } else { + nttop.backward(rlwe_a.as_mut()); + nttop.backward(rlwe_b.as_mut()); + rlwe_modop.elwise_add_mut(rlwe_a.as_mut(), ni_rlwe_ct.as_ref()); + } + }); + } + + // RLWE'_{s}(m) + { + let (rlwe_dash_m_parta, rlwe_dash_partb) = + rlwe_dash_m.split_at_mut(decomposer.b().decomposition_count().0); + izip!( + rlwe_dash_m_parta.iter_mut(), + rlwe_dash_partb.iter_mut(), + ni_rgsw_ct.iter_rows().skip( + max_decomposer.decomposition_count().0 - decomposer.b().decomposition_count().0 + ), + decomposed_neg_ais.iter() + ) + .for_each(|(rlwe_a, rlwe_b, ni_rlwe_ct, decomp_neg_ai)| { + // KS(-a_{i, l}) = RLWE(u_i * -a_{i,l}) using user j's Ksk + izip!( + decomp_neg_ai.iter_rows(), + uj_to_s_ksk.0.iter_rows(), + uj_to_s_ksk.1.iter_rows() + ) + .for_each(|(c, pb, pa)| { + rlwe_modop.elwise_fma_mut(rlwe_b.as_mut(), pb.as_ref(), c.as_ref()); + rlwe_modop.elwise_fma_mut(rlwe_a.as_mut(), pa.as_ref(), c.as_ref()); + }); + + // RLWE_{s}(beta^i m) = (u_j * a_{i, l} + e + beta^i m, 0) - + // RLWE(-a_{i, l} u_j) + if out_eval { + let mut ni_rlwe_ct = ni_rlwe_ct.clone(); + nttop.forward(ni_rlwe_ct.as_mut()); + rlwe_modop.elwise_add_mut(rlwe_b.as_mut(), ni_rlwe_ct.as_ref()); + } else { + nttop.backward(rlwe_a.as_mut()); + nttop.backward(rlwe_b.as_mut()); + rlwe_modop.elwise_add_mut(rlwe_b.as_mut(), ni_rlwe_ct.as_ref()); + } + }); + } + + rgsw_i +} + +/// Assigns user with user_id segement of LWE secret indices for which they +/// generate RGSW(X^{s[i]}) as the leader (i.e. for RLWExRGSW). If returned +/// tuple is (start, end), user's segment is [start, end) +pub(super) fn multi_party_user_id_lwe_segment( + user_id: usize, + total_users: usize, + lwe_n: usize, +) -> (usize, usize) { + let per_user = (lwe_n as f64 / total_users as f64) + .ceil() + .to_usize() + .unwrap(); + ( + per_user * user_id, + std::cmp::min(per_user * (user_id + 1), lwe_n), + ) +} + +impl BoolEvaluator +where + M: MatrixEntity + MatrixMut, + M::MatElement: PrimInt + + Debug + + Display + + NumInfo + + FromPrimitive + + WrappingSub + + WrappingAdd + + SampleUniform + + From, + NttOp: Ntt, + RlweModOp: ArithmeticOps + + VectorOps + + GetModulus> + + ShoupMatrixFMA, + LweModOp: ArithmeticOps + + VectorOps + + GetModulus>, + M::R: TryConvertFrom1<[i32], CiphertextModulus> + RowEntity + Debug, + ::R: RowMut, +{ + pub(super) fn new(parameters: BoolParameters) -> Self + where + RlweModOp: ModInit>, + LweModOp: ModInit>, + NttOp: NttInit>, + { + //TODO(Jay): Run sanity checks for modulus values in parameters + + // generates dlog map s.t. (+/-)g^{k} % q = a, for all a \in Z*_{q} and k \in + // [0, q/4). We store the dlog `k` at index `a`. This makes it easier to + // simply look up `k` at runtime as vec[a]. If a = g^{k} then dlog is + // stored as k. If a = -g^{k} then dlog is stored as k = q/4. This is done to + // differentiate sign. + let g = parameters.g(); + let q = *parameters.br_q(); + let mut g_k_dlog_map = vec![0usize; q]; + for i in 0..q / 4 { + let v = mod_exponent(g as u64, i as u64, q as u64) as usize; + // g^i + g_k_dlog_map[v] = i; + // -(g^i) + g_k_dlog_map[q - v] = i + (q / 4); + } + + let embedding_factor = (2 * parameters.rlwe_n().0) / q; + + let rlwe_nttop = NttOp::new(parameters.rlwe_q(), parameters.rlwe_n().0); + let rlwe_modop = RlweModOp::new(*parameters.rlwe_q()); + let lwe_modop = LweModOp::new(*parameters.lwe_q()); + + let q = *parameters.br_q(); + let qby2 = q >> 1; + let qby8 = q >> 3; + // Q/8 (Q: rlwe_q) + let true_m_el = parameters.rlwe_q().true_el(); + // -Q/8 + let false_m_el = parameters.rlwe_q().false_el(); + let (auto_map_index, auto_map_sign) = generate_auto_map(qby2, -(g as isize)); + + let init_test_vec = |partition_el: usize, + before_partition_el: M::MatElement, + after_partition_el: M::MatElement| { + let mut test_vec = M::R::zeros(qby2); + for i in 0..qby2 { + if i < partition_el { + test_vec.as_mut()[i] = before_partition_el; + } else { + test_vec.as_mut()[i] = after_partition_el; + } + } + + // v(X) -> v(X^{-g}) + let mut test_vec_autog = M::R::zeros(qby2); + izip!( + test_vec.as_ref().iter(), + auto_map_index.iter(), + auto_map_sign.iter() + ) + .for_each(|(v, to_index, to_sign)| { + if !to_sign { + // negate + test_vec_autog.as_mut()[*to_index] = rlwe_modop.neg(v); + } else { + test_vec_autog.as_mut()[*to_index] = *v; + } + }); + + return test_vec_autog; + }; + + let nand_test_vec = init_test_vec(3 * qby8, true_m_el, false_m_el); + let and_test_vec = init_test_vec(3 * qby8, false_m_el, true_m_el); + let or_test_vec = init_test_vec(qby8, false_m_el, true_m_el); + let nor_test_vec = init_test_vec(qby8, true_m_el, false_m_el); + let xor_test_vec = init_test_vec(qby8, false_m_el, true_m_el); + let xnor_test_vec = init_test_vec(qby8, true_m_el, false_m_el); + + // auto map indices and sign + // Auto maps are stored as [-g, g^{1}, g^{2}, ..., g^{w}] + let mut rlwe_auto_maps = vec![]; + let ring_size = parameters.rlwe_n().0; + let g = parameters.g(); + let br_q = parameters.br_q(); + let auto_element_dlogs = parameters.auto_element_dlogs(); + assert!(auto_element_dlogs[0] == 0); + for i in auto_element_dlogs.into_iter() { + let el = if i == 0 { + -(g as isize) + } else { + (g.pow(i as u32) % br_q) as isize + }; + rlwe_auto_maps.push(generate_auto_map(ring_size, el)) + } + + let rlwe_qby4 = parameters.rlwe_q().qby4(); + + let scratch_memory = ScratchMemory::new(¶meters); + + let ni_ui_to_s_ks_decomposer = if parameters.variant() + == &ParameterVariant::NonInteractiveMultiParty + { + Some(parameters + .non_interactive_ui_to_s_key_switch_decomposer::>()) + } else { + None + }; + + let pbs_info = BoolPbsInfo { + auto_decomposer: parameters.auto_decomposer(), + lwe_decomposer: parameters.lwe_decomposer(), + rlwe_rgsw_decomposer: parameters.rlwe_rgsw_decomposer(), + g_k_dlog_map, + embedding_factor, + lwe_modop, + rlwe_modop, + rlwe_nttop, + rlwe_qby4, + rlwe_auto_maps, + parameters: parameters, + }; + + BoolEvaluator { + pbs_info, + scratch_memory, + nand_test_vec, + and_test_vec, + or_test_vec, + nor_test_vec, + xnor_test_vec, + xor_test_vec, + ni_ui_to_s_ks_decomposer, + _phantom: PhantomData, + } + } + + pub(crate) fn client_key( + &self, + ) -> ClientKey<::Seed, M::MatElement> { + ClientKey::new(self.parameters().clone()) + } + + pub(super) fn single_party_server_key>( + &self, + client_key: &K, + ) -> SeededSinglePartyServerKey, [u8; 32]> { + assert_eq!(self.parameters().variant(), &ParameterVariant::SingleParty); + + DefaultSecureRng::with_local_mut(|rng| { + let mut main_seed = [0u8; 32]; + rng.fill_bytes(&mut main_seed); + + let mut main_prng = DefaultSecureRng::new_seeded(main_seed); + + let rlwe_n = self.pbs_info.parameters.rlwe_n().0; + let sk_rlwe = client_key.sk_rlwe(); + let sk_lwe = client_key.sk_lwe(); + + // generate auto keys + let mut auto_keys = HashMap::new(); + let auto_gadget = self.pbs_info.auto_decomposer.gadget_vector(); + let g = self.pbs_info.parameters.g(); + let br_q = self.pbs_info.parameters.br_q(); + let auto_els = self.pbs_info.parameters.auto_element_dlogs(); + for i in auto_els.into_iter() { + let g_pow = if i == 0 { + -(g as isize) + } else { + (g.pow(i as u32) % br_q) as isize + }; + let mut gk = M::zeros( + self.pbs_info.auto_decomposer.decomposition_count().0, + rlwe_n, + ); + seeded_auto_key_gen( + &mut gk, + &sk_rlwe, + g_pow, + &auto_gadget, + &self.pbs_info.rlwe_modop, + &self.pbs_info.rlwe_nttop, + &mut main_prng, + rng, + ); + auto_keys.insert(i, gk); + } + + // generate rgsw ciphertexts RGSW(si) where si is i^th LWE secret element + let ring_size = self.pbs_info.parameters.rlwe_n().0; + let rlwe_q = self.pbs_info.parameters.rlwe_q(); + let (rlrg_d_a, rlrg_d_b) = ( + self.pbs_info.rlwe_rgsw_decomposer.0.decomposition_count().0, + self.pbs_info.rlwe_rgsw_decomposer.1.decomposition_count().0, + ); + let rlrg_gadget_a = self.pbs_info.rlwe_rgsw_decomposer.0.gadget_vector(); + let rlrg_gadget_b = self.pbs_info.rlwe_rgsw_decomposer.1.gadget_vector(); + let rgsw_cts = sk_lwe + .iter() + .map(|si| { + // X^{si}; assume |emebedding_factor * si| < N + let mut m = M::R::zeros(ring_size); + let si = (self.pbs_info.embedding_factor as i32) * si; + // dbg!(si); + if si < 0 { + // X^{-i} = X^{2N - i} = -X^{N-i} + m.as_mut()[ring_size - (si.abs() as usize)] = rlwe_q.neg_one(); + } else { + // X^{i} + m.as_mut()[si.abs() as usize] = M::MatElement::one(); + } + + let mut rgsw_si = M::zeros(rlrg_d_a * 2 + rlrg_d_b, ring_size); + secret_key_encrypt_rgsw( + &mut rgsw_si, + m.as_ref(), + &rlrg_gadget_a, + &rlrg_gadget_b, + &sk_rlwe, + &self.pbs_info.rlwe_modop, + &self.pbs_info.rlwe_nttop, + &mut main_prng, + rng, + ); + + rgsw_si + }) + .collect_vec(); + + // LWE KSK from RLWE secret s -> LWE secret z + let d_lwe_gadget = self.pbs_info.lwe_decomposer.gadget_vector(); + let lwe_ksk = seeded_lwe_ksk_keygen( + &sk_rlwe, + &sk_lwe, + &d_lwe_gadget, + &self.pbs_info.lwe_modop, + &mut main_prng, + rng, + ); + + SeededSinglePartyServerKey::from_raw( + auto_keys, + rgsw_cts, + lwe_ksk, + self.pbs_info.parameters.clone(), + main_seed, + ) + }) + } + + pub(super) fn gen_interactive_multi_party_server_key_share< + K: InteractiveMultiPartyClientKey, + >( + &self, + user_id: usize, + total_users: usize, + cr_seed: &InteractiveMultiPartyCrs<[u8; 32]>, + collective_pk: &M, + client_key: &K, + ) -> CommonReferenceSeededInteractiveMultiPartyServerKeyShare< + M, + BoolParameters, + InteractiveMultiPartyCrs<[u8; 32]>, + > { + assert_eq!( + self.parameters().variant(), + &ParameterVariant::InteractiveMultiParty + ); + assert!(user_id < total_users); + + let sk_rlwe = client_key.sk_rlwe(); + let sk_lwe = client_key.sk_lwe(); + + let g = self.pbs_info.parameters.g(); + let ring_size = self.pbs_info.parameters.rlwe_n().0; + let rlwe_q = self.pbs_info.parameters.rlwe_q(); + let lwe_q = self.pbs_info.parameters.lwe_q(); + + let rlweq_modop = &self.pbs_info.rlwe_modop; + let rlweq_nttop = &self.pbs_info.rlwe_nttop; + + // sanity check + assert!(sk_rlwe.len() == ring_size); + assert!(sk_lwe.len() == self.pbs_info.parameters.lwe_n().0); + + // auto keys + let auto_keys = self._common_rountine_multi_party_auto_keys_share_gen( + cr_seed.auto_keys_cts_seed::(), + &sk_rlwe, + ); + + // rgsw ciphertexts of lwe secret elements + let (self_leader_rgsws, not_self_leader_rgsws) = DefaultSecureRng::with_local_mut(|rng| { + let mut self_leader_rgsw = vec![]; + let mut not_self_leader_rgsws = vec![]; + + let (segment_start, segment_end) = + multi_party_user_id_lwe_segment(user_id, total_users, self.pbs_info().lwe_n()); + + // self LWE secret indices + { + // LWE secret indices for which user is the leader they need to send RGSW(m) for + // RLWE x RGSW multiplication + let rlrg_decomposer = self.pbs_info().rlwe_rgsw_decomposer(); + let (rlrg_d_a, rlrg_d_b) = ( + rlrg_decomposer.a().decomposition_count(), + rlrg_decomposer.b().decomposition_count(), + ); + let (gadget_a, gadget_b) = ( + rlrg_decomposer.a().gadget_vector(), + rlrg_decomposer.b().gadget_vector(), + ); + for s_index in segment_start..segment_end { + let mut out_rgsw = M::zeros(rlrg_d_a.0 * 2 + rlrg_d_b.0 * 2, ring_size); + public_key_encrypt_rgsw( + &mut out_rgsw, + &encode_x_pow_si_with_emebedding_factor::< + M::R, + CiphertextModulus, + >( + sk_lwe[s_index], + self.pbs_info().embedding_factor(), + ring_size, + self.pbs_info().rlwe_q(), + ) + .as_ref(), + collective_pk, + &gadget_a, + &gadget_b, + rlweq_modop, + rlweq_nttop, + rng, + ); + self_leader_rgsw.push(out_rgsw); + } + } + + // not self LWE secret indices + { + // LWE secret indices for which user isn't the leader, they need to send RGSW(m) + // for RGSW x RGSW multiplcation + let rgsw_rgsw_decomposer = self + .pbs_info + .parameters + .rgsw_rgsw_decomposer::>(); + let (rgrg_d_a, rgrg_d_b) = ( + rgsw_rgsw_decomposer.a().decomposition_count(), + rgsw_rgsw_decomposer.b().decomposition_count(), + ); + let (rgrg_gadget_a, rgrg_gadget_b) = ( + rgsw_rgsw_decomposer.a().gadget_vector(), + rgsw_rgsw_decomposer.b().gadget_vector(), + ); + + for s_index in (0..segment_start).chain(segment_end..self.parameters().lwe_n().0) { + let mut out_rgsw = M::zeros(rgrg_d_a.0 * 2 + rgrg_d_b.0 * 2, ring_size); + public_key_encrypt_rgsw( + &mut out_rgsw, + &encode_x_pow_si_with_emebedding_factor::< + M::R, + CiphertextModulus, + >( + sk_lwe[s_index], + self.pbs_info().embedding_factor(), + ring_size, + self.pbs_info().rlwe_q(), + ) + .as_ref(), + collective_pk, + &rgrg_gadget_a, + &rgrg_gadget_b, + rlweq_modop, + rlweq_nttop, + rng, + ); + + not_self_leader_rgsws.push(out_rgsw); + } + } + + (self_leader_rgsw, not_self_leader_rgsws) + }); + + // LWE Ksk + let lwe_ksk = self._common_rountine_multi_party_lwe_ksk_share_gen( + cr_seed.lwe_ksk_cts_seed_seed::(), + &sk_rlwe, + &sk_lwe, + ); + + CommonReferenceSeededInteractiveMultiPartyServerKeyShare::new( + self_leader_rgsws, + not_self_leader_rgsws, + auto_keys, + lwe_ksk, + cr_seed.clone(), + self.pbs_info.parameters.clone(), + user_id, + ) + } + + pub(super) fn aggregate_interactive_multi_party_server_key_shares( + &self, + shares: &[CommonReferenceSeededInteractiveMultiPartyServerKeyShare< + M, + BoolParameters, + InteractiveMultiPartyCrs, + >], + ) -> SeededInteractiveMultiPartyServerKey< + M, + InteractiveMultiPartyCrs, + BoolParameters, + > + where + S: PartialEq + Clone, + M: Clone, + { + assert_eq!( + self.parameters().variant(), + &ParameterVariant::InteractiveMultiParty + ); + assert!(shares.len() > 0); + + let total_users = shares.len(); + + let parameters = shares[0].parameters().clone(); + let cr_seed = shares[0].cr_seed(); + + let rlwe_n = parameters.rlwe_n().0; + + // sanity checks + shares.iter().skip(1).for_each(|s| { + assert!(s.parameters() == ¶meters); + assert!(s.cr_seed() == cr_seed); + }); + + let rlweq_modop = &self.pbs_info.rlwe_modop; + let rlweq_nttop = &self.pbs_info.rlwe_nttop; + + // auto keys + let mut auto_keys = HashMap::new(); + let auto_elements_dlog = parameters.auto_element_dlogs(); + for i in auto_elements_dlog.into_iter() { + let mut key = M::zeros(parameters.auto_decomposition_count().0, rlwe_n); + + shares.iter().for_each(|s| { + let auto_key_share_i = s.auto_keys().get(&i).expect("Auto key {i} missing"); + assert!( + auto_key_share_i.dimension() + == (parameters.auto_decomposition_count().0, rlwe_n) + ); + izip!(key.iter_rows_mut(), auto_key_share_i.iter_rows()).for_each( + |(partb_out, partb_share)| { + rlweq_modop.elwise_add_mut(partb_out.as_mut(), partb_share.as_ref()); + }, + ); + }); + + auto_keys.insert(i, key); + } + + // rgsw ciphertext (most expensive part!) + let rgsw_cts = { + let rgsw_x_rgsw_decomposer = + parameters.rgsw_rgsw_decomposer::>(); + let rlwe_x_rgsw_decomposer = self.pbs_info().rlwe_rgsw_decomposer(); + let rgsw_x_rgsw_dimension = ( + rgsw_x_rgsw_decomposer.a().decomposition_count().0 * 2 + + rgsw_x_rgsw_decomposer.b().decomposition_count().0 * 2, + rlwe_n, + ); + let rlwe_x_rgsw_dimension = ( + rlwe_x_rgsw_decomposer.a().decomposition_count().0 * 2 + + rlwe_x_rgsw_decomposer.b().decomposition_count().0 * 2, + rlwe_n, + ); + + let mut rgsw_x_rgsw_scratch = M::zeros( + rgsw_x_rgsw_scratch_rows(rlwe_x_rgsw_decomposer, &rgsw_x_rgsw_decomposer), + rlwe_n, + ); + + let shares_in_correct_order = (0..total_users) + .map(|i| shares.iter().find(|s| s.user_id() == i).unwrap()) + .collect_vec(); + + let lwe_n = self.parameters().lwe_n().0; + let (users_segments, users_segments_sizes): (Vec<(usize, usize)>, Vec) = (0 + ..total_users) + .map(|(user_id)| { + let (start_index, end_index) = + multi_party_user_id_lwe_segment(user_id, total_users, lwe_n); + ((start_index, end_index), end_index - start_index) + }) + .unzip(); + + let mut rgsw_cts = Vec::with_capacity(lwe_n); + users_segments + .iter() + .enumerate() + .for_each(|(user_id, user_segment)| { + let share = shares_in_correct_order[user_id]; + for secret_index in user_segment.0..user_segment.1 { + let mut rgsw_i = + share.self_leader_rgsws()[secret_index - user_segment.0].clone(); + // assert already exists in RGSW x RGSW rountine + assert!(rgsw_i.dimension() == rlwe_x_rgsw_dimension); + + // multiply leader's RGSW ct at `secret_index` with RGSW cts of other users + // for lwe index `secret_index` + (0..total_users) + .filter(|i| i != &user_id) + .for_each(|other_user_id| { + let mut offset = 0; + if other_user_id < user_id { + offset = users_segments_sizes[other_user_id]; + } + + let mut other_rgsw_i = shares_in_correct_order[other_user_id] + .not_self_leader_rgsws() + [secret_index.checked_sub(offset).unwrap()] + .clone(); + // assert already exists in RGSW x RGSW rountine + assert!(other_rgsw_i.dimension() == rgsw_x_rgsw_dimension); + + // send to evaluation domain for RGSwxRGSW mul + other_rgsw_i + .iter_rows_mut() + .for_each(|r| rlweq_nttop.forward(r.as_mut())); + + rgsw_by_rgsw_inplace( + &mut RgswCiphertextMutRef::new( + rgsw_i.as_mut(), + rlwe_x_rgsw_decomposer.a().decomposition_count().0, + rlwe_x_rgsw_decomposer.b().decomposition_count().0, + ), + &RgswCiphertextRef::new( + other_rgsw_i.as_ref(), + rgsw_x_rgsw_decomposer.a().decomposition_count().0, + rgsw_x_rgsw_decomposer.b().decomposition_count().0, + ), + rlwe_x_rgsw_decomposer, + &rgsw_x_rgsw_decomposer, + &mut RuntimeScratchMutRef::new(rgsw_x_rgsw_scratch.as_mut()), + rlweq_nttop, + rlweq_modop, + ); + }); + + rgsw_cts.push(rgsw_i); + } + }); + + rgsw_cts + }; + + // LWE ksks + let mut lwe_ksk = M::R::zeros(rlwe_n * parameters.lwe_decomposition_count().0); + let lweq_modop = &self.pbs_info.lwe_modop; + shares.iter().for_each(|si| { + assert!(si.lwe_ksk().as_ref().len() == rlwe_n * parameters.lwe_decomposition_count().0); + lweq_modop.elwise_add_mut(lwe_ksk.as_mut(), si.lwe_ksk().as_ref()) + }); + + SeededInteractiveMultiPartyServerKey::new( + rgsw_cts, + auto_keys, + lwe_ksk, + cr_seed.clone(), + parameters, + ) + } + + pub(super) fn aggregate_non_interactive_multi_party_server_key_shares( + &self, + cr_seed: &NonInteractiveMultiPartyCrs<[u8; 32]>, + key_shares: &[CommonReferenceSeededNonInteractiveMultiPartyServerKeyShare< + M, + BoolParameters, + NonInteractiveMultiPartyCrs<[u8; 32]>, + >], + ) -> SeededNonInteractiveMultiPartyServerKey< + M, + NonInteractiveMultiPartyCrs<[u8; 32]>, + BoolParameters, + > + where + M: Clone + Debug, + ::R: RowMut + Clone, + { + assert_eq!( + self.parameters().variant(), + &ParameterVariant::NonInteractiveMultiParty + ); + + let total_users = key_shares.len(); + let key_shares = (0..total_users) + .map(|user_id| { + // find share of user_id + key_shares + .iter() + .find(|share| share.user_index() == user_id) + .expect(&format!("Key Share for user_id={user_id} missing")) + }) + .collect_vec(); + + // check parameters and cr seed are equal + { + key_shares.iter().for_each(|k| { + assert!(k.parameters() == self.parameters()); + assert!(k.cr_seed() == cr_seed); + }); + } + + let rlwe_modop = &self.pbs_info().rlwe_modop; + let nttop = &self.pbs_info().rlwe_nttop; + let ring_size = self.parameters().rlwe_n().0; + let rlwe_q = self.parameters().rlwe_q(); + let lwe_modop = self.pbs_info().modop_lweq(); + + // Generate Key switching key from u_j to s, where u_j user j's RLWE secret and + // s is the ideal RLWE secret. + // + // User j gives [s_j * a_{i, j} + e + \beta^i u_j] where a_{i, j} is user j + // specific publicly know polynomial sampled from user j's pseudo random seed + // defined in the protocol. + // + // User k, k != j, gives [s_k * a_{i, j} + e]. + // + // We set Ksk(u_j -> s) = [s_j * a_{i, j} + e + \beta^i u_j + \sum_{k \in P, k + // != j} s_k * a_{i, j} + e] + let ni_uj_to_s_decomposer = self + .parameters() + .non_interactive_ui_to_s_key_switch_decomposer::>(); + let mut uj_to_s_ksks = key_shares + .iter() + .map(|share| { + let mut useri_ui_to_s_ksk = share.ui_to_s_ksk().clone(); + assert!( + useri_ui_to_s_ksk.dimension() + == (ni_uj_to_s_decomposer.decomposition_count().0, ring_size) + ); + key_shares + .iter() + .filter(|x| x.user_index() != share.user_index()) + .for_each(|other_share| { + let op2 = other_share.ui_to_s_ksk_zero_encs_for_user_i(share.user_index()); + assert!( + op2.dimension() + == (ni_uj_to_s_decomposer.decomposition_count().0, ring_size) + ); + izip!(useri_ui_to_s_ksk.iter_rows_mut(), op2.iter_rows()).for_each( + |(add_to, add_from)| { + rlwe_modop.elwise_add_mut(add_to.as_mut(), add_from.as_ref()) + }, + ); + }); + useri_ui_to_s_ksk + }) + .collect_vec(); + + let rgsw_cts = { + // Send u_j -> s ksk in evaluation domain and sample corresponding a's using + // user j's ksk seed to prepare for upcoming key switches + uj_to_s_ksks.iter_mut().for_each(|ksk_i| { + ksk_i + .iter_rows_mut() + .for_each(|r| nttop.forward(r.as_mut())) + }); + let uj_to_s_ksks_part_a_eval = key_shares + .iter() + .map(|share| { + let mut ksk_prng = DefaultSecureRng::new_seeded( + cr_seed.ui_to_s_ks_seed_for_user_i::(share.user_index()), + ); + let mut ais = + M::zeros(ni_uj_to_s_decomposer.decomposition_count().0, ring_size); + + ais.iter_rows_mut().for_each(|r_ai| { + RandomFillUniformInModulus::random_fill( + &mut ksk_prng, + rlwe_q, + r_ai.as_mut(), + ); + + nttop.forward(r_ai.as_mut()) + }); + ais + }) + .collect_vec(); + + let rgsw_x_rgsw_decomposer = self + .parameters() + .rgsw_rgsw_decomposer::>(); + let rlwe_x_rgsw_decomposer = self + .parameters() + .rlwe_rgsw_decomposer::>(); + + let d_max = if rgsw_x_rgsw_decomposer.a().decomposition_count().0 + > rgsw_x_rgsw_decomposer.b().decomposition_count().0 + { + rgsw_x_rgsw_decomposer.a().decomposition_count().0 + } else { + rgsw_x_rgsw_decomposer.b().decomposition_count().0 + }; + + let mut scratch_rgsw_x_rgsw = M::zeros( + rgsw_x_rgsw_scratch_rows(&rlwe_x_rgsw_decomposer, &rgsw_x_rgsw_decomposer), + self.parameters().rlwe_n().0, + ); + + // Recall that given u_j * a_{i, l} + e + \beta^i X^{s_j[l]} from user j + // + // We generate: + // + // - RLWE(-s \beta^i X^{s_j[l]}) = KS_{u_j -> s}(a_{i, l} * s + e) + (0 , u_j * + // a_{i, l} + e + \beta^i X^{s_j[l]}), where KS_{u_j -> s}(a_{i, l} * s + e) = + // RLWE_s(a_{i, l} * u_j) + // - RLWE(\beta^i X^{s_j[l]}) = KS_{u_j -> s}(-a_{i,l}) + (u_j * a_{i, l} + e + + // \beta^i X^{s_j[l]}, 0), where KS_{u_j -> s}(-a_{i,l}) = RLWE_s(-a_{i,l} * + // u_j) + // + // a_{i, l} * s + e = \sum_{j \in P} a_{i, l} * s_{j} + e + let user_segments = (0..total_users) + .map(|user_id| { + multi_party_user_id_lwe_segment( + user_id, + total_users, + self.parameters().lwe_n().0, + ) + }) + .collect_vec(); + // Note: Each user is assigned a contigous LWE segement and the LWE dimension is + // split approximately uniformly across all users. Hence, concatenation of all + // user specific lwe segments will give LWE dimension. + let rgsw_cts = user_segments + .into_iter() + .enumerate() + .flat_map(|(user_id, lwe_segment)| { + (lwe_segment.0..lwe_segment.1) + .into_iter() + .map(|lwe_index| { + // We sample d_b `-a_i`s to key switch and generate RLWE'(m). But before + // we sampling we need to puncture a_prng d_max - d_b times to align + // a_i's. After sampling we decompose `-a_i`s and send them to + // evaluation domain for upcoming key switches. + let mut a_prng = DefaultSecureRng::new_seeded( + cr_seed.ni_rgsw_ct_seed_for_index::(lwe_index), + ); + + let mut scratch = M::R::zeros(self.parameters().rlwe_n().0); + (0..d_max - rgsw_x_rgsw_decomposer.b().decomposition_count().0) + .for_each(|_| { + RandomFillUniformInModulus::random_fill( + &mut a_prng, + rlwe_q, + scratch.as_mut(), + ); + }); + + let decomp_neg_ais = (0..rgsw_x_rgsw_decomposer + .b() + .decomposition_count() + .0) + .map(|_| { + RandomFillUniformInModulus::random_fill( + &mut a_prng, + rlwe_q, + scratch.as_mut(), + ); + rlwe_modop.elwise_neg_mut(scratch.as_mut()); + + let mut decomp_neg_ai = M::zeros( + ni_uj_to_s_decomposer.decomposition_count().0, + self.parameters().rlwe_n().0, + ); + scratch.as_ref().iter().enumerate().for_each(|(index, el)| { + ni_uj_to_s_decomposer + .decompose_iter(el) + .enumerate() + .for_each(|(row_j, d_el)| { + (decomp_neg_ai.as_mut()[row_j]).as_mut()[index] = + d_el; + }); + }); + + decomp_neg_ai + .iter_rows_mut() + .for_each(|ri| nttop.forward(ri.as_mut())); + + decomp_neg_ai + }) + .collect_vec(); + + // Aggregate zero encryptions to produce a_{i, l} * s + e = + // \sum_{k in P} a_{i, l} * s_{k} + e where s is the ideal RLWE + // secret. Aggregated a_{i, l} * s + e are key switched using + // Ksk(u_j -> s) to produce RLWE_{s}(a_{i, l} * s) which are + // then use to produce RLWE'(-sX^{s_{lwe}[l]}). + // Hence, after aggregation we decompose a_{i, l} * s + e to + // prepare for key switching + let ni_rgsw_zero_encs = (0..rgsw_x_rgsw_decomposer + .a() + .decomposition_count() + .0) + .map(|i| { + let mut sum = M::R::zeros(self.parameters().rlwe_n().0); + key_shares.iter().for_each(|k| { + let to_add_ref = k + .ni_rgsw_zero_enc_for_lwe_index(lwe_index) + .get_row_slice(i); + assert!(to_add_ref.len() == self.parameters().rlwe_n().0); + rlwe_modop.elwise_add_mut(sum.as_mut(), to_add_ref); + }); + + // decompose + let mut decomp_sum = M::zeros( + ni_uj_to_s_decomposer.decomposition_count().0, + self.parameters().rlwe_n().0, + ); + sum.as_ref().iter().enumerate().for_each(|(index, el)| { + ni_uj_to_s_decomposer + .decompose_iter(el) + .enumerate() + .for_each(|(row_j, d_el)| { + (decomp_sum.as_mut()[row_j]).as_mut()[index] = d_el; + }); + }); + + decomp_sum + .iter_rows_mut() + .for_each(|r| nttop.forward(r.as_mut())); + + decomp_sum + }) + .collect_vec(); + + // Produce RGSW(X^{s_{j=user_id, lwe}[l]}) for the + // leader, ie user's id = user_id. + // Recall leader's RGSW ciphertext must be constructed + // for RLWE x RGSW multiplication, and is then used + // to accumulate, using RGSW x RGSW multiplication, + // X^{s_{j != user_id, lwe}[l]} from other users. + let mut rgsw_i = produce_rgsw_ciphertext_from_ni_rgsw( + key_shares[user_id] + .ni_rgsw_cts_for_self_leader_lwe_index(lwe_index), + &ni_rgsw_zero_encs[rgsw_x_rgsw_decomposer + .a() + .decomposition_count() + .0 + - rlwe_x_rgsw_decomposer.a().decomposition_count().0..], + &decomp_neg_ais[rgsw_x_rgsw_decomposer + .b() + .decomposition_count() + .0 + - rlwe_x_rgsw_decomposer.b().decomposition_count().0..], + &rlwe_x_rgsw_decomposer, + self.parameters(), + (&uj_to_s_ksks[user_id], &uj_to_s_ksks_part_a_eval[user_id]), + rlwe_modop, + nttop, + false, + ); + + // RGSW for lwe_index of users that are not leader. + // + // Recall that for users that are not leader for the + // lwe_index we require to produce RGSW ciphertext for + // RGSW x RGSW product + (0..total_users) + .filter(|i| *i != user_id) + .for_each(|other_user_id| { + let mut other_rgsw_i = produce_rgsw_ciphertext_from_ni_rgsw( + key_shares[other_user_id] + .ni_rgsw_cts_for_self_not_leader_lwe_index(lwe_index), + &ni_rgsw_zero_encs, + &decomp_neg_ais, + &rgsw_x_rgsw_decomposer, + self.parameters(), + ( + &uj_to_s_ksks[other_user_id], + &uj_to_s_ksks_part_a_eval[other_user_id], + ), + rlwe_modop, + nttop, + true, + ); + + rgsw_by_rgsw_inplace( + &mut RgswCiphertextMutRef::new( + rgsw_i.as_mut(), + rlwe_x_rgsw_decomposer.a().decomposition_count().0, + rlwe_x_rgsw_decomposer.b().decomposition_count().0, + ), + &RgswCiphertextRef::new( + other_rgsw_i.as_ref(), + rgsw_x_rgsw_decomposer.a().decomposition_count().0, + rgsw_x_rgsw_decomposer.b().decomposition_count().0, + ), + &rlwe_x_rgsw_decomposer, + &rgsw_x_rgsw_decomposer, + &mut RuntimeScratchMutRef::new( + scratch_rgsw_x_rgsw.as_mut(), + ), + nttop, + rlwe_modop, + ) + }); + + rgsw_i + }) + .collect_vec() + }) + .collect_vec(); + + // put u_j to s ksk in coefficient domain + uj_to_s_ksks.iter_mut().for_each(|ksk_i| { + ksk_i + .iter_rows_mut() + .for_each(|r| nttop.backward(r.as_mut())) + }); + + rgsw_cts + }; + + // auto keys + let auto_keys = { + let mut auto_keys = HashMap::new(); + let auto_elements_dlog = self.parameters().auto_element_dlogs(); + for i in auto_elements_dlog.into_iter() { + let mut key = M::zeros(self.parameters().auto_decomposition_count().0, ring_size); + + key_shares.iter().for_each(|s| { + let auto_key_share_i = + s.auto_keys_share().get(&i).expect("Auto key {i} missing"); + assert!( + auto_key_share_i.dimension() + == (self.parameters().auto_decomposition_count().0, ring_size) + ); + izip!(key.iter_rows_mut(), auto_key_share_i.iter_rows()).for_each( + |(partb_out, partb_share)| { + rlwe_modop.elwise_add_mut(partb_out.as_mut(), partb_share.as_ref()); + }, + ); + }); + + auto_keys.insert(i, key); + } + auto_keys + }; + + // LWE ksk + let lwe_ksk = { + let mut lwe_ksk = + M::R::zeros(self.parameters().lwe_decomposition_count().0 * ring_size); + key_shares.iter().for_each(|s| { + assert!( + s.lwe_ksk_share().as_ref().len() + == self.parameters().lwe_decomposition_count().0 * ring_size + ); + lwe_modop.elwise_add_mut(lwe_ksk.as_mut(), s.lwe_ksk_share().as_ref()); + }); + lwe_ksk + }; + + SeededNonInteractiveMultiPartyServerKey::new( + uj_to_s_ksks, + rgsw_cts, + auto_keys, + lwe_ksk, + cr_seed.clone(), + self.parameters().clone(), + ) + } + + pub(super) fn gen_non_interactive_multi_party_key_share< + K: NonInteractiveMultiPartyClientKey, + >( + &self, + cr_seed: &NonInteractiveMultiPartyCrs<[u8; 32]>, + self_index: usize, + total_users: usize, + client_key: &K, + ) -> CommonReferenceSeededNonInteractiveMultiPartyServerKeyShare< + M, + BoolParameters, + NonInteractiveMultiPartyCrs<[u8; 32]>, + > { + assert_eq!( + self.parameters().variant(), + &ParameterVariant::NonInteractiveMultiParty + ); + + // TODO: check whether parameters support `total_users` + let nttop = self.pbs_info().nttop_rlweq(); + let rlwe_modop = self.pbs_info().modop_rlweq(); + // let ring_size = self.pbs_info().rlwe_n(); + let rlwe_q = self.parameters().rlwe_q(); + + let sk_rlwe = client_key.sk_rlwe(); + let sk_u_rlwe = client_key.sk_u_rlwe(); + let sk_lwe = client_key.sk_lwe(); + + let (ui_to_s_ksk, ksk_zero_encs_for_others) = DefaultSecureRng::with_local_mut(|rng| { + // ui_to_s_ksk + let non_interactive_decomposer = self + .parameters() + .non_interactive_ui_to_s_key_switch_decomposer::>( + ); + let non_interactive_gadget_vec = non_interactive_decomposer.gadget_vector(); + let ui_to_s_ksk = { + let mut p_rng = DefaultSecureRng::new_seeded( + cr_seed.ui_to_s_ks_seed_for_user_i::(self_index), + ); + + non_interactive_ksk_gen::( + &sk_rlwe, + &sk_u_rlwe, + &non_interactive_gadget_vec, + &mut p_rng, + rng, + nttop, + rlwe_modop, + ) + }; + + // zero encryptions for others uj_to_s ksk + let all_users_except_self = (0..total_users).filter(|x| *x != self_index); + let zero_encs_for_others = all_users_except_self + .map(|other_user_index| { + let mut p_rng = DefaultSecureRng::new_seeded( + cr_seed.ui_to_s_ks_seed_for_user_i::(other_user_index), + ); + let zero_encs = + non_interactive_ksk_zero_encryptions_for_other_party_i::( + &sk_rlwe, + &non_interactive_gadget_vec, + &mut p_rng, + rng, + nttop, + rlwe_modop, + ); + zero_encs + }) + .collect_vec(); + + (ui_to_s_ksk, zero_encs_for_others) + }); + + // Non-interactive RGSW cts + let (ni_rgsw_zero_encs, self_leader_ni_rgsw_cts, not_self_leader_rgsw_cts) = { + let rgsw_x_rgsw_decomposer = self + .parameters() + .rgsw_rgsw_decomposer::>(); + let rlwe_x_rgsw_decomposer = self + .parameters() + .rlwe_rgsw_decomposer::>(); + + // We assume that d_{a/b} for RGSW x RGSW are always < d'_{a/b} for RLWE x RGSW + assert!( + rlwe_x_rgsw_decomposer.a().decomposition_count().0 + < rgsw_x_rgsw_decomposer.a().decomposition_count().0 + ); + assert!( + rlwe_x_rgsw_decomposer.b().decomposition_count().0 + < rgsw_x_rgsw_decomposer.b().decomposition_count().0 + ); + + let sj_poly_eval = { + let mut s = M::R::try_convert_from(&sk_rlwe, rlwe_q); + nttop.forward(s.as_mut()); + s + }; + + let d_rgsw_a = rgsw_x_rgsw_decomposer.a().decomposition_count().0; + let d_rgsw_b = rgsw_x_rgsw_decomposer.b().decomposition_count().0; + let d_max = std::cmp::max(d_rgsw_a, d_rgsw_b); + + // Zero encyptions for each LWE index. We generate d_a zero encryptions for each + // LWE index using a_{i, l} with i \in {d_max - d_a , d_max) and l = lwe_index + let zero_encs = { + (0..self.parameters().lwe_n().0) + .map(|lwe_index| { + let mut p_rng = DefaultSecureRng::new_seeded( + cr_seed.ni_rgsw_ct_seed_for_index::(lwe_index), + ); + + let mut scratch = M::R::zeros(self.parameters().rlwe_n().0); + + // puncture seeded prng d_max - d_a times + (0..(d_max - d_rgsw_a)).into_iter().for_each(|_| { + RandomFillUniformInModulus::random_fill( + &mut p_rng, + rlwe_q, + scratch.as_mut(), + ); + }); + + let mut zero_enc = M::zeros(d_rgsw_a, self.parameters().rlwe_n().0); + zero_enc.iter_rows_mut().for_each(|out| { + // sample a_i + RandomFillUniformInModulus::random_fill( + &mut p_rng, + rlwe_q, + out.as_mut(), + ); + + // a_i * s_j + nttop.forward(out.as_mut()); + rlwe_modop.elwise_mul_mut(out.as_mut(), sj_poly_eval.as_ref()); + nttop.backward(out.as_mut()); + + // a_j * s_j + e + DefaultSecureRng::with_local_mut_mut(&mut |rng| { + RandomFillGaussianInModulus::random_fill( + rng, + rlwe_q, + scratch.as_mut(), + ); + }); + + rlwe_modop.elwise_add_mut(out.as_mut(), scratch.as_ref()); + }); + + zero_enc + }) + .collect_vec() + }; + + let uj_poly_eval = { + let mut u = M::R::try_convert_from(&sk_u_rlwe, rlwe_q); + nttop.forward(u.as_mut()); + u + }; + + // Generate non-interactive RGSW ciphertexts a_{i, l} u_j + e + \beta X^{s_j[l]} + // for i \in (0, d_max] + let (self_start_index, self_end_index) = multi_party_user_id_lwe_segment( + self_index, + total_users, + self.parameters().lwe_n().0, + ); + + // For LWE indices [self_start_index, self_end_index) user generates + // non-interactive RGSW cts for RLWE x RGSW product. We refer to + // such indices as where user is the leader. For the rest of + // the indices user generates non-interactive RGWS cts for RGSW x + // RGSW multiplication. We refer to such indices as where user is + // not the leader. + let self_leader_ni_rgsw_cts = { + let max_rlwe_x_rgsw_decomposer = + if rlwe_x_rgsw_decomposer.a().decomposition_count().0 + > rlwe_x_rgsw_decomposer.b().decomposition_count().0 + { + rlwe_x_rgsw_decomposer.a() + } else { + rlwe_x_rgsw_decomposer.b() + }; + + let gadget_vec = max_rlwe_x_rgsw_decomposer.gadget_vector(); + + (self_start_index..self_end_index) + .map(|lwe_index| { + let mut p_rng = DefaultSecureRng::new_seeded( + cr_seed.ni_rgsw_ct_seed_for_index::(lwe_index), + ); + + // puncture p_rng d_max - d'_max time to align with `a_{i, l}`s used to + // produce RGSW cts for RGSW x RGSW + let mut scratch = M::R::zeros(self.parameters().rlwe_n().0); + (0..(d_max - max_rlwe_x_rgsw_decomposer.decomposition_count().0)) + .into_iter() + .for_each(|_| { + RandomFillUniformInModulus::random_fill( + &mut p_rng, + rlwe_q, + scratch.as_mut(), + ); + }); + + let mut ni_rgsw_cts = M::zeros( + max_rlwe_x_rgsw_decomposer.decomposition_count().0, + self.parameters().rlwe_n().0, + ); + + // X^{s_{j, lwe}[l]} + let m_poly = encode_x_pow_si_with_emebedding_factor::( + sk_lwe[lwe_index], + self.pbs_info().embedding_factor(), + self.parameters().rlwe_n().0, + rlwe_q, + ); + + izip!(ni_rgsw_cts.iter_rows_mut(), gadget_vec.iter()).for_each( + |(out, beta)| { + // sample a_i + RandomFillUniformInModulus::random_fill( + &mut p_rng, + rlwe_q, + out.as_mut(), + ); + + // u_j * a_i + nttop.forward(out.as_mut()); + rlwe_modop.elwise_mul_mut(out.as_mut(), uj_poly_eval.as_ref()); + nttop.backward(out.as_mut()); + + // u_j + a_i + e + DefaultSecureRng::with_local_mut_mut(&mut |rng| { + RandomFillGaussianInModulus::random_fill( + rng, + rlwe_q, + scratch.as_mut(), + ); + }); + rlwe_modop.elwise_add_mut(out.as_mut(), scratch.as_ref()); + + // u_j + a_i + e + beta m + rlwe_modop.elwise_scalar_mul( + scratch.as_mut(), + m_poly.as_ref(), + beta, + ); + rlwe_modop.elwise_add_mut(out.as_mut(), scratch.as_ref()); + }, + ); + + ni_rgsw_cts + }) + .collect_vec() + }; + + let not_self_leader_rgsw_cts = { + let max_rgsw_x_rgsw_decomposer = + if rgsw_x_rgsw_decomposer.a().decomposition_count().0 + > rgsw_x_rgsw_decomposer.b().decomposition_count().0 + { + rgsw_x_rgsw_decomposer.a() + } else { + rgsw_x_rgsw_decomposer.b() + }; + let gadget_vec = max_rgsw_x_rgsw_decomposer.gadget_vector(); + + ((0..self_start_index).chain(self_end_index..self.parameters().lwe_n().0)) + .map(|lwe_index| { + let mut p_rng = DefaultSecureRng::new_seeded( + cr_seed.ni_rgsw_ct_seed_for_index::(lwe_index), + ); + let mut ni_rgsw_cts = M::zeros( + max_rgsw_x_rgsw_decomposer.decomposition_count().0, + self.parameters().rlwe_n().0, + ); + let mut scratch = M::R::zeros(self.parameters().rlwe_n().0); + + // X^{s_{j, lwe}[l]} + let m_poly = encode_x_pow_si_with_emebedding_factor::( + sk_lwe[lwe_index], + self.pbs_info().embedding_factor(), + self.parameters().rlwe_n().0, + rlwe_q, + ); + + izip!(ni_rgsw_cts.iter_rows_mut(), gadget_vec.iter()).for_each( + |(out, beta)| { + // sample a_i + RandomFillUniformInModulus::random_fill( + &mut p_rng, + rlwe_q, + out.as_mut(), + ); + + // u_j * a_i + nttop.forward(out.as_mut()); + rlwe_modop.elwise_mul_mut(out.as_mut(), uj_poly_eval.as_ref()); + nttop.backward(out.as_mut()); + + // u_j + a_i + e + DefaultSecureRng::with_local_mut_mut(&mut |rng| { + RandomFillGaussianInModulus::random_fill( + rng, + rlwe_q, + scratch.as_mut(), + ); + }); + rlwe_modop.elwise_add_mut(out.as_mut(), scratch.as_ref()); + + // u_j * a_i + e + beta m + rlwe_modop.elwise_scalar_mul( + scratch.as_mut(), + m_poly.as_ref(), + beta, + ); + rlwe_modop.elwise_add_mut(out.as_mut(), scratch.as_ref()); + }, + ); + + ni_rgsw_cts + }) + .collect_vec() + }; + + (zero_encs, self_leader_ni_rgsw_cts, not_self_leader_rgsw_cts) + }; + + // Auto key share + let auto_keys_share = { + let auto_seed = cr_seed.auto_keys_cts_seed::(); + self._common_rountine_multi_party_auto_keys_share_gen(auto_seed, &sk_rlwe) + }; + + // Lwe Ksk share + let lwe_ksk_share = { + let lwe_ksk_seed = cr_seed.lwe_ksk_cts_seed::(); + self._common_rountine_multi_party_lwe_ksk_share_gen(lwe_ksk_seed, &sk_rlwe, &sk_lwe) + }; + + CommonReferenceSeededNonInteractiveMultiPartyServerKeyShare::new( + self_leader_ni_rgsw_cts, + not_self_leader_rgsw_cts, + ni_rgsw_zero_encs, + ui_to_s_ksk, + ksk_zero_encs_for_others, + auto_keys_share, + lwe_ksk_share, + self_index, + total_users, + self.parameters().lwe_n().0, + cr_seed.clone(), + self.parameters().clone(), + ) + } + + fn _common_rountine_multi_party_auto_keys_share_gen( + &self, + auto_seed: ::Seed, + sk_rlwe: &[i32], + ) -> HashMap { + let g = self.pbs_info.parameters.g(); + let ring_size = self.pbs_info.parameters.rlwe_n().0; + let br_q = self.pbs_info.parameters.br_q(); + let rlweq_modop = &self.pbs_info.rlwe_modop; + let rlweq_nttop = &self.pbs_info.rlwe_nttop; + + DefaultSecureRng::with_local_mut(|rng| { + let mut p_rng = DefaultSecureRng::new_seeded(auto_seed); + + let mut auto_keys = HashMap::new(); + let auto_gadget = self.pbs_info.auto_decomposer.gadget_vector(); + let auto_element_dlogs = self.pbs_info.parameters.auto_element_dlogs(); + + for i in auto_element_dlogs.into_iter() { + let g_pow = if i == 0 { + -(g as isize) + } else { + (g.pow(i as u32) % br_q) as isize + }; + + let mut ksk_out = M::zeros( + self.pbs_info.auto_decomposer.decomposition_count().0, + ring_size, + ); + seeded_auto_key_gen( + &mut ksk_out, + sk_rlwe, + g_pow, + &auto_gadget, + rlweq_modop, + rlweq_nttop, + &mut p_rng, + rng, + ); + auto_keys.insert(i, ksk_out); + } + + auto_keys + }) + } + + fn _common_rountine_multi_party_lwe_ksk_share_gen( + &self, + lwe_ksk_seed: ::Seed, + sk_rlwe: &[i32], + sk_lwe: &[i32], + ) -> M::R { + DefaultSecureRng::with_local_mut(|rng| { + let mut p_rng = DefaultSecureRng::new_seeded(lwe_ksk_seed); + let lwe_modop = &self.pbs_info.lwe_modop; + let d_lwe_gadget_vec = self.pbs_info.lwe_decomposer.gadget_vector(); + seeded_lwe_ksk_keygen( + sk_rlwe, + sk_lwe, + &d_lwe_gadget_vec, + lwe_modop, + &mut p_rng, + rng, + ) + }) + } + + pub(super) fn multi_party_public_key_share>( + &self, + cr_seed: &InteractiveMultiPartyCrs<[u8; 32]>, + client_key: &K, + ) -> CommonReferenceSeededCollectivePublicKeyShare< + ::R, + [u8; 32], + BoolParameters<::MatElement>, + > { + DefaultSecureRng::with_local_mut(|rng| { + let mut share_out = M::R::zeros(self.pbs_info.parameters.rlwe_n().0); + let modop = &self.pbs_info.rlwe_modop; + let nttop = &self.pbs_info.rlwe_nttop; + let pk_seed = cr_seed.public_key_share_seed::(); + let mut main_prng = DefaultSecureRng::new_seeded(pk_seed); + public_key_share( + &mut share_out, + &client_key.sk_rlwe(), + modop, + nttop, + &mut main_prng, + rng, + ); + CommonReferenceSeededCollectivePublicKeyShare::new( + share_out, + pk_seed, + self.pbs_info.parameters.clone(), + ) + }) + } + + pub fn sk_encrypt>( + &self, + m: bool, + client_key: &K, + ) -> M::R { + //FIXME(Jay): Figure out a way to get Q/8 form modulus + let m = if m { + // Q/8 + self.pbs_info.rlwe_q().true_el() + } else { + // -Q/8 + self.pbs_info.rlwe_q().false_el() + }; + + DefaultSecureRng::with_local_mut(|rng| { + encrypt_lwe(&m, &client_key.sk_rlwe(), &self.pbs_info.rlwe_modop, rng) + }) + } + + pub fn sk_decrypt>( + &self, + lwe_ct: &M::R, + client_key: &K, + ) -> bool { + let m = decrypt_lwe(lwe_ct, &client_key.sk_rlwe(), &self.pbs_info.rlwe_modop); + self.pbs_info.rlwe_q().decode(m) + } +} + +impl BoolEvaluator +where + M: MatrixMut + MatrixEntity, + M::R: RowMut + RowEntity, + M::MatElement: PrimInt + FromPrimitive + One + Copy + Zero + Display + WrappingSub + NumInfo, + RlweModOp: VectorOps + ArithmeticOps, + LweModOp: VectorOps + ArithmeticOps, + NttOp: Ntt, +{ + /// Returns c0 + c1 + Q/4 + fn _add_and_shift_lwe_cts(&self, c0: &mut M::R, c1: &M::R) { + let modop = &self.pbs_info.rlwe_modop; + modop.elwise_add_mut(c0.as_mut(), c1.as_ref()); + // +Q/4 + c0.as_mut()[0] = modop.add(&c0.as_ref()[0], &self.pbs_info.rlwe_qby4); + } + + /// Returns 2(c0 - c1) + Q/4 + fn _subtract_double_lwe_cts(&self, c0: &mut M::R, c1: &M::R) { + let modop = &self.pbs_info.rlwe_modop; + // c0 - c1 + modop.elwise_sub_mut(c0.as_mut(), c1.as_ref()); + + // double + c0.as_mut().iter_mut().for_each(|v| *v = modop.add(v, v)); + } +} + +impl BooleanGates + for BoolEvaluator +where + M: MatrixMut + MatrixEntity, + M::R: RowMut + RowEntity + Clone, + M::MatElement: PrimInt + + FromPrimitive + + One + + Copy + + Zero + + Display + + WrappingSub + + NumInfo + + From + + WrappingAdd + + Debug, + RlweModOp: VectorOps + + ArithmeticOps + + ShoupMatrixFMA, + LweModOp: VectorOps + ArithmeticOps, + NttOp: Ntt, + Skey: PbsKey::RgswCt, LweKskKey = M>, + ::RgswCt: WithShoupRepr, +{ + type Ciphertext = M::R; + type Key = Skey; + + fn nand_inplace(&mut self, c0: &mut M::R, c1: &M::R, server_key: &Self::Key) { + self._add_and_shift_lwe_cts(c0, c1); + + // PBS + pbs( + &self.pbs_info, + &self.nand_test_vec, + c0, + server_key, + &mut self.scratch_memory.lwe_vector, + &mut self.scratch_memory.decomposition_matrix, + ); + } + + fn and_inplace(&mut self, c0: &mut M::R, c1: &M::R, server_key: &Self::Key) { + self._add_and_shift_lwe_cts(c0, c1); + + // PBS + pbs( + &self.pbs_info, + &self.and_test_vec, + c0, + server_key, + &mut self.scratch_memory.lwe_vector, + &mut self.scratch_memory.decomposition_matrix, + ); + } + + fn or_inplace(&mut self, c0: &mut M::R, c1: &M::R, server_key: &Self::Key) { + self._add_and_shift_lwe_cts(c0, c1); + + // PBS + pbs( + &self.pbs_info, + &self.or_test_vec, + c0, + server_key, + &mut self.scratch_memory.lwe_vector, + &mut self.scratch_memory.decomposition_matrix, + ); + } + + fn nor_inplace(&mut self, c0: &mut M::R, c1: &M::R, server_key: &Self::Key) { + self._add_and_shift_lwe_cts(c0, c1); + + // PBS + pbs( + &self.pbs_info, + &self.nor_test_vec, + c0, + server_key, + &mut self.scratch_memory.lwe_vector, + &mut self.scratch_memory.decomposition_matrix, + ) + } + + fn xor_inplace(&mut self, c0: &mut M::R, c1: &M::R, server_key: &Self::Key) { + self._subtract_double_lwe_cts(c0, c1); + + // PBS + pbs( + &self.pbs_info, + &self.xor_test_vec, + c0, + server_key, + &mut self.scratch_memory.lwe_vector, + &mut self.scratch_memory.decomposition_matrix, + ); + } + + fn xnor_inplace(&mut self, c0: &mut M::R, c1: &M::R, server_key: &Self::Key) { + self._subtract_double_lwe_cts(c0, c1); + + // PBS + pbs( + &self.pbs_info, + &self.xnor_test_vec, + c0, + server_key, + &mut self.scratch_memory.lwe_vector, + &mut self.scratch_memory.decomposition_matrix, + ); + } + + fn not_inplace(&self, c0: &mut M::R) { + let modop = &self.pbs_info.rlwe_modop; + c0.as_mut().iter_mut().for_each(|v| *v = modop.neg(v)); + } + + fn and( + &mut self, + c0: &Self::Ciphertext, + c1: &Self::Ciphertext, + key: &Self::Key, + ) -> Self::Ciphertext { + let mut out = c0.clone(); + self.and_inplace(&mut out, c1, key); + out + } + + fn nand( + &mut self, + c0: &Self::Ciphertext, + c1: &Self::Ciphertext, + key: &Self::Key, + ) -> Self::Ciphertext { + let mut out = c0.clone(); + self.nand_inplace(&mut out, c1, key); + out + } + + fn or( + &mut self, + c0: &Self::Ciphertext, + c1: &Self::Ciphertext, + key: &Self::Key, + ) -> Self::Ciphertext { + let mut out = c0.clone(); + self.or_inplace(&mut out, c1, key); + out + } + + fn nor( + &mut self, + c0: &Self::Ciphertext, + c1: &Self::Ciphertext, + key: &Self::Key, + ) -> Self::Ciphertext { + let mut out = c0.clone(); + self.nor_inplace(&mut out, c1, key); + out + } + + fn xnor( + &mut self, + c0: &Self::Ciphertext, + c1: &Self::Ciphertext, + key: &Self::Key, + ) -> Self::Ciphertext { + let mut out = c0.clone(); + self.xnor_inplace(&mut out, c1, key); + out + } + + fn xor( + &mut self, + c0: &Self::Ciphertext, + c1: &Self::Ciphertext, + key: &Self::Key, + ) -> Self::Ciphertext { + let mut out = c0.clone(); + self.xor_inplace(&mut out, c1, key); + out + } + + fn not(&self, c: &Self::Ciphertext) -> Self::Ciphertext { + let mut out = c.clone(); + self.not_inplace(&mut out); + out + } +} diff --git a/src/bool/keys.rs b/src/bool/keys.rs new file mode 100644 index 0000000..16dfdf1 --- /dev/null +++ b/src/bool/keys.rs @@ -0,0 +1,1559 @@ +use std::{collections::HashMap, marker::PhantomData}; + +use crate::{ + backend::{ModInit, VectorOps}, + pbs::WithShoupRepr, + random::{NewWithSeed, RandomFillUniformInModulus}, + utils::ToShoup, + Matrix, MatrixEntity, MatrixMut, RowEntity, RowMut, +}; + +use super::parameters::{BoolParameters, CiphertextModulus}; + +pub(crate) trait SinglePartyClientKey { + type Element; + fn sk_rlwe(&self) -> Vec; + fn sk_lwe(&self) -> Vec; +} + +pub(crate) trait InteractiveMultiPartyClientKey { + type Element; + fn sk_rlwe(&self) -> Vec; + fn sk_lwe(&self) -> Vec; +} + +pub(crate) trait NonInteractiveMultiPartyClientKey { + type Element; + fn sk_rlwe(&self) -> Vec; + fn sk_u_rlwe(&self) -> Vec; + fn sk_lwe(&self) -> Vec; +} + +/// Client key +/// +/// Key is used for all parameter varians - Single party, interactive +/// multi-party, and non-interactive multi-party. The only stored the main seed +/// and seeds of the Rlwe/Lwe secrets are derived at puncturing the seed desired +/// number of times. +/// +/// ### Punctures required: +/// +/// Puncture 1 -> Seed of RLWE secret used as main RLWE secret for +/// single-party, interactive/non-interactive multi-party +/// +/// Puncture 2 -> Seed of LWE secret used main LWE secret for single-party, +/// interactive/non-interactive multi-party +/// +/// Puncture 3 -> Seed of RLWE secret used as `u` in +/// non-interactive multi-party. +#[derive(Clone)] +pub struct ClientKey { + seed: S, + parameters: BoolParameters, +} + +mod impl_ck { + use crate::{ + parameters::SecretKeyDistribution, + random::{DefaultSecureRng, RandomFillGaussian}, + utils::{fill_random_ternary_secret_with_hamming_weight, puncture_p_rng}, + }; + + use super::*; + + impl ClientKey<[u8; 32], E> { + pub(in super::super) fn new(parameters: BoolParameters) -> ClientKey<[u8; 32], E> { + let mut rng = DefaultSecureRng::new(); + let mut seed = [0u8; 32]; + rng.fill_bytes(&mut seed); + Self { seed, parameters } + } + } + + impl SinglePartyClientKey for ClientKey<[u8; 32], E> { + type Element = i32; + fn sk_lwe(&self) -> Vec { + let mut p_rng = DefaultSecureRng::new_seeded(self.seed); + let lwe_seed = puncture_p_rng::<[u8; 32], DefaultSecureRng>(&mut p_rng, 2); + + let mut lwe_prng = DefaultSecureRng::new_seeded(lwe_seed); + + let mut out = vec![0i32; self.parameters.lwe_n().0]; + + match self.parameters.lwe_secret_key_dist() { + &SecretKeyDistribution::ErrorDistribution => { + RandomFillGaussian::random_fill(&mut lwe_prng, &mut out); + } + &SecretKeyDistribution::TernaryDistribution => { + fill_random_ternary_secret_with_hamming_weight( + &mut out, + self.parameters.lwe_n().0 >> 1, + &mut lwe_prng, + ); + } + } + out + } + fn sk_rlwe(&self) -> Vec { + assert!( + self.parameters.rlwe_secret_key_dist() + == &SecretKeyDistribution::TernaryDistribution + ); + + let mut p_rng = DefaultSecureRng::new_seeded(self.seed); + let rlwe_seed = puncture_p_rng::<[u8; 32], DefaultSecureRng>(&mut p_rng, 1); + + let mut rlwe_prng = DefaultSecureRng::new_seeded(rlwe_seed); + let mut out = vec![0i32; self.parameters.rlwe_n().0]; + fill_random_ternary_secret_with_hamming_weight( + &mut out, + self.parameters.rlwe_n().0 >> 1, + &mut rlwe_prng, + ); + out + } + } + + #[cfg(feature = "interactive_mp")] + impl InteractiveMultiPartyClientKey for ClientKey<[u8; 32], E> { + type Element = i32; + fn sk_lwe(&self) -> Vec { + ::sk_lwe(&self) + } + fn sk_rlwe(&self) -> Vec { + ::sk_rlwe(&self) + } + } + + #[cfg(feature = "non_interactive_mp")] + impl NonInteractiveMultiPartyClientKey for ClientKey<[u8; 32], E> { + type Element = i32; + fn sk_lwe(&self) -> Vec { + ::sk_lwe(&self) + } + fn sk_rlwe(&self) -> Vec { + ::sk_rlwe(&self) + } + fn sk_u_rlwe(&self) -> Vec { + assert!( + self.parameters.rlwe_secret_key_dist() + == &SecretKeyDistribution::TernaryDistribution + ); + + let mut p_rng = DefaultSecureRng::new_seeded(self.seed); + let rlwe_seed = puncture_p_rng::<[u8; 32], DefaultSecureRng>(&mut p_rng, 3); + + let mut rlwe_prng = DefaultSecureRng::new_seeded(rlwe_seed); + let mut out = vec![0i32; self.parameters.rlwe_n().0]; + fill_random_ternary_secret_with_hamming_weight( + &mut out, + self.parameters.rlwe_n().0 >> 1, + &mut rlwe_prng, + ); + out + } + } +} + +/// Public key +pub struct PublicKey { + key: M, + _phantom: PhantomData<(Rng, ModOp)>, +} + +pub(super) mod impl_pk { + use super::*; + + impl PublicKey { + pub(in super::super) fn key(&self) -> &M { + &self.key + } + } + + impl< + M: MatrixMut + MatrixEntity, + Rng: NewWithSeed + + RandomFillUniformInModulus<[M::MatElement], CiphertextModulus>, + ModOp, + > From, ModOp>> + for PublicKey + where + ::R: RowMut, + M::MatElement: Copy, + { + fn from( + value: SeededPublicKey, ModOp>, + ) -> Self { + let mut prng = Rng::new_with_seed(value.seed); + + let mut key = M::zeros(2, value.parameters.rlwe_n().0); + // sample A + RandomFillUniformInModulus::random_fill( + &mut prng, + value.parameters.rlwe_q(), + key.get_row_mut(0), + ); + // Copy over B + key.get_row_mut(1).copy_from_slice(value.part_b.as_ref()); + + PublicKey { + key, + _phantom: PhantomData, + } + } + } + + impl< + M: MatrixMut + MatrixEntity, + Rng: NewWithSeed + + RandomFillUniformInModulus<[M::MatElement], CiphertextModulus>, + ModOp: VectorOps + ModInit>, + > + From< + &[CommonReferenceSeededCollectivePublicKeyShare< + M::R, + Rng::Seed, + BoolParameters, + >], + > for PublicKey + where + ::R: RowMut, + Rng::Seed: Copy + PartialEq, + M::MatElement: PartialEq + Copy, + { + fn from( + value: &[CommonReferenceSeededCollectivePublicKeyShare< + M::R, + Rng::Seed, + BoolParameters, + >], + ) -> Self { + assert!(value.len() > 0); + + let parameters = &value[0].parameters; + let mut key = M::zeros(2, parameters.rlwe_n().0); + + // sample A + let seed = value[0].cr_seed; + let mut main_rng = Rng::new_with_seed(seed); + RandomFillUniformInModulus::random_fill( + &mut main_rng, + parameters.rlwe_q(), + key.get_row_mut(0), + ); + + // Sum all Bs + let rlweq_modop = ModOp::new(parameters.rlwe_q().clone()); + value.iter().for_each(|share_i| { + assert!(share_i.cr_seed == seed); + assert!(&share_i.parameters == parameters); + + rlweq_modop.elwise_add_mut(key.get_row_mut(1), share_i.share.as_ref()); + }); + + PublicKey { + key, + _phantom: PhantomData, + } + } + } +} + +/// Seeded public key +struct SeededPublicKey { + part_b: Ro, + seed: S, + parameters: P, + _phantom: PhantomData, +} + +mod impl_seeded_pk { + use super::*; + + impl + From<&[CommonReferenceSeededCollectivePublicKeyShare>]> + for SeededPublicKey, ModOp> + where + ModOp: VectorOps + ModInit>, + S: PartialEq + Clone, + R: RowMut + RowEntity + Clone, + R::Element: Clone + PartialEq, + { + fn from( + value: &[CommonReferenceSeededCollectivePublicKeyShare< + R, + S, + BoolParameters, + >], + ) -> Self { + assert!(value.len() > 0); + + let parameters = &value[0].parameters; + let cr_seed = value[0].cr_seed.clone(); + + // Sum all Bs + let rlweq_modop = ModOp::new(parameters.rlwe_q().clone()); + let mut part_b = value[0].share.clone(); + value.iter().skip(1).for_each(|share_i| { + assert!(&share_i.cr_seed == &cr_seed); + assert!(&share_i.parameters == parameters); + + rlweq_modop.elwise_add_mut(part_b.as_mut(), share_i.share.as_ref()); + }); + + Self { + part_b, + seed: cr_seed, + parameters: parameters.clone(), + _phantom: PhantomData, + } + } + } +} + +/// CRS seeded collective public key share +pub struct CommonReferenceSeededCollectivePublicKeyShare { + /// Public key share polynomial + share: Ro, + /// Common reference seed + cr_seed: S, + /// Parameters + parameters: P, +} +impl CommonReferenceSeededCollectivePublicKeyShare { + pub(super) fn new(share: Ro, cr_seed: S, parameters: P) -> Self { + CommonReferenceSeededCollectivePublicKeyShare { + share, + cr_seed, + parameters, + } + } +} + +/// Common reference seed seeded interactive multi-party server key share +pub struct CommonReferenceSeededInteractiveMultiPartyServerKeyShare { + /// Public key encrypted RGSW(m = X^{s[i]}) ciphertexts for LWE secret + /// indices for which `Self` is the leader. Note that when `Self` is + /// leader RGSW ciphertext is encrypted using RLWE x RGSW decomposer + self_leader_rgsws: Vec, + /// Public key encrypted RGSW(m = X^{s[i]}) ciphertext for LWE secret + /// indices for which `Self` is `not` the leader. Note that when `Self` + /// is not the leader RGSW ciphertext is encrypted using RGSW1 + /// decomposer for RGSW0 x RGSW1 + not_self_leader_rgsws: Vec, + /// Auto key shares for auto elements [-g, g, g^2, .., g^{w}] where `w` + /// is the window size parameter. Share corresponding to auto element -g + /// is stored at key `0` and share corresponding to auto element g^{k} is + /// stored at key `k`. + auto_keys: HashMap, + /// LWE key switching key share to key switching ciphertext LWE_{q, s}(m) to + /// LWE_{q, z}(m) where q is LWE ciphertext modulus, `s` is the ideal RLWE + /// secret with dimension N, and `z` is the ideal LWE secret of dimension n. + lwe_ksk: M::R, + /// Common reference seed + cr_seed: S, + parameters: P, + /// User id assigned by the server. + /// + /// User id must be unique and a number in range [0, total_users) + user_id: usize, +} + +impl CommonReferenceSeededInteractiveMultiPartyServerKeyShare { + pub(super) fn new( + self_leader_rgsws: Vec, + not_self_leader_rgsws: Vec, + auto_keys: HashMap, + lwe_ksk: M::R, + cr_seed: S, + parameters: P, + user_id: usize, + ) -> Self { + CommonReferenceSeededInteractiveMultiPartyServerKeyShare { + self_leader_rgsws, + not_self_leader_rgsws, + auto_keys, + lwe_ksk, + cr_seed, + parameters, + user_id, + } + } + + pub(super) fn cr_seed(&self) -> &S { + &self.cr_seed + } + + pub(super) fn parameters(&self) -> &P { + &self.parameters + } + + pub(super) fn auto_keys(&self) -> &HashMap { + &self.auto_keys + } + + pub(crate) fn self_leader_rgsws(&self) -> &[M] { + &self.self_leader_rgsws + } + + pub(super) fn not_self_leader_rgsws(&self) -> &[M] { + &self.not_self_leader_rgsws + } + + pub(super) fn lwe_ksk(&self) -> &M::R { + &self.lwe_ksk + } + + pub(super) fn user_id(&self) -> usize { + self.user_id + } +} + +/// Common reference seeded interactive multi-party server key +pub struct SeededInteractiveMultiPartyServerKey { + /// RGSW ciphertexts RGSW(X^{s[i]}) encrypted under ideal RLWE secret key + /// where `s` is ideal LWE secret key for each LWE secret dimension. + rgsw_cts: Vec, + /// Seeded auto keys under ideal RLWE secret for RLWE automorphisms with + /// auto elements [-g, g, g^2,..., g^{w}]. Auto key corresponidng to + /// auto element -g is stored at key `0` and key corresponding to auto + /// element g^{k} is stored at key `k` + auto_keys: HashMap, + /// Seeded LWE key switching key under ideal LWE secret to switch LWE_{q, + /// s}(m) to LWE_{q, z}(m) where s is ideal RLWE secret and z is ideal LWE + /// secret. + lwe_ksk: M::R, + /// Common reference seed + cr_seed: S, + parameters: P, +} + +impl SeededInteractiveMultiPartyServerKey { + pub(super) fn new( + rgsw_cts: Vec, + auto_keys: HashMap, + lwe_ksk: M::R, + cr_seed: S, + parameters: P, + ) -> Self { + SeededInteractiveMultiPartyServerKey { + rgsw_cts, + auto_keys, + lwe_ksk, + cr_seed, + parameters, + } + } + + pub(super) fn rgsw_cts(&self) -> &[M] { + &self.rgsw_cts + } +} + +/// Seeded single party server key +pub struct SeededSinglePartyServerKey { + /// Rgsw cts of LWE secret elements + pub(crate) rgsw_cts: Vec, + /// Auto keys. Key corresponding to g^{k} is at index `k`. Key corresponding + /// to -g is at 0 + pub(crate) auto_keys: HashMap, + /// LWE ksk to key switching LWE ciphertext from RLWE secret to LWE secret + pub(crate) lwe_ksk: M::R, + /// Parameters + pub(crate) parameters: P, + /// Main seed + pub(crate) seed: S, +} +impl SeededSinglePartyServerKey, S> { + pub(super) fn from_raw( + auto_keys: HashMap, + rgsw_cts: Vec, + lwe_ksk: M::R, + parameters: BoolParameters, + seed: S, + ) -> Self { + // sanity checks + auto_keys.iter().for_each(|v| { + assert!( + v.1.dimension() + == ( + parameters.auto_decomposition_count().0, + parameters.rlwe_n().0 + ) + ) + }); + + let (part_a_d, part_b_d) = parameters.rlwe_rgsw_decomposition_count(); + rgsw_cts.iter().for_each(|v| { + assert!(v.dimension() == (part_a_d.0 * 2 + part_b_d.0, parameters.rlwe_n().0)) + }); + assert!( + lwe_ksk.as_ref().len() + == (parameters.lwe_decomposition_count().0 * parameters.rlwe_n().0) + ); + + SeededSinglePartyServerKey { + rgsw_cts, + auto_keys, + lwe_ksk, + parameters, + seed, + } + } +} + +/// Server key in evaluation domain +pub(crate) struct ServerKeyEvaluationDomain { + /// RGSW ciphertext RGSW(X^{s[i]}) for each LWE index in evaluation domain + rgsw_cts: Vec, + /// Auto keys for all auto elements [-g, g, g^2,..., g^w] in evaluation + /// domain + galois_keys: HashMap, + /// LWE key switching key to key switch LWE_{q, s}(m) to LWE_{q, z}(m) + lwe_ksk: M, + parameters: P, + _phanton: PhantomData<(R, N)>, +} + +pub(super) mod impl_server_key_eval_domain { + use itertools::{izip, Itertools}; + + use crate::{ + bool::evaluator::InteractiveMultiPartyCrs, + ntt::{Ntt, NttInit}, + pbs::PbsKey, + random::RandomFill, + }; + + use super::*; + + impl ServerKeyEvaluationDomain { + pub(in super::super) fn rgsw_cts(&self) -> &[M] { + &self.rgsw_cts + } + } + + impl< + M: MatrixMut + MatrixEntity, + R: RandomFillUniformInModulus<[M::MatElement], CiphertextModulus> + + NewWithSeed, + N: NttInit> + Ntt, + > From<&SeededSinglePartyServerKey, R::Seed>> + for ServerKeyEvaluationDomain, R, N> + where + ::R: RowMut, + M::MatElement: Copy, + R::Seed: Clone, + { + fn from( + value: &SeededSinglePartyServerKey, R::Seed>, + ) -> Self { + let mut main_prng = R::new_with_seed(value.seed.clone()); + let parameters = &value.parameters; + let g = parameters.g() as isize; + let ring_size = value.parameters.rlwe_n().0; + let lwe_n = value.parameters.lwe_n().0; + let rlwe_q = value.parameters.rlwe_q(); + let lwq_q = value.parameters.lwe_q(); + + let nttop = N::new(rlwe_q, ring_size); + + // galois keys + let mut auto_keys = HashMap::new(); + let auto_decomp_count = parameters.auto_decomposition_count().0; + let auto_element_dlogs = parameters.auto_element_dlogs(); + for i in auto_element_dlogs.into_iter() { + let seeded_auto_key = value.auto_keys.get(&i).unwrap(); + assert!(seeded_auto_key.dimension() == (auto_decomp_count, ring_size)); + + let mut data = M::zeros(auto_decomp_count * 2, ring_size); + + // sample RLWE'_A(-s(X^k)) + data.iter_rows_mut().take(auto_decomp_count).for_each(|ri| { + RandomFillUniformInModulus::random_fill(&mut main_prng, &rlwe_q, ri.as_mut()) + }); + + // copy over RLWE'B_(-s(X^k)) + izip!( + data.iter_rows_mut().skip(auto_decomp_count), + seeded_auto_key.iter_rows() + ) + .for_each(|(to_ri, from_ri)| to_ri.as_mut().copy_from_slice(from_ri.as_ref())); + + // Send to Evaluation domain + data.iter_rows_mut() + .for_each(|ri| nttop.forward(ri.as_mut())); + + auto_keys.insert(i, data); + } + + // RGSW ciphertexts + let (rlrg_a_decomp, rlrg_b_decomp) = parameters.rlwe_rgsw_decomposition_count(); + let rgsw_cts = value + .rgsw_cts + .iter() + .map(|seeded_rgsw_si| { + assert!( + seeded_rgsw_si.dimension() + == (rlrg_a_decomp.0 * 2 + rlrg_b_decomp.0, ring_size) + ); + + let mut data = M::zeros(rlrg_a_decomp.0 * 2 + rlrg_b_decomp.0 * 2, ring_size); + + // copy over RLWE'(-sm) + izip!( + data.iter_rows_mut().take(rlrg_a_decomp.0 * 2), + seeded_rgsw_si.iter_rows().take(rlrg_a_decomp.0 * 2) + ) + .for_each(|(to_ri, from_ri)| to_ri.as_mut().copy_from_slice(from_ri.as_ref())); + + // sample RLWE'_A(m) + data.iter_rows_mut() + .skip(rlrg_a_decomp.0 * 2) + .take(rlrg_b_decomp.0) + .for_each(|ri| { + RandomFillUniformInModulus::random_fill( + &mut main_prng, + &rlwe_q, + ri.as_mut(), + ) + }); + + // copy over RLWE'_B(m) + izip!( + data.iter_rows_mut() + .skip(rlrg_a_decomp.0 * 2 + rlrg_b_decomp.0), + seeded_rgsw_si.iter_rows().skip(rlrg_a_decomp.0 * 2) + ) + .for_each(|(to_ri, from_ri)| to_ri.as_mut().copy_from_slice(from_ri.as_ref())); + + // send polynomials to evaluation domain + data.iter_rows_mut() + .for_each(|ri| nttop.forward(ri.as_mut())); + + data + }) + .collect_vec(); + + // LWE ksk + let lwe_ksk = { + let d = parameters.lwe_decomposition_count().0; + assert!(value.lwe_ksk.as_ref().len() == d * ring_size); + + let mut data = M::zeros(d * ring_size, lwe_n + 1); + izip!(data.iter_rows_mut(), value.lwe_ksk.as_ref().iter()).for_each( + |(lwe_i, bi)| { + RandomFillUniformInModulus::random_fill( + &mut main_prng, + &lwq_q, + &mut lwe_i.as_mut()[1..], + ); + lwe_i.as_mut()[0] = *bi; + }, + ); + + data + }; + + ServerKeyEvaluationDomain { + rgsw_cts, + galois_keys: auto_keys, + lwe_ksk, + parameters: parameters.clone(), + _phanton: PhantomData, + } + } + } + + impl< + M: MatrixMut + MatrixEntity, + Rng: NewWithSeed, + N: NttInit> + Ntt, + > + From< + &SeededInteractiveMultiPartyServerKey< + M, + InteractiveMultiPartyCrs, + BoolParameters, + >, + > for ServerKeyEvaluationDomain, Rng, N> + where + ::R: RowMut, + Rng::Seed: Copy + Default, + Rng: RandomFillUniformInModulus<[M::MatElement], CiphertextModulus> + + RandomFill, + M::MatElement: Copy, + { + fn from( + value: &SeededInteractiveMultiPartyServerKey< + M, + InteractiveMultiPartyCrs, + BoolParameters, + >, + ) -> Self { + let g = value.parameters.g() as isize; + let rlwe_n = value.parameters.rlwe_n().0; + let lwe_n = value.parameters.lwe_n().0; + let rlwe_q = value.parameters.rlwe_q(); + let lwe_q = value.parameters.lwe_q(); + + let rlwe_nttop = N::new(rlwe_q, rlwe_n); + + // auto keys + let mut auto_keys = HashMap::new(); + { + let mut auto_prng = Rng::new_with_seed(value.cr_seed.auto_keys_cts_seed::()); + let auto_d_count = value.parameters.auto_decomposition_count().0; + let auto_element_dlogs = value.parameters.auto_element_dlogs(); + for i in auto_element_dlogs.into_iter() { + let mut key = M::zeros(auto_d_count * 2, rlwe_n); + + // sample a + key.iter_rows_mut().take(auto_d_count).for_each(|ri| { + RandomFillUniformInModulus::random_fill( + &mut auto_prng, + &rlwe_q, + ri.as_mut(), + ) + }); + + let key_part_b = value.auto_keys.get(&i).unwrap(); + assert!(key_part_b.dimension() == (auto_d_count, rlwe_n)); + izip!( + key.iter_rows_mut().skip(auto_d_count), + key_part_b.iter_rows() + ) + .for_each(|(to_ri, from_ri)| { + to_ri.as_mut().copy_from_slice(from_ri.as_ref()); + }); + + // send to evaluation domain + key.iter_rows_mut() + .for_each(|ri| rlwe_nttop.forward(ri.as_mut())); + + auto_keys.insert(i, key); + } + } + + // rgsw cts + let (rlrg_d_a, rlrg_d_b) = value.parameters.rlwe_rgsw_decomposition_count(); + let rgsw_ct_out = rlrg_d_a.0 * 2 + rlrg_d_b.0 * 2; + let rgsw_cts = value + .rgsw_cts + .iter() + .map(|ct_i_in| { + assert!(ct_i_in.dimension() == (rgsw_ct_out, rlwe_n)); + let mut eval_ct_i_out = M::zeros(rgsw_ct_out, rlwe_n); + + izip!(eval_ct_i_out.iter_rows_mut(), ct_i_in.iter_rows()).for_each( + |(to_ri, from_ri)| { + to_ri.as_mut().copy_from_slice(from_ri.as_ref()); + rlwe_nttop.forward(to_ri.as_mut()); + }, + ); + + eval_ct_i_out + }) + .collect_vec(); + + // lwe ksk + let mut lwe_ksk_prng = Rng::new_with_seed(value.cr_seed.lwe_ksk_cts_seed_seed::()); + let d_lwe = value.parameters.lwe_decomposition_count().0; + let mut lwe_ksk = M::zeros(rlwe_n * d_lwe, lwe_n + 1); + izip!(lwe_ksk.iter_rows_mut(), value.lwe_ksk.as_ref().iter()).for_each( + |(lwe_i, bi)| { + RandomFillUniformInModulus::random_fill( + &mut lwe_ksk_prng, + &lwe_q, + &mut lwe_i.as_mut()[1..], + ); + lwe_i.as_mut()[0] = *bi; + }, + ); + + ServerKeyEvaluationDomain { + rgsw_cts, + galois_keys: auto_keys, + lwe_ksk, + parameters: value.parameters.clone(), + _phanton: PhantomData, + } + } + } + + impl PbsKey for ServerKeyEvaluationDomain { + type AutoKey = M; + type LweKskKey = M; + type RgswCt = M; + + fn galois_key_for_auto(&self, k: usize) -> &Self::AutoKey { + self.galois_keys.get(&k).unwrap() + } + fn rgsw_ct_lwe_si(&self, si: usize) -> &Self::RgswCt { + &self.rgsw_cts[si] + } + + fn lwe_ksk(&self) -> &Self::LweKskKey { + &self.lwe_ksk + } + } + + #[cfg(test)] + impl super::super::print_noise::CollectRuntimeServerKeyStats + for ServerKeyEvaluationDomain + { + type M = M; + fn galois_key_for_auto(&self, k: usize) -> &Self::M { + self.galois_keys.get(&k).unwrap() + } + fn lwe_ksk(&self) -> &Self::M { + &self.lwe_ksk + } + fn rgsw_cts_lwe_si(&self, s_index: usize) -> &Self::M { + &self.rgsw_cts[s_index] + } + } +} + +/// Non-interactive multi-party server key in evaluation domain. +/// +/// The key is derived from Seeded non-interactive mmulti-party server key +/// `SeededNonInteractiveMultiPartyServerKey`. +pub(crate) struct NonInteractiveServerKeyEvaluationDomain { + /// RGSW ciphertexts RGSW(X^{s[i]}) under ideal RLWE secret key in + /// evaluation domain + rgsw_cts: Vec, + /// Auto keys for all auto elements [-g, g, g^2, g^w] in evaluation + /// domain + auto_keys: HashMap, + /// LWE key switching key to key switch LWE_{q, s}(m) to LWE_{q, z}(m) + lwe_ksk: M, + /// Key switching key from user j's secret u_j to ideal RLWE secret key `s` + /// in evaluation domain. User j's key switching key is at j'th index. + ui_to_s_ksks: Vec, + parameters: P, + _phanton: PhantomData<(R, N)>, +} + +pub(super) mod impl_non_interactive_server_key_eval_domain { + use itertools::{izip, Itertools}; + + use crate::{bool::evaluator::NonInteractiveMultiPartyCrs, random::RandomFill, Ntt, NttInit}; + + use super::*; + + impl NonInteractiveServerKeyEvaluationDomain { + pub(in super::super) fn rgsw_cts(&self) -> &[M] { + &self.rgsw_cts + } + } + + impl + From< + &SeededNonInteractiveMultiPartyServerKey< + M, + NonInteractiveMultiPartyCrs, + BoolParameters, + >, + > for NonInteractiveServerKeyEvaluationDomain, Rng, N> + where + M: MatrixMut + MatrixEntity + Clone, + Rng: NewWithSeed + + RandomFillUniformInModulus<[M::MatElement], CiphertextModulus> + + RandomFill<::Seed>, + N: Ntt + NttInit>, + M::R: RowMut, + M::MatElement: Copy, + Rng::Seed: Clone + Copy + Default, + { + fn from( + value: &SeededNonInteractiveMultiPartyServerKey< + M, + NonInteractiveMultiPartyCrs, + BoolParameters, + >, + ) -> Self { + let rlwe_nttop = N::new(value.parameters.rlwe_q(), value.parameters.rlwe_n().0); + let ring_size = value.parameters.rlwe_n().0; + + // RGSW cts + // copy over rgsw cts and send to evaluation domain + let mut rgsw_cts = value.rgsw_cts.clone(); + rgsw_cts.iter_mut().for_each(|c| { + c.iter_rows_mut() + .for_each(|ri| rlwe_nttop.forward(ri.as_mut())) + }); + + // Auto keys + // populate pseudo random part of auto keys. Then send auto keys to + // evaluation domain + let mut auto_keys = HashMap::new(); + let auto_seed = value.cr_seed.auto_keys_cts_seed::(); + let mut auto_prng = Rng::new_with_seed(auto_seed); + let auto_element_dlogs = value.parameters.auto_element_dlogs(); + let d_auto = value.parameters.auto_decomposition_count().0; + auto_element_dlogs.iter().for_each(|el| { + let auto_part_b = value + .auto_keys + .get(el) + .expect(&format!("Auto key for element g^{el} not found")); + + assert!(auto_part_b.dimension() == (d_auto, ring_size)); + + let mut auto_ct = M::zeros(d_auto * 2, ring_size); + + // sample part A + auto_ct.iter_rows_mut().take(d_auto).for_each(|ri| { + RandomFillUniformInModulus::random_fill( + &mut auto_prng, + value.parameters.rlwe_q(), + ri.as_mut(), + ) + }); + + // Copy over part B + izip!( + auto_ct.iter_rows_mut().skip(d_auto), + auto_part_b.iter_rows() + ) + .for_each(|(to_ri, from_ri)| to_ri.as_mut().copy_from_slice(from_ri.as_ref())); + + // send to evaluation domain + auto_ct + .iter_rows_mut() + .for_each(|r| rlwe_nttop.forward(r.as_mut())); + + auto_keys.insert(*el, auto_ct); + }); + + // LWE ksk + // populate pseudo random part of lwe ciphertexts in ksk and copy over part b + // elements + let lwe_ksk_seed = value.cr_seed.lwe_ksk_cts_seed::(); + let mut lwe_ksk_prng = Rng::new_with_seed(lwe_ksk_seed); + let mut lwe_ksk = M::zeros( + value.parameters.lwe_decomposition_count().0 * ring_size, + value.parameters.lwe_n().0 + 1, + ); + lwe_ksk.iter_rows_mut().for_each(|ri| { + // first element is resereved for part b. Only sample a_is in the rest + RandomFillUniformInModulus::random_fill( + &mut lwe_ksk_prng, + value.parameters.lwe_q(), + &mut ri.as_mut()[1..], + ) + }); + // copy over part bs + assert!( + value.lwe_ksk.as_ref().len() + == value.parameters.lwe_decomposition_count().0 * ring_size + ); + izip!(value.lwe_ksk.as_ref().iter(), lwe_ksk.iter_rows_mut()).for_each( + |(b_el, lwe_ct)| { + lwe_ct.as_mut()[0] = *b_el; + }, + ); + + // u_i to s ksk + let d_uitos = value + .parameters + .non_interactive_ui_to_s_key_switch_decomposition_count() + .0; + let ui_to_s_ksks = value + .ui_to_s_ksks + .iter() + .enumerate() + .map(|(user_id, incoming_ksk_partb)| { + let user_i_seed = value.cr_seed.ui_to_s_ks_seed_for_user_i::(user_id); + let mut prng = Rng::new_with_seed(user_i_seed); + + let mut ksk_ct = M::zeros(d_uitos * 2, ring_size); + + ksk_ct.iter_rows_mut().take(d_uitos).for_each(|r| { + RandomFillUniformInModulus::random_fill( + &mut prng, + value.parameters.rlwe_q(), + r.as_mut(), + ); + }); + + assert!(incoming_ksk_partb.dimension() == (d_uitos, ring_size)); + izip!( + ksk_ct.iter_rows_mut().skip(d_uitos), + incoming_ksk_partb.iter_rows() + ) + .for_each(|(to_ri, from_ri)| { + to_ri.as_mut().copy_from_slice(from_ri.as_ref()); + }); + + ksk_ct + .iter_rows_mut() + .for_each(|r| rlwe_nttop.forward(r.as_mut())); + ksk_ct + }) + .collect_vec(); + + NonInteractiveServerKeyEvaluationDomain { + rgsw_cts, + auto_keys, + lwe_ksk, + ui_to_s_ksks, + parameters: value.parameters.clone(), + _phanton: PhantomData, + } + } + } + + #[cfg(test)] + impl super::super::print_noise::CollectRuntimeServerKeyStats + for NonInteractiveServerKeyEvaluationDomain + { + type M = M; + fn galois_key_for_auto(&self, k: usize) -> &Self::M { + self.auto_keys.get(&k).unwrap() + } + fn lwe_ksk(&self) -> &Self::M { + &self.lwe_ksk + } + fn rgsw_cts_lwe_si(&self, s_index: usize) -> &Self::M { + &self.rgsw_cts[s_index] + } + } +} + +/// Seeded non-interactive multi-party server key. +/// +/// Given common reference seeded non-interactive multi-party key shares of each +/// users with unique user-ids, seeded non-interactive can be generated using +/// `BoolEvaluator::aggregate_non_interactive_multi_party_key_share` +pub struct SeededNonInteractiveMultiPartyServerKey { + /// Key switching key from user j's secret u_j to ideal RLWE secret key `s`. + /// User j's key switching key is at j'th index. + ui_to_s_ksks: Vec, + /// RGSW ciphertexts RGSW(X^{s[i]}) under ideal RLWE secret key + rgsw_cts: Vec, + /// Auto keys for all auto elements [-g, g, g^2, g^w] + auto_keys: HashMap, + /// LWE key switching key to key switch LWE_{q, s}(m) to LWE_{q, z}(m) + lwe_ksk: M::R, + /// Common reference seed + cr_seed: S, + parameters: P, +} + +impl SeededNonInteractiveMultiPartyServerKey { + pub(super) fn new( + ui_to_s_ksks: Vec, + rgsw_cts: Vec, + auto_keys: HashMap, + lwe_ksk: M::R, + cr_seed: S, + parameters: P, + ) -> Self { + Self { + ui_to_s_ksks, + + rgsw_cts, + auto_keys, + lwe_ksk, + cr_seed, + parameters, + } + } +} + +/// This key is equivalent to NonInteractiveServerKeyEvaluationDomain with the +/// addition that each polynomial in evaluation domain has a corresponding shoup +/// representation suitable for shoup multiplication. +pub(crate) struct ShoupNonInteractiveServerKeyEvaluationDomain { + rgsw_cts: Vec>, + auto_keys: HashMap>, + lwe_ksk: M, + ui_to_s_ksks: Vec>, +} + +mod impl_shoup_non_interactive_server_key_eval_domain { + use itertools::Itertools; + use num_traits::{FromPrimitive, PrimInt, ToPrimitive}; + + use super::*; + use crate::{backend::Modulus, decomposer::NumInfo, pbs::PbsKey}; + + impl ShoupNonInteractiveServerKeyEvaluationDomain { + pub(in super::super) fn ui_to_s_ksk(&self, user_id: usize) -> &NormalAndShoup { + &self.ui_to_s_ksks[user_id] + } + } + + impl, R, N> + From, R, N>> + for ShoupNonInteractiveServerKeyEvaluationDomain + where + M::MatElement: FromPrimitive + ToPrimitive + PrimInt + NumInfo, + { + fn from( + value: NonInteractiveServerKeyEvaluationDomain, R, N>, + ) -> Self { + let rlwe_q = value.parameters.rlwe_q().q().unwrap(); + + let rgsw_dim = ( + value.parameters.rlwe_rgsw_decomposition_count().0 .0 * 2 + + value.parameters.rlwe_rgsw_decomposition_count().1 .0 * 2, + value.parameters.rlwe_n().0, + ); + let rgsw_cts = value + .rgsw_cts + .into_iter() + .map(|m| { + assert!(m.dimension() == rgsw_dim); + NormalAndShoup::new_with_modulus(m, rlwe_q) + }) + .collect_vec(); + + let auto_dim = ( + value.parameters.auto_decomposition_count().0 * 2, + value.parameters.rlwe_n().0, + ); + let mut auto_keys = HashMap::new(); + value.auto_keys.into_iter().for_each(|(k, v)| { + assert!(v.dimension() == auto_dim); + auto_keys.insert(k, NormalAndShoup::new_with_modulus(v, rlwe_q)); + }); + + let ui_ks_dim = ( + value + .parameters + .non_interactive_ui_to_s_key_switch_decomposition_count() + .0 + * 2, + value.parameters.rlwe_n().0, + ); + let ui_to_s_ksks = value + .ui_to_s_ksks + .into_iter() + .map(|m| { + assert!(m.dimension() == ui_ks_dim); + NormalAndShoup::new_with_modulus(m, rlwe_q) + }) + .collect_vec(); + + assert!( + value.lwe_ksk.dimension() + == ( + value.parameters.rlwe_n().0 * value.parameters.lwe_decomposition_count().0, + value.parameters.lwe_n().0 + 1 + ) + ); + + Self { + rgsw_cts, + auto_keys, + lwe_ksk: value.lwe_ksk, + ui_to_s_ksks, + } + } + } + + impl PbsKey for ShoupNonInteractiveServerKeyEvaluationDomain { + type AutoKey = NormalAndShoup; + type LweKskKey = M; + type RgswCt = NormalAndShoup; + + fn galois_key_for_auto(&self, k: usize) -> &Self::AutoKey { + let d = self.auto_keys.get(&k).unwrap(); + d + } + fn rgsw_ct_lwe_si(&self, si: usize) -> &Self::RgswCt { + &self.rgsw_cts[si] + } + + fn lwe_ksk(&self) -> &Self::LweKskKey { + &self.lwe_ksk + } + } +} + +/// This is equivalent to ServerKeyEvaluationDomain with the addition that each +/// polynomial in evaluation domain has corresponding shoup representation +/// suitable for shoup multiplication. +pub(crate) struct ShoupServerKeyEvaluationDomain { + rgsw_cts: Vec>, + galois_keys: HashMap>, + lwe_ksk: M, +} + +mod shoup_server_key_eval_domain { + use itertools::{izip, Itertools}; + use num_traits::{FromPrimitive, PrimInt}; + + use crate::{backend::Modulus, decomposer::NumInfo, pbs::PbsKey}; + + use super::*; + + impl, R, N> + From, R, N>> + for ShoupServerKeyEvaluationDomain + where + ::R: RowMut, + M::MatElement: PrimInt + FromPrimitive + NumInfo, + { + fn from(value: ServerKeyEvaluationDomain, R, N>) -> Self { + let q = value.parameters.rlwe_q().q().unwrap(); + // Rgsw ciphertexts + let rgsw_cts = value + .rgsw_cts + .into_iter() + .map(|ct| NormalAndShoup::new_with_modulus(ct, q)) + .collect_vec(); + + let mut auto_keys = HashMap::new(); + value.galois_keys.into_iter().for_each(|(index, key)| { + auto_keys.insert(index, NormalAndShoup::new_with_modulus(key, q)); + }); + + Self { + rgsw_cts, + galois_keys: auto_keys, + lwe_ksk: value.lwe_ksk, + } + } + } + + impl PbsKey for ShoupServerKeyEvaluationDomain { + type AutoKey = NormalAndShoup; + type LweKskKey = M; + type RgswCt = NormalAndShoup; + + fn galois_key_for_auto(&self, k: usize) -> &Self::AutoKey { + self.galois_keys.get(&k).unwrap() + } + fn rgsw_ct_lwe_si(&self, si: usize) -> &Self::RgswCt { + &self.rgsw_cts[si] + } + + fn lwe_ksk(&self) -> &Self::LweKskKey { + &self.lwe_ksk + } + } +} + +pub struct CommonReferenceSeededNonInteractiveMultiPartyServerKeyShare { + /// Non-interactive RGSW ciphertexts for LWE secret indices for which user + /// is the leader + self_leader_ni_rgsw_cts: Vec, + /// Non-interactive RGSW ciphertexts for LWE secret indices for which user + /// is not the leader + not_self_leader_ni_rgsw_cts: Vec, + /// Zero encryptions for RGSW ciphertexts for all indices + ni_rgsw_zero_encs: Vec, + + /// Key switching key from u_j to s where u_j is user j's RLWE secret `u` + /// and `s` is ideal RLWE secret. Note that in server key share the key + /// switching key is encrypted under user j's RLWE secret `s_j`. It is + /// then switched to ideal RLWE secret after adding zero encryptions + /// generated using same `a_k`s from other users. + /// + /// That is the key share has the following key switching key: + /// (a_k*s_j + e + \beta u_j, a_k*s_j + e) + ui_to_s_ksk: M, + /// Zero encryptions to switch user l's key switching key u_l to s from + /// user l's RLWE secret s_l to ideal RLWE secret `s`. + /// + /// If there are P total parties then zero encryption sets are generated for + /// each party l \in [0, P) and l != j where j self's user_id. + /// + /// Zero encryption set for user `l` is stored at index l is l < j otherwise + /// it is stored at index l - 1, where j is self's user_id + ksk_zero_encs_for_others: Vec, + + /// RLWE auto key shares for auto elements [-g, g, g^2, g^{w}] where `w` + /// is the window size. Auto key share corresponding to auto element -g + /// is stored at key 0 and key share corresponding to auto element g^{k} is + /// stored at key `k` + auto_keys_share: HashMap, + /// LWE key switching key share to key switching LWE_{q, s}(m) to LWE_{q, + /// z}(m) + lwe_ksk_share: M::R, + + /// User's id. + /// + /// If there are P total parties, then user id must be inque and in range + /// [0, P) + user_id: usize, + /// Total users participating in multi-party compute + total_users: usize, + /// LWE dimension + lwe_n: usize, + /// Common reference seed + cr_seed: S, + parameters: P, +} + +mod impl_common_ref_non_interactive_multi_party_server_share { + use crate::bool::evaluator::multi_party_user_id_lwe_segment; + + use super::*; + + impl CommonReferenceSeededNonInteractiveMultiPartyServerKeyShare { + pub(in super::super) fn new( + self_leader_ni_rgsw_cts: Vec, + not_self_leader_ni_rgsw_cts: Vec, + ni_rgsw_zero_encs: Vec, + ui_to_s_ksk: M, + ksk_zero_encs_for_others: Vec, + auto_keys_share: HashMap, + lwe_ksk_share: M::R, + user_index: usize, + total_users: usize, + lwe_n: usize, + cr_seed: S, + parameters: P, + ) -> Self { + Self { + self_leader_ni_rgsw_cts, + not_self_leader_ni_rgsw_cts, + ni_rgsw_zero_encs, + ui_to_s_ksk, + ksk_zero_encs_for_others, + auto_keys_share, + lwe_ksk_share, + user_id: user_index, + total_users, + lwe_n, + cr_seed, + parameters, + } + } + + pub(in super::super) fn ni_rgsw_cts_for_self_leader_lwe_index( + &self, + lwe_index: usize, + ) -> &M { + let self_segment = + multi_party_user_id_lwe_segment(self.user_id, self.total_users, self.lwe_n); + assert!(lwe_index >= self_segment.0 && lwe_index < self_segment.1); + &self.self_leader_ni_rgsw_cts[lwe_index - self_segment.0] + } + + pub(in super::super) fn ni_rgsw_cts_for_self_not_leader_lwe_index( + &self, + lwe_index: usize, + ) -> &M { + let self_segment = + multi_party_user_id_lwe_segment(self.user_id, self.total_users, self.lwe_n); + // Non-interactive RGSW cts when self is not leader are stored in + // sorted-order. For ex, if self is the leader for indices (5, 6] + // then self stores NI-RGSW cts for rest of indices like [0, 1, 2, + // 3, 4, 6, 7, 8, 9] + assert!(lwe_index < self.lwe_n); + assert!(lwe_index < self_segment.0 || lwe_index >= self_segment.1); + if lwe_index < self_segment.0 { + &self.not_self_leader_ni_rgsw_cts[lwe_index] + } else { + &self.not_self_leader_ni_rgsw_cts[lwe_index - (self_segment.1 - self_segment.0)] + } + } + + pub(in super::super) fn ni_rgsw_zero_enc_for_lwe_index(&self, lwe_index: usize) -> &M { + &self.ni_rgsw_zero_encs[lwe_index] + } + + pub(in super::super) fn ui_to_s_ksk(&self) -> &M { + &self.ui_to_s_ksk + } + + pub(in super::super) fn user_index(&self) -> usize { + self.user_id + } + + pub(in super::super) fn auto_keys_share(&self) -> &HashMap { + &self.auto_keys_share + } + + pub(in super::super) fn lwe_ksk_share(&self) -> &M::R { + &self.lwe_ksk_share + } + + pub(in super::super) fn ui_to_s_ksk_zero_encs_for_user_i(&self, user_i: usize) -> &M { + assert!(user_i != self.user_id); + if user_i < self.user_id { + &self.ksk_zero_encs_for_others[user_i] + } else { + &self.ksk_zero_encs_for_others[user_i - 1] + } + } + + pub(in super::super) fn cr_seed(&self) -> &S { + &self.cr_seed + } + + pub(in super::super) fn parameters(&self) -> &P { + &self.parameters + } + } +} + +/// Stores both normal and shoup representation of elements in the container +/// (for ex, a matrix). +/// +/// To access normal representation borrow self as a `self.as_ref()`. To access +/// shoup representation call `self.shoup_repr()` +pub(crate) struct NormalAndShoup(M, M); + +impl NormalAndShoup { + fn new_with_modulus(value: M, modulus: ::Modulus) -> Self { + let value_shoup = M::to_shoup(&value, modulus); + NormalAndShoup(value, value_shoup) + } +} + +impl AsRef for NormalAndShoup { + fn as_ref(&self) -> &M { + &self.0 + } +} + +impl WithShoupRepr for NormalAndShoup { + type M = M; + fn shoup_repr(&self) -> &Self::M { + &self.1 + } +} + +#[cfg(test)] +pub(crate) mod key_size { + use num_traits::{FromPrimitive, PrimInt}; + + use crate::{backend::Modulus, decomposer::NumInfo, SizeInBitsWithLogModulus}; + + use super::*; + + /// Size of the Key in Bits + pub(crate) trait KeySize { + /// Returns size of the key in bits + fn size(&self) -> usize; + } + + impl KeySize + for CommonReferenceSeededInteractiveMultiPartyServerKeyShare, S> + where + M: SizeInBitsWithLogModulus, + M::R: SizeInBitsWithLogModulus, + El: PrimInt + NumInfo + FromPrimitive, + { + fn size(&self) -> usize { + let mut total = 0; + + let log_rlweq = self.parameters().rlwe_q().log_q(); + self.self_leader_rgsws + .iter() + .for_each(|v| total += v.size(log_rlweq)); + self.not_self_leader_rgsws + .iter() + .for_each(|v| total += v.size(log_rlweq)); + self.auto_keys + .values() + .for_each(|v| total += v.size(log_rlweq)); + + let log_lweq = self.parameters().lwe_q().log_q(); + total += self.lwe_ksk.size(log_lweq); + total + } + } + + impl KeySize + for CommonReferenceSeededNonInteractiveMultiPartyServerKeyShare, S> + where + M: SizeInBitsWithLogModulus, + M::R: SizeInBitsWithLogModulus, + El: PrimInt + NumInfo + FromPrimitive, + { + fn size(&self) -> usize { + let mut total = 0; + + let log_rlweq = self.parameters.rlwe_q().log_q(); + self.self_leader_ni_rgsw_cts + .iter() + .for_each(|v| total += v.size(log_rlweq)); + self.not_self_leader_ni_rgsw_cts + .iter() + .for_each(|v| total += v.size(log_rlweq)); + self.ni_rgsw_zero_encs + .iter() + .for_each(|v| total += v.size(log_rlweq)); + total += self.ui_to_s_ksk.size(log_rlweq); + self.ksk_zero_encs_for_others + .iter() + .for_each(|v| total += v.size(log_rlweq)); + self.auto_keys_share + .values() + .for_each(|v| total += v.size(log_rlweq)); + + let log_lweq = self.parameters.lwe_q().log_q(); + total += self.lwe_ksk_share.size(log_lweq); + + total + } + } +} + +pub(super) mod tests { + use itertools::izip; + use num_traits::{FromPrimitive, PrimInt, Zero}; + + use crate::{ + backend::GetModulus, bool::ClientKey, decomposer::NumInfo, lwe::decrypt_lwe, + parameters::CiphertextModulus, utils::TryConvertFrom1, ArithmeticOps, Row, + }; + + use super::SinglePartyClientKey; + + pub(crate) fn ideal_sk_rlwe(cks: &[ClientKey]) -> Vec { + let mut ideal_rlwe_sk = cks[0].sk_rlwe(); + cks.iter().skip(1).for_each(|k| { + let sk_rlwe = k.sk_rlwe(); + izip!(ideal_rlwe_sk.iter_mut(), sk_rlwe.iter()).for_each(|(a, b)| { + *a = *a + b; + }); + }); + ideal_rlwe_sk + } + + pub(crate) fn ideal_sk_lwe(cks: &[ClientKey]) -> Vec { + let mut ideal_rlwe_sk = cks[0].sk_lwe(); + cks.iter().skip(1).for_each(|k| { + let sk_rlwe = k.sk_lwe(); + izip!(ideal_rlwe_sk.iter_mut(), sk_rlwe.iter()).for_each(|(a, b)| { + *a = *a + b; + }); + }); + ideal_rlwe_sk + } + + pub(crate) fn measure_noise_lwe< + R: Row, + S, + Modop: ArithmeticOps + + GetModulus, Element = R::Element>, + >( + lwe_ct: &R, + m_expected: R::Element, + sk: &[S], + modop: &Modop, + ) -> R::Element + where + R: TryConvertFrom1<[S], CiphertextModulus>, + R::Element: Zero + FromPrimitive + PrimInt + NumInfo, + { + let noisy_m = decrypt_lwe(lwe_ct, &sk, modop); + let noise = modop.sub(&m_expected, &noisy_m); + noise + } + // #[test] + // fn trial() { + // let parameters = I_2P; + // let ck = ClientKey::new(parameters); + // let lwe = ck.sk_lwe(); + // dbg!(lwe); + // } +} diff --git a/src/bool/mod.rs b/src/bool/mod.rs new file mode 100644 index 0000000..9567b09 --- /dev/null +++ b/src/bool/mod.rs @@ -0,0 +1,266 @@ +mod evaluator; +mod keys; +pub(crate) mod parameters; + +#[cfg(feature = "interactive_mp")] +mod mp_api; +#[cfg(feature = "non_interactive_mp")] +mod ni_mp_api; + +#[cfg(feature = "non_interactive_mp")] +pub use ni_mp_api::*; + +#[cfg(feature = "interactive_mp")] +pub use mp_api::*; + +use crate::RowEntity; + +pub type ClientKey = keys::ClientKey<[u8; 32], u64>; +#[cfg(any(feature = "interactive_mp", feature = "non_interactive_mp"))] +pub type FheBool = impl_bool_frontend::FheBool>; + +pub(crate) trait BooleanGates { + type Ciphertext: RowEntity; + type Key; + + fn and_inplace(&mut self, c0: &mut Self::Ciphertext, c1: &Self::Ciphertext, key: &Self::Key); + fn nand_inplace(&mut self, c0: &mut Self::Ciphertext, c1: &Self::Ciphertext, key: &Self::Key); + fn or_inplace(&mut self, c0: &mut Self::Ciphertext, c1: &Self::Ciphertext, key: &Self::Key); + fn nor_inplace(&mut self, c0: &mut Self::Ciphertext, c1: &Self::Ciphertext, key: &Self::Key); + fn xor_inplace(&mut self, c0: &mut Self::Ciphertext, c1: &Self::Ciphertext, key: &Self::Key); + fn xnor_inplace(&mut self, c0: &mut Self::Ciphertext, c1: &Self::Ciphertext, key: &Self::Key); + fn not_inplace(&self, c: &mut Self::Ciphertext); + + fn and( + &mut self, + c0: &Self::Ciphertext, + c1: &Self::Ciphertext, + key: &Self::Key, + ) -> Self::Ciphertext; + fn nand( + &mut self, + c0: &Self::Ciphertext, + c1: &Self::Ciphertext, + key: &Self::Key, + ) -> Self::Ciphertext; + fn or( + &mut self, + c0: &Self::Ciphertext, + c1: &Self::Ciphertext, + key: &Self::Key, + ) -> Self::Ciphertext; + fn nor( + &mut self, + c0: &Self::Ciphertext, + c1: &Self::Ciphertext, + key: &Self::Key, + ) -> Self::Ciphertext; + fn xor( + &mut self, + c0: &Self::Ciphertext, + c1: &Self::Ciphertext, + key: &Self::Key, + ) -> Self::Ciphertext; + fn xnor( + &mut self, + c0: &Self::Ciphertext, + c1: &Self::Ciphertext, + key: &Self::Key, + ) -> Self::Ciphertext; + fn not(&self, c: &Self::Ciphertext) -> Self::Ciphertext; +} + +#[cfg(any(feature = "interactive_mp", feature = "non_interactive_mp"))] +mod impl_bool_frontend { + use crate::MultiPartyDecryptor; + + /// Fhe Bool ciphertext + #[derive(Clone)] + pub struct FheBool { + pub(crate) data: C, + } + + impl FheBool { + pub(crate) fn data(&self) -> &C { + &self.data + } + + pub(crate) fn data_mut(&mut self) -> &mut C { + &mut self.data + } + } + + impl MultiPartyDecryptor> for K + where + K: MultiPartyDecryptor, + { + type DecryptionShare = >::DecryptionShare; + + fn aggregate_decryption_shares( + &self, + c: &FheBool, + shares: &[Self::DecryptionShare], + ) -> bool { + self.aggregate_decryption_shares(&c.data, shares) + } + + fn gen_decryption_share(&self, c: &FheBool) -> Self::DecryptionShare { + self.gen_decryption_share(&c.data) + } + } + + mod ops { + use std::ops::{BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign, Not}; + + use crate::{ + utils::{Global, WithLocal}, + BooleanGates, + }; + + use super::super::{BoolEvaluator, RuntimeServerKey}; + + type FheBool = super::super::FheBool; + + impl BitAnd for &FheBool { + type Output = FheBool; + fn bitand(self, rhs: Self) -> Self::Output { + BoolEvaluator::with_local_mut(|e| { + let key = RuntimeServerKey::global(); + FheBool { + data: e.and(self.data(), rhs.data(), key), + } + }) + } + } + + impl BitAndAssign for FheBool { + fn bitand_assign(&mut self, rhs: Self) { + BoolEvaluator::with_local_mut_mut(&mut |e| { + let key = RuntimeServerKey::global(); + e.and_inplace(&mut self.data_mut(), rhs.data(), key); + }); + } + } + + impl BitOr for &FheBool { + type Output = FheBool; + fn bitor(self, rhs: Self) -> Self::Output { + BoolEvaluator::with_local_mut(|e| { + let key = RuntimeServerKey::global(); + FheBool { + data: e.or(self.data(), rhs.data(), key), + } + }) + } + } + + impl BitOrAssign for FheBool { + fn bitor_assign(&mut self, rhs: Self) { + BoolEvaluator::with_local_mut_mut(&mut |e| { + let key = RuntimeServerKey::global(); + e.or_inplace(&mut self.data_mut(), rhs.data(), key); + }); + } + } + + impl BitXor for &FheBool { + type Output = FheBool; + fn bitxor(self, rhs: Self) -> Self::Output { + BoolEvaluator::with_local_mut(|e| { + let key = RuntimeServerKey::global(); + FheBool { + data: e.xor(self.data(), rhs.data(), key), + } + }) + } + } + + impl BitXorAssign for FheBool { + fn bitxor_assign(&mut self, rhs: Self) { + BoolEvaluator::with_local_mut_mut(&mut |e| { + let key = RuntimeServerKey::global(); + e.xor_inplace(&mut self.data_mut(), rhs.data(), key); + }); + } + } + + impl Not for &FheBool { + type Output = FheBool; + fn not(self) -> Self::Output { + BoolEvaluator::with_local(|e| FheBool { + data: e.not(self.data()), + }) + } + } + } +} + +#[cfg(any(feature = "interactive_mp", feature = "non_interactive_mp"))] +mod common_mp_enc_dec { + use itertools::Itertools; + + use super::BoolEvaluator; + use crate::{ + pbs::{sample_extract, PbsInfo}, + utils::WithLocal, + Matrix, RowEntity, SampleExtractor, + }; + + type Mat = Vec>; + + impl SampleExtractor<::R> for Mat { + /// Sample extract coefficient at `index` as a LWE ciphertext from RLWE + /// ciphertext `Self` + fn extract_at(&self, index: usize) -> ::R { + // input is RLWE ciphertext + assert!(self.dimension().0 == 2); + + let ring_size = self.dimension().1; + assert!(index < ring_size); + + BoolEvaluator::with_local(|e| { + let mut lwe_out = ::R::zeros(ring_size + 1); + sample_extract(&mut lwe_out, self, e.pbs_info().modop_rlweq(), index); + lwe_out + }) + } + + /// Extract first `how_many` coefficients of `Self` as LWE ciphertexts + fn extract_many(&self, how_many: usize) -> Vec<::R> { + assert!(self.dimension().0 == 2); + + let ring_size = self.dimension().1; + assert!(how_many <= ring_size); + + (0..how_many) + .map(|index| { + BoolEvaluator::with_local(|e| { + let mut lwe_out = ::R::zeros(ring_size + 1); + sample_extract(&mut lwe_out, self, e.pbs_info().modop_rlweq(), index); + lwe_out + }) + }) + .collect_vec() + } + + /// Extracts all coefficients of `Self` as LWE ciphertexts + fn extract_all(&self) -> Vec<::R> { + assert!(self.dimension().0 == 2); + + let ring_size = self.dimension().1; + + (0..ring_size) + .map(|index| { + BoolEvaluator::with_local(|e| { + let mut lwe_out = ::R::zeros(ring_size + 1); + sample_extract(&mut lwe_out, self, e.pbs_info().modop_rlweq(), index); + lwe_out + }) + }) + .collect_vec() + } + } +} + +#[cfg(test)] +mod print_noise; diff --git a/src/bool/mp_api.rs b/src/bool/mp_api.rs new file mode 100644 index 0000000..f88a863 --- /dev/null +++ b/src/bool/mp_api.rs @@ -0,0 +1,697 @@ +use std::{cell::RefCell, sync::OnceLock}; + +use crate::{ + backend::{ModularOpsU64, ModulusPowerOf2}, + ntt::NttBackendU64, + random::{DefaultSecureRng, NewWithSeed}, + utils::{Global, WithLocal}, +}; + +use super::{evaluator::InteractiveMultiPartyCrs, keys::*, parameters::*, ClientKey}; + +pub(crate) type BoolEvaluator = super::evaluator::BoolEvaluator< + Vec>, + NttBackendU64, + ModularOpsU64>, + ModulusPowerOf2>, + ShoupServerKeyEvaluationDomain>>, +>; + +thread_local! { + static BOOL_EVALUATOR: RefCell> = RefCell::new(None); + +} +static BOOL_SERVER_KEY: OnceLock>>> = OnceLock::new(); + +static MULTI_PARTY_CRS: OnceLock> = OnceLock::new(); + +pub enum ParameterSelector { + InteractiveLTE2Party, + InteractiveLTE4Party, + InteractiveLTE8Party, +} + +/// Select Interactive multi-party parameter variant +pub fn set_parameter_set(select: ParameterSelector) { + match select { + ParameterSelector::InteractiveLTE2Party => { + BOOL_EVALUATOR.with_borrow_mut(|v| *v = Some(BoolEvaluator::new(I_2P_LB_SR))); + } + ParameterSelector::InteractiveLTE4Party => { + BOOL_EVALUATOR.with_borrow_mut(|v| *v = Some(BoolEvaluator::new(I_4P))); + } + ParameterSelector::InteractiveLTE8Party => { + BOOL_EVALUATOR.with_borrow_mut(|v| *v = Some(BoolEvaluator::new(I_8P))); + } + } +} + +/// Set application specific interactive multi-party common reference string +pub fn set_common_reference_seed(seed: [u8; 32]) { + assert!( + MULTI_PARTY_CRS + .set(InteractiveMultiPartyCrs { seed: seed }) + .is_ok(), + "Attempted to set MP SEED twice." + ) +} + +/// Generate client key for interactive multi-party protocol +pub fn gen_client_key() -> ClientKey { + BoolEvaluator::with_local(|e| e.client_key()) +} + +/// Generate client's share for collective public key, i.e round 1 share, of the +/// 2 round protocol +pub fn collective_pk_share( + ck: &ClientKey, +) -> CommonReferenceSeededCollectivePublicKeyShare, [u8; 32], BoolParameters> { + BoolEvaluator::with_local(|e| { + let pk_share = e.multi_party_public_key_share(InteractiveMultiPartyCrs::global(), ck); + pk_share + }) +} + +/// Generate clients share for collective server key, i.e. round 2, of the +/// 2 round protocol +pub fn collective_server_key_share( + ck: &ClientKey, + user_id: usize, + total_users: usize, + pk: &PublicKey>, R, ModOp>, +) -> CommonReferenceSeededInteractiveMultiPartyServerKeyShare< + Vec>, + BoolParameters, + InteractiveMultiPartyCrs<[u8; 32]>, +> { + BoolEvaluator::with_local_mut(|e| { + let server_key_share = e.gen_interactive_multi_party_server_key_share( + user_id, + total_users, + InteractiveMultiPartyCrs::global(), + pk.key(), + ck, + ); + server_key_share + }) +} + +/// Aggregate public key shares from all parties. +/// +/// Public key shares are generated per client in round 1. Aggregation of public +/// key shares marks the end of round 1. +pub fn aggregate_public_key_shares( + shares: &[CommonReferenceSeededCollectivePublicKeyShare< + Vec, + [u8; 32], + BoolParameters, + >], +) -> PublicKey>, DefaultSecureRng, ModularOpsU64>> { + PublicKey::from(shares) +} + +/// Aggregate server key shares +pub fn aggregate_server_key_shares( + shares: &[CommonReferenceSeededInteractiveMultiPartyServerKeyShare< + Vec>, + BoolParameters, + InteractiveMultiPartyCrs<[u8; 32]>, + >], +) -> SeededInteractiveMultiPartyServerKey< + Vec>, + InteractiveMultiPartyCrs<[u8; 32]>, + BoolParameters, +> { + BoolEvaluator::with_local(|e| e.aggregate_interactive_multi_party_server_key_shares(shares)) +} + +impl + SeededInteractiveMultiPartyServerKey< + Vec>, + InteractiveMultiPartyCrs<::Seed>, + BoolParameters, + > +{ + /// Sets the server key as a global reference for circuit evaluation + pub fn set_server_key(&self) { + assert!( + BOOL_SERVER_KEY + .set(ShoupServerKeyEvaluationDomain::from( + ServerKeyEvaluationDomain::<_, _, DefaultSecureRng, NttBackendU64>::from(self), + )) + .is_ok(), + "Attempted to set server key twice." + ); + } +} + +// MULTIPARTY CRS // +impl Global for InteractiveMultiPartyCrs<[u8; 32]> { + fn global() -> &'static Self { + MULTI_PARTY_CRS + .get() + .expect("Multi Party Common Reference String not set") + } +} + +// BOOL EVALUATOR // +impl WithLocal for BoolEvaluator { + fn with_local(func: F) -> R + where + F: Fn(&Self) -> R, + { + BOOL_EVALUATOR.with_borrow(|s| func(s.as_ref().expect("Parameters not set"))) + } + + fn with_local_mut(func: F) -> R + where + F: Fn(&mut Self) -> R, + { + BOOL_EVALUATOR.with_borrow_mut(|s| func(s.as_mut().expect("Parameters not set"))) + } + + fn with_local_mut_mut(func: &mut F) -> R + where + F: FnMut(&mut Self) -> R, + { + BOOL_EVALUATOR.with_borrow_mut(|s| func(s.as_mut().expect("Parameters not set"))) + } +} + +pub(crate) type RuntimeServerKey = ShoupServerKeyEvaluationDomain>>; +impl Global for RuntimeServerKey { + fn global() -> &'static Self { + BOOL_SERVER_KEY.get().expect("Server key not set!") + } +} + +mod impl_enc_dec { + use crate::{ + bool::evaluator::BoolEncoding, + multi_party::{ + multi_party_aggregate_decryption_shares_and_decrypt, multi_party_decryption_share, + }, + pbs::{sample_extract, PbsInfo}, + rgsw::public_key_encrypt_rlwe, + utils::TryConvertFrom1, + Encryptor, Matrix, MatrixEntity, MultiPartyDecryptor, RowEntity, + }; + use itertools::Itertools; + use num_traits::{ToPrimitive, Zero}; + + use super::*; + + type Mat = Vec>; + + impl Encryptor<[bool], Vec> for PublicKey { + fn encrypt(&self, m: &[bool]) -> Vec { + BoolEvaluator::with_local(|e| { + DefaultSecureRng::with_local_mut(|rng| { + let parameters = e.parameters(); + let ring_size = parameters.rlwe_n().0; + + let rlwe_count = ((m.len() as f64 / ring_size as f64).ceil()) + .to_usize() + .unwrap(); + + // encrypt `m` into ceil(len(m)/N) RLWE ciphertexts + let rlwes = (0..rlwe_count) + .map(|index| { + let mut message = vec![::MatElement::zero(); ring_size]; + m[(index * ring_size)..std::cmp::min(m.len(), (index + 1) * ring_size)] + .iter() + .enumerate() + .for_each(|(i, v)| { + if *v { + message[i] = parameters.rlwe_q().true_el() + } else { + message[i] = parameters.rlwe_q().false_el() + } + }); + + // encrypt message + let mut rlwe_out = + ::zeros(2, parameters.rlwe_n().0); + + public_key_encrypt_rlwe::<_, _, _, _, i32, _>( + &mut rlwe_out, + self.key(), + &message, + e.pbs_info().modop_rlweq(), + e.pbs_info().nttop_rlweq(), + rng, + ); + + rlwe_out + }) + .collect_vec(); + rlwes + }) + }) + } + } + + impl Encryptor::R> for PublicKey { + fn encrypt(&self, m: &bool) -> ::R { + let m = vec![*m]; + let rlwe = &self.encrypt(m.as_slice())[0]; + BoolEvaluator::with_local(|e| { + let mut lwe = ::R::zeros(e.parameters().rlwe_n().0 + 1); + sample_extract(&mut lwe, rlwe, e.pbs_info().modop_rlweq(), 0); + lwe + }) + } + } + + impl MultiPartyDecryptor::R> for K + where + K: InteractiveMultiPartyClientKey, + ::R: + TryConvertFrom1<[K::Element], CiphertextModulus<::MatElement>>, + { + type DecryptionShare = ::MatElement; + + fn gen_decryption_share(&self, c: &::R) -> Self::DecryptionShare { + BoolEvaluator::with_local(|e| { + DefaultSecureRng::with_local_mut(|rng| { + multi_party_decryption_share( + c, + self.sk_rlwe().as_slice(), + e.pbs_info().modop_rlweq(), + rng, + ) + }) + }) + } + + fn aggregate_decryption_shares( + &self, + c: &::R, + shares: &[Self::DecryptionShare], + ) -> bool { + BoolEvaluator::with_local(|e| { + let noisy_m = multi_party_aggregate_decryption_shares_and_decrypt( + c, + shares, + e.pbs_info().modop_rlweq(), + ); + + e.pbs_info().rlwe_q().decode(noisy_m) + }) + } + } +} + +#[cfg(test)] +mod tests { + use itertools::Itertools; + use rand::{thread_rng, Rng, RngCore}; + + use crate::{bool::evaluator::BoolEncoding, Encryptor, MultiPartyDecryptor, SampleExtractor}; + + use super::*; + + #[test] + fn batched_fhe_u8s_extract_works() { + set_parameter_set(ParameterSelector::InteractiveLTE2Party); + let mut seed = [0u8; 32]; + thread_rng().fill_bytes(&mut seed); + set_common_reference_seed(seed); + + let parties = 2; + let cks = (0..parties).map(|_| gen_client_key()).collect_vec(); + + // round 1 + let pk_shares = cks.iter().map(|k| collective_pk_share(k)).collect_vec(); + + // collective pk + let pk = aggregate_public_key_shares(&pk_shares); + + let parameters = BoolEvaluator::with_local(|e| e.parameters().clone()); + + let batch_size = parameters.rlwe_n().0 * 3 + 123; + let m = (0..batch_size) + .map(|_| thread_rng().gen::()) + .collect_vec(); + + let seeded_ct = pk.encrypt(m.as_slice()); + + let m_back = (0..batch_size) + .map(|i| { + let ct = seeded_ct.extract_at(i); + cks[0].aggregate_decryption_shares( + &ct, + &cks.iter() + .map(|k| k.gen_decryption_share(&ct)) + .collect_vec(), + ) + }) + .collect_vec(); + + assert_eq!(m, m_back); + } + + mod sp_api { + use num_traits::ToPrimitive; + + use crate::{ + bool::impl_bool_frontend::FheBool, pbs::PbsInfo, rgsw::seeded_secret_key_encrypt_rlwe, + Decryptor, + }; + + use super::*; + + pub(crate) fn set_single_party_parameter_sets(parameter: BoolParameters) { + BOOL_EVALUATOR.with_borrow_mut(|e| *e = Some(BoolEvaluator::new(parameter))); + } + + // SERVER KEY EVAL (/SHOUP) DOMAIN // + impl SeededSinglePartyServerKey>, BoolParameters, [u8; 32]> { + pub fn set_server_key(&self) { + assert!( + BOOL_SERVER_KEY + .set( + ShoupServerKeyEvaluationDomain::from(ServerKeyEvaluationDomain::< + _, + _, + DefaultSecureRng, + NttBackendU64, + >::from( + self + ),) + ) + .is_ok(), + "Attempted to set server key twice." + ); + } + } + + pub(crate) fn gen_keys() -> ( + ClientKey, + SeededSinglePartyServerKey>, BoolParameters, [u8; 32]>, + ) { + super::BoolEvaluator::with_local_mut(|e| { + let ck = e.client_key(); + let sk = e.single_party_server_key(&ck); + + (ck, sk) + }) + } + + impl> Encryptor> for K { + fn encrypt(&self, m: &bool) -> Vec { + BoolEvaluator::with_local(|e| e.sk_encrypt(*m, self)) + } + } + + impl> Decryptor> for K { + fn decrypt(&self, c: &Vec) -> bool { + BoolEvaluator::with_local(|e| e.sk_decrypt(c, self)) + } + } + + impl, C> Encryptor> for K + where + K: Encryptor, + { + fn encrypt(&self, m: &bool) -> FheBool { + FheBool { + data: self.encrypt(m), + } + } + } + + impl, C> Decryptor> for K + where + K: Decryptor, + { + fn decrypt(&self, c: &FheBool) -> bool { + self.decrypt(c.data()) + } + } + + impl Encryptor<[bool], (Vec>, [u8; 32])> for K + where + K: SinglePartyClientKey, + { + fn encrypt(&self, m: &[bool]) -> (Vec>, [u8; 32]) { + BoolEvaluator::with_local(|e| { + DefaultSecureRng::with_local_mut(|rng| { + let parameters = e.parameters(); + let ring_size = parameters.rlwe_n().0; + + let rlwe_count = ((m.len() as f64 / ring_size as f64).ceil()) + .to_usize() + .unwrap(); + + let mut seed = ::Seed::default(); + rng.fill_bytes(&mut seed); + let mut prng = DefaultSecureRng::new_seeded(seed); + + let sk_u = self.sk_rlwe(); + + // encrypt `m` into ceil(len(m)/N) RLWE ciphertexts + let rlwes = (0..rlwe_count) + .map(|index| { + let mut message = vec![0; ring_size]; + m[(index * ring_size) + ..std::cmp::min(m.len(), (index + 1) * ring_size)] + .iter() + .enumerate() + .for_each(|(i, v)| { + if *v { + message[i] = parameters.rlwe_q().true_el() + } else { + message[i] = parameters.rlwe_q().false_el() + } + }); + + // encrypt message + let mut rlwe_out = vec![0u64; parameters.rlwe_n().0]; + seeded_secret_key_encrypt_rlwe( + &message, + &mut rlwe_out, + &sk_u, + e.pbs_info().modop_rlweq(), + e.pbs_info().nttop_rlweq(), + &mut prng, + rng, + ); + + rlwe_out + }) + .collect_vec(); + + (rlwes, seed) + }) + }) + } + } + + #[test] + #[cfg(feature = "interactive_mp")] + fn all_uint8_apis() { + use num_traits::Euclid; + + use crate::{div_zero_error_flag, FheBool}; + + set_single_party_parameter_sets(SP_TEST_BOOL_PARAMS); + + let (ck, sk) = gen_keys(); + sk.set_server_key(); + + for i in 0..=255 { + for j in 0..=255 { + let m0 = i; + let m1 = j; + let c0 = ck.encrypt(&m0); + let c1 = ck.encrypt(&m1); + + assert!(ck.decrypt(&c0) == m0); + assert!(ck.decrypt(&c1) == m1); + + // Arithmetic + { + { + // Add + let c_add = &c0 + &c1; + let m0_plus_m1 = ck.decrypt(&c_add); + assert_eq!( + m0_plus_m1, + m0.wrapping_add(m1), + "Expected {} but got {m0_plus_m1} for + {i}+{j}", + m0.wrapping_add(m1) + ); + } + { + // Sub + let c_sub = &c0 - &c1; + let m0_sub_m1 = ck.decrypt(&c_sub); + assert_eq!( + m0_sub_m1, + m0.wrapping_sub(m1), + "Expected {} but got {m0_sub_m1} for + {i}-{j}", + m0.wrapping_sub(m1) + ); + } + + { + // Mul + let c_m0m1 = &c0 * &c1; + let m0m1 = ck.decrypt(&c_m0m1); + assert_eq!( + m0m1, + m0.wrapping_mul(m1), + "Expected {} but got {m0m1} for {i}x{j}", + m0.wrapping_mul(m1) + ); + } + + // Div & Rem + { + let (c_quotient, c_rem) = c0.div_rem(&c1); + let m_quotient = ck.decrypt(&c_quotient); + let m_remainder = ck.decrypt(&c_rem); + if j != 0 { + let (q, r) = i.div_rem_euclid(&j); + assert_eq!( + m_quotient, q, + "Expected {} but got {m_quotient} for + {i}/{j}", + q + ); + assert_eq!( + m_remainder, r, + "Expected {} but got {m_remainder} for + {i}%{j}", + r + ); + } else { + assert_eq!( + m_quotient, 255, + "Expected 255 but got {m_quotient}. Case + div by zero" + ); + assert_eq!( + m_remainder, i, + "Expected {i} but got {m_remainder}. Case + div by zero" + ); + + let div_by_zero = ck.decrypt(&div_zero_error_flag().unwrap()); + assert_eq!( + div_by_zero, true, + "Expected true but got {div_by_zero}" + ); + } + } + } + + // // Comparisons + { + { + let c_eq = c0.eq(&c1); + let is_eq = ck.decrypt(&c_eq); + assert_eq!( + is_eq, + i == j, + "Expected {} but got {is_eq} for {i}=={j}", + i == j + ); + } + + { + let c_gt = c0.gt(&c1); + let is_gt = ck.decrypt(&c_gt); + assert_eq!( + is_gt, + i > j, + "Expected {} but got {is_gt} for {i}>{j}", + i > j + ); + } + + { + let c_lt = c0.lt(&c1); + let is_lt = ck.decrypt(&c_lt); + assert_eq!( + is_lt, + i < j, + "Expected {} but got {is_lt} for {i}<{j}", + i < j + ); + } + + { + let c_ge = c0.ge(&c1); + let is_ge = ck.decrypt(&c_ge); + assert_eq!( + is_ge, + i >= j, + "Expected {} but got {is_ge} for {i}>={j}", + i >= j + ); + } + + { + let c_le = c0.le(&c1); + let is_le = ck.decrypt(&c_le); + assert_eq!( + is_le, + i <= j, + "Expected {} but got {is_le} for {i}<={j}", + i <= j + ); + } + } + + // mux + { + let selector = thread_rng().gen_bool(0.5); + let selector_enc: FheBool = ck.encrypt(&selector); + let mux_out = ck.decrypt(&c0.mux(&c1, &selector_enc)); + let want_mux_out = if selector { m0 } else { m1 }; + assert_eq!(mux_out, want_mux_out); + } + } + } + } + + #[test] + #[cfg(feature = "interactive_mp")] + fn all_bool_apis() { + use crate::FheBool; + + set_single_party_parameter_sets(SP_TEST_BOOL_PARAMS); + + let (ck, sk) = gen_keys(); + sk.set_server_key(); + + for _ in 0..100 { + let a = thread_rng().gen_bool(0.5); + let b = thread_rng().gen_bool(0.5); + + let c_a: FheBool = ck.encrypt(&a); + let c_b: FheBool = ck.encrypt(&b); + + let c_out = &c_a & &c_b; + let out = ck.decrypt(&c_out); + assert_eq!(out, a & b, "Expected {} but got {out}", a & b); + + let c_out = &c_a | &c_b; + let out = ck.decrypt(&c_out); + assert_eq!(out, a | b, "Expected {} but got {out}", a | b); + + let c_out = &c_a ^ &c_b; + let out = ck.decrypt(&c_out); + assert_eq!(out, a ^ b, "Expected {} but got {out}", a ^ b); + + let c_out = !(&c_a); + let out = ck.decrypt(&c_out); + assert_eq!(out, !a, "Expected {} but got {out}", !a); + } + } + } +} diff --git a/src/bool/ni_mp_api.rs b/src/bool/ni_mp_api.rs new file mode 100644 index 0000000..7efaa6f --- /dev/null +++ b/src/bool/ni_mp_api.rs @@ -0,0 +1,459 @@ +use std::{cell::RefCell, sync::OnceLock}; + +use crate::{ + backend::ModulusPowerOf2, + bool::parameters::ParameterVariant, + random::DefaultSecureRng, + utils::{Global, WithLocal}, + ModularOpsU64, NttBackendU64, +}; + +use super::{ + evaluator::NonInteractiveMultiPartyCrs, + keys::{ + CommonReferenceSeededNonInteractiveMultiPartyServerKeyShare, + NonInteractiveServerKeyEvaluationDomain, SeededNonInteractiveMultiPartyServerKey, + ShoupNonInteractiveServerKeyEvaluationDomain, + }, + parameters::{BoolParameters, CiphertextModulus, NI_2P, NI_4P_HB_FR, NI_8P}, + ClientKey, +}; + +pub(crate) type BoolEvaluator = super::evaluator::BoolEvaluator< + Vec>, + NttBackendU64, + ModularOpsU64>, + ModulusPowerOf2>, + ShoupNonInteractiveServerKeyEvaluationDomain>>, +>; + +thread_local! { + static BOOL_EVALUATOR: RefCell> = RefCell::new(None); + +} +static BOOL_SERVER_KEY: OnceLock>>> = + OnceLock::new(); + +static MULTI_PARTY_CRS: OnceLock> = OnceLock::new(); + +pub enum ParameterSelector { + NonInteractiveLTE2Party, + NonInteractiveLTE4Party, + NonInteractiveLTE8Party, +} + +pub fn set_parameter_set(select: ParameterSelector) { + match select { + ParameterSelector::NonInteractiveLTE2Party => { + BOOL_EVALUATOR.with_borrow_mut(|v| *v = Some(BoolEvaluator::new(NI_2P))); + } + ParameterSelector::NonInteractiveLTE4Party => { + BOOL_EVALUATOR.with_borrow_mut(|v| *v = Some(BoolEvaluator::new(NI_4P_HB_FR))); + } + ParameterSelector::NonInteractiveLTE8Party => { + BOOL_EVALUATOR.with_borrow_mut(|v| *v = Some(BoolEvaluator::new(NI_8P))); + } + } +} + +pub fn set_common_reference_seed(seed: [u8; 32]) { + BoolEvaluator::with_local(|e| { + assert_eq!( + e.parameters().variant(), + &ParameterVariant::NonInteractiveMultiParty, + "Set parameters do not support Non interactive multi-party" + ); + }); + + assert!( + MULTI_PARTY_CRS + .set(NonInteractiveMultiPartyCrs { seed: seed }) + .is_ok(), + "Attempted to set MP SEED twice." + ) +} + +pub fn gen_client_key() -> ClientKey { + BoolEvaluator::with_local(|e| e.client_key()) +} + +pub fn gen_server_key_share( + user_id: usize, + total_users: usize, + client_key: &ClientKey, +) -> CommonReferenceSeededNonInteractiveMultiPartyServerKeyShare< + Vec>, + BoolParameters, + NonInteractiveMultiPartyCrs<[u8; 32]>, +> { + BoolEvaluator::with_local(|e| { + let cr_seed = NonInteractiveMultiPartyCrs::global(); + e.gen_non_interactive_multi_party_key_share(cr_seed, user_id, total_users, client_key) + }) +} + +pub fn aggregate_server_key_shares( + shares: &[CommonReferenceSeededNonInteractiveMultiPartyServerKeyShare< + Vec>, + BoolParameters, + NonInteractiveMultiPartyCrs<[u8; 32]>, + >], +) -> SeededNonInteractiveMultiPartyServerKey< + Vec>, + NonInteractiveMultiPartyCrs<[u8; 32]>, + BoolParameters, +> { + BoolEvaluator::with_local(|e| { + let cr_seed = NonInteractiveMultiPartyCrs::global(); + e.aggregate_non_interactive_multi_party_server_key_shares(cr_seed, shares) + }) +} + +impl + SeededNonInteractiveMultiPartyServerKey< + Vec>, + NonInteractiveMultiPartyCrs<[u8; 32]>, + BoolParameters, + > +{ + pub fn set_server_key(&self) { + let eval_key = NonInteractiveServerKeyEvaluationDomain::< + _, + BoolParameters, + DefaultSecureRng, + NttBackendU64, + >::from(self); + assert!( + BOOL_SERVER_KEY + .set(ShoupNonInteractiveServerKeyEvaluationDomain::from(eval_key)) + .is_ok(), + "Attempted to set server key twice!" + ); + } +} + +impl Global for NonInteractiveMultiPartyCrs<[u8; 32]> { + fn global() -> &'static Self { + MULTI_PARTY_CRS + .get() + .expect("Non-interactive multi-party common reference string not set") + } +} + +// BOOL EVALUATOR // +impl WithLocal for BoolEvaluator { + fn with_local(func: F) -> R + where + F: Fn(&Self) -> R, + { + BOOL_EVALUATOR.with_borrow(|s| func(s.as_ref().expect("Parameters not set"))) + } + + fn with_local_mut(func: F) -> R + where + F: Fn(&mut Self) -> R, + { + BOOL_EVALUATOR.with_borrow_mut(|s| func(s.as_mut().expect("Parameters not set"))) + } + + fn with_local_mut_mut(func: &mut F) -> R + where + F: FnMut(&mut Self) -> R, + { + BOOL_EVALUATOR.with_borrow_mut(|s| func(s.as_mut().expect("Parameters not set"))) + } +} + +pub(crate) type RuntimeServerKey = ShoupNonInteractiveServerKeyEvaluationDomain>>; +impl Global for RuntimeServerKey { + fn global() -> &'static Self { + BOOL_SERVER_KEY.get().expect("Server key not set!") + } +} + +/// Batch of bool ciphertexts stored as vector of RLWE ciphertext under user j's +/// RLWE secret `u_j` +/// +/// To use the bool ciphertexts in multi-party protocol first key switch the +/// ciphertexts from u_j to ideal RLWE secret `s` with +/// `self.key_switch(user_id)` where `user_id` is user j's id. Key switch +/// returns `BatchedFheBools` that stored key vector of key switched RLWE +/// ciphertext. +pub(super) struct NonInteractiveBatchedFheBools { + data: Vec, +} + +/// Batch of Bool cipphertexts stored as vector of RLWE ciphertexts under the +/// ideal RLWE secret key `s` of the protocol +/// +/// Bool ciphertext at `index` can be extracted from the coefficient at `index % +/// N` of `index / N`th RLWE ciphertext. +/// +/// To extract bool ciphertext at `index` as LWE ciphertext use +/// `self.extract(index)` +pub(super) struct BatchedFheBools { + pub(in super::super) data: Vec, +} + +/// Non interactive multi-party specfic encryptor decryptor routines +mod impl_enc_dec { + use crate::{ + bool::{evaluator::BoolEncoding, keys::NonInteractiveMultiPartyClientKey}, + multi_party::{ + multi_party_aggregate_decryption_shares_and_decrypt, multi_party_decryption_share, + }, + pbs::{sample_extract, PbsInfo, WithShoupRepr}, + random::{NewWithSeed, RandomFillUniformInModulus}, + rgsw::{rlwe_key_switch, seeded_secret_key_encrypt_rlwe}, + utils::TryConvertFrom1, + Encryptor, KeySwitchWithId, Matrix, MatrixEntity, MatrixMut, MultiPartyDecryptor, + RowEntity, RowMut, + }; + use itertools::Itertools; + use num_traits::{ToPrimitive, Zero}; + + use super::*; + + type Mat = Vec>; + + // Implement `extract` to extract Bool LWE ciphertext at `index` from + // `BatchedFheBools` + impl> BatchedFheBools + where + C::R: RowEntity + RowMut, + { + pub(crate) fn extract(&self, index: usize) -> C::R { + BoolEvaluator::with_local(|e| { + let ring_size = e.parameters().rlwe_n().0; + let ct_index = index / ring_size; + let coeff_index = index % ring_size; + let mut lwe_out = C::R::zeros(e.parameters().rlwe_n().0 + 1); + sample_extract( + &mut lwe_out, + &self.data[ct_index], + e.pbs_info().modop_rlweq(), + coeff_index, + ); + lwe_out + }) + } + } + + impl> From<&(Vec, [u8; 32])> + for NonInteractiveBatchedFheBools + where + ::R: RowMut, + { + /// Derive `NonInteractiveBatchedFheBools` from a vector seeded RLWE + /// ciphertexts (Vec, Seed) + /// + /// Unseed the RLWE ciphertexts and store them as vector RLWE + /// ciphertexts in `NonInteractiveBatchedFheBools` + fn from(value: &(Vec, [u8; 32])) -> Self { + BoolEvaluator::with_local(|e| { + let parameters = e.parameters(); + let ring_size = parameters.rlwe_n().0; + let rlwe_q = parameters.rlwe_q(); + + let mut prng = DefaultSecureRng::new_seeded(value.1); + let rlwes = value + .0 + .iter() + .map(|partb| { + let mut rlwe = M::zeros(2, ring_size); + + // sample A + RandomFillUniformInModulus::random_fill( + &mut prng, + rlwe_q, + rlwe.get_row_mut(0), + ); + + // Copy over B + rlwe.get_row_mut(1).copy_from_slice(partb.as_ref()); + + rlwe + }) + .collect_vec(); + Self { data: rlwes } + }) + } + } + + impl Encryptor<[bool], NonInteractiveBatchedFheBools> for K + where + K: Encryptor<[bool], (Mat, [u8; 32])>, + { + /// Encrypt a vector bool of arbitrary length as vector of unseeded RLWE + /// ciphertexts in `NonInteractiveBatchedFheBools` + fn encrypt(&self, m: &[bool]) -> NonInteractiveBatchedFheBools { + NonInteractiveBatchedFheBools::from(&K::encrypt(&self, m)) + } + } + + impl Encryptor<[bool], (Vec<::R>, [u8; 32])> for K + where + K: NonInteractiveMultiPartyClientKey, + ::R: + TryConvertFrom1<[K::Element], CiphertextModulus<::MatElement>>, + { + /// Encrypt a vector of bool of arbitrary length as vector of seeded + /// RLWE ciphertexts and returns (Vec, Seed) + fn encrypt(&self, m: &[bool]) -> (Mat, [u8; 32]) { + BoolEvaluator::with_local(|e| { + DefaultSecureRng::with_local_mut(|rng| { + let parameters = e.parameters(); + let ring_size = parameters.rlwe_n().0; + + let rlwe_count = ((m.len() as f64 / ring_size as f64).ceil()) + .to_usize() + .unwrap(); + + let mut seed = ::Seed::default(); + rng.fill_bytes(&mut seed); + let mut prng = DefaultSecureRng::new_seeded(seed); + + let sk_u = self.sk_u_rlwe(); + + // encrypt `m` into ceil(len(m)/N) RLWE ciphertexts + let rlwes = (0..rlwe_count) + .map(|index| { + let mut message = vec![::MatElement::zero(); ring_size]; + m[(index * ring_size)..std::cmp::min(m.len(), (index + 1) * ring_size)] + .iter() + .enumerate() + .for_each(|(i, v)| { + if *v { + message[i] = parameters.rlwe_q().true_el() + } else { + message[i] = parameters.rlwe_q().false_el() + } + }); + + // encrypt message + let mut rlwe_out = + <::R as RowEntity>::zeros(parameters.rlwe_n().0); + + seeded_secret_key_encrypt_rlwe( + &message, + &mut rlwe_out, + &sk_u, + e.pbs_info().modop_rlweq(), + e.pbs_info().nttop_rlweq(), + &mut prng, + rng, + ); + + rlwe_out + }) + .collect_vec(); + + (rlwes, seed) + }) + }) + } + } + + impl MultiPartyDecryptor::R> for K + where + K: NonInteractiveMultiPartyClientKey, + ::R: + TryConvertFrom1<[K::Element], CiphertextModulus<::MatElement>>, + { + type DecryptionShare = ::MatElement; + + fn gen_decryption_share(&self, c: &::R) -> Self::DecryptionShare { + BoolEvaluator::with_local(|e| { + DefaultSecureRng::with_local_mut(|rng| { + multi_party_decryption_share( + c, + self.sk_rlwe().as_slice(), + e.pbs_info().modop_rlweq(), + rng, + ) + }) + }) + } + + fn aggregate_decryption_shares( + &self, + c: &::R, + shares: &[Self::DecryptionShare], + ) -> bool { + BoolEvaluator::with_local(|e| { + let noisy_m = multi_party_aggregate_decryption_shares_and_decrypt( + c, + shares, + e.pbs_info().modop_rlweq(), + ); + + e.pbs_info().rlwe_q().decode(noisy_m) + }) + } + } + + impl KeySwitchWithId for Mat { + /// Key switch RLWE ciphertext `Self` from user j's RLWE secret u_j + /// to ideal RLWE secret `s` of non-interactive multi-party protocol. + /// + /// - user_id: user j's user_id in the protocol + fn key_switch(&self, user_id: usize) -> Mat { + BoolEvaluator::with_local(|e| { + assert!(self.dimension() == (2, e.parameters().rlwe_n().0)); + let server_key = BOOL_SERVER_KEY.get().unwrap(); + let ksk = server_key.ui_to_s_ksk(user_id); + let decomposer = e.ni_ui_to_s_ks_decomposer().as_ref().unwrap(); + + // perform key switch + rlwe_key_switch( + self, + ksk.as_ref(), + ksk.shoup_repr(), + decomposer, + e.pbs_info().nttop_rlweq(), + e.pbs_info().modop_rlweq(), + ) + }) + } + } + + impl KeySwitchWithId> for NonInteractiveBatchedFheBools + where + C: KeySwitchWithId, + { + /// Key switch `Self`'s vector of RLWE ciphertexts from user j's RLWE + /// secret u_j to ideal RLWE secret `s` of non-interactive + /// multi-party protocol. + /// + /// Returns vector of key switched RLWE ciphertext as `BatchedFheBools` + /// which can then be used to extract individual Bool LWE ciphertexts. + /// + /// - user_id: user j's user_id in the protocol + fn key_switch(&self, user_id: usize) -> BatchedFheBools { + let data = self + .data + .iter() + .map(|c| c.key_switch(user_id)) + .collect_vec(); + BatchedFheBools { data } + } + } +} + +#[cfg(test)] +mod tests { + use itertools::Itertools; + use rand::{thread_rng, RngCore}; + + use crate::{ + backend::Modulus, + bool::{ + keys::tests::{ideal_sk_rlwe, measure_noise_lwe}, + BooleanGates, + }, + utils::tests::Stats, + Encoder, Encryptor, KeySwitchWithId, MultiPartyDecryptor, + }; + + use super::*; +} diff --git a/src/bool/parameters.rs b/src/bool/parameters.rs new file mode 100644 index 0000000..7d7beab --- /dev/null +++ b/src/bool/parameters.rs @@ -0,0 +1,738 @@ +use num_traits::{ConstZero, FromPrimitive, PrimInt}; + +use crate::{ + backend::Modulus, + decomposer::{Decomposer, NumInfo}, + utils::log2, +}; + +pub(crate) trait DoubleDecomposerCount { + type Count; + fn a(&self) -> Self::Count; + fn b(&self) -> Self::Count; +} + +pub(crate) trait DoubleDecomposerParams { + type Base; + type Count; + + fn decomposition_base(&self) -> Self::Base; + fn decomposition_count_a(&self) -> Self::Count; + fn decomposition_count_b(&self) -> Self::Count; +} + +pub(crate) trait SingleDecomposerParams { + type Base; + type Count; + + // fn new(base: Self::Base, count: Self::Count) -> Self; + fn decomposition_base(&self) -> Self::Base; + fn decomposition_count(&self) -> Self::Count; +} + +impl DoubleDecomposerParams + for ( + DecompostionLogBase, + // Assume (Decomposition count for A, Decomposition count for B) + (DecompositionCount, DecompositionCount), + ) +{ + type Base = DecompostionLogBase; + type Count = DecompositionCount; + + // fn new( + // base: DecompostionLogBase, + // count_a: DecompositionCount, + // count_b: DecompositionCount, + // ) -> Self { + // (base, (count_a, count_b)) + // } + + fn decomposition_base(&self) -> Self::Base { + self.0 + } + + fn decomposition_count_a(&self) -> Self::Count { + self.1 .0 + } + + fn decomposition_count_b(&self) -> Self::Count { + self.1 .1 + } +} + +impl SingleDecomposerParams for (DecompostionLogBase, DecompositionCount) { + type Base = DecompostionLogBase; + type Count = DecompositionCount; + + // fn new(base: DecompostionLogBase, count: DecompositionCount) -> Self { + // (base, count) + // } + + fn decomposition_base(&self) -> Self::Base { + self.0 + } + + fn decomposition_count(&self) -> Self::Count { + self.1 + } +} + +#[derive(Clone, PartialEq, Debug)] +pub(crate) enum SecretKeyDistribution { + /// Elements of secret key are sample from Gaussian distribitution with + /// \sigma = 3.19 and \mu = 0.0 + ErrorDistribution, + /// Elements of secret key are chosen from the set {1,0,-1} with hamming + /// weight `floor(N/2)` where `N` is the secret dimension. + TernaryDistribution, +} + +#[derive(Clone, PartialEq, Debug)] +pub(crate) enum ParameterVariant { + SingleParty, + InteractiveMultiParty, + NonInteractiveMultiParty, +} +#[derive(Clone, PartialEq)] +pub struct BoolParameters { + /// RLWE secret key distribution + rlwe_secret_key_dist: SecretKeyDistribution, + /// LWE secret key distribtuion + lwe_secret_key_dist: SecretKeyDistribution, + /// RLWE ciphertext modulus Q + rlwe_q: CiphertextModulus, + /// LWE ciphertext modulus q (usually referred to as Q_{ks}) + lwe_q: CiphertextModulus, + /// Blind rotation modulus. It is the modulus to which we switch before + /// blind rotation. + /// + /// Since blind rotation decrypts LWE ciphertext in the exponent of a ring + /// polynomial, which is a ring mod 2N, blind rotation modulus is + /// always <= 2N. + br_q: usize, + /// Ring dimension `N` for 2N^{th} cyclotomic polynomial ring + rlwe_n: PolynomialSize, + /// LWE dimension `n` + lwe_n: LweDimension, + /// LWE key switch decompositon params + lwe_decomposer_params: (DecompostionLogBase, DecompositionCount), + /// Decompostion parameters for RLWE x RGSW. + /// + /// We restrict decomposition for RLWE'(-sm) and RLWE'(m) to have same base + /// but can have different decomposition count. We refer to this + /// DoubleDecomposer / RlweDecomposer + /// + /// Decomposition count `d_a` (i.e. for SignedDecompose(RLWE_A(m)) x + /// RLWE'(-sm)) and `d_b` (i.e. for SignedDecompose(RLWE_B(m)) x RLWE'(m)) + /// are always stored as `(d_a, d_b)` + rlrg_decomposer_params: ( + DecompostionLogBase, + (DecompositionCount, DecompositionCount), + ), + /// Decomposition parameters for RLWE automorphism + auto_decomposer_params: (DecompostionLogBase, DecompositionCount), + /// Decomposition parameters for RGSW0 x RGSW1 + /// + /// `0` and `1` indicate that RGSW0 and RGSW1 may not use same decomposition + /// parameters. + /// + /// In RGSW0 x RGSW1, decomposition parameters for RGSW1 are required. + /// Hence, the parameters we store are decomposition parameters of RGSW1. + /// + /// Like RLWE x RGSW decomposition parameters (1) we restrict to same base + /// but can have different decomposition counts `d_a` and `d_b` and (2) + /// decomposition count `d_a` and `d_b` are always stored as `(d_a, d_b)` + /// + /// RGSW0 x RGSW1 are optional because they only necessary to be supplied in + /// multi-party setting. + rgrg_decomposer_params: Option<( + DecompostionLogBase, + (DecompositionCount, DecompositionCount), + )>, + /// Decomposition parameters for non-interactive key switching from u_j to + /// s, hwere u_j is RLWE secret `u` of party `j` and `s` is the ideal RLWE + /// secret key. + /// + /// Decomposition parameters for non-interactive key switching are optional + /// and must be supplied only for non-interactive multi-party + non_interactive_ui_to_s_key_switch_decomposer: + Option<(DecompostionLogBase, DecompositionCount)>, + /// Group generator for Z^*_{br_q} + g: usize, + /// Window size parameter for LMKC++ blind rotation + w: usize, + /// Parameter variant + variant: ParameterVariant, +} + +impl BoolParameters { + pub(crate) fn rlwe_secret_key_dist(&self) -> &SecretKeyDistribution { + &self.rlwe_secret_key_dist + } + + pub(crate) fn lwe_secret_key_dist(&self) -> &SecretKeyDistribution { + &self.lwe_secret_key_dist + } + + pub(crate) fn rlwe_q(&self) -> &CiphertextModulus { + &self.rlwe_q + } + + pub(crate) fn lwe_q(&self) -> &CiphertextModulus { + &self.lwe_q + } + + pub(crate) fn br_q(&self) -> &usize { + &self.br_q + } + + pub(crate) fn rlwe_n(&self) -> &PolynomialSize { + &self.rlwe_n + } + + pub(crate) fn lwe_n(&self) -> &LweDimension { + &self.lwe_n + } + + pub(crate) fn g(&self) -> usize { + self.g + } + + pub(crate) fn w(&self) -> usize { + self.w + } + + pub(crate) fn rlwe_by_rgsw_decomposition_params( + &self, + ) -> &( + DecompostionLogBase, + (DecompositionCount, DecompositionCount), + ) { + &self.rlrg_decomposer_params + } + + pub(crate) fn rgsw_by_rgsw_decomposition_params( + &self, + ) -> ( + DecompostionLogBase, + (DecompositionCount, DecompositionCount), + ) { + self.rgrg_decomposer_params.expect(&format!( + "Parameter variant {:?} does not support RGSWxRGSW", + self.variant + )) + } + + pub(crate) fn rlwe_rgsw_decomposition_base(&self) -> DecompostionLogBase { + self.rlrg_decomposer_params.0 + } + + pub(crate) fn rlwe_rgsw_decomposition_count(&self) -> (DecompositionCount, DecompositionCount) { + self.rlrg_decomposer_params.1 + } + + pub(crate) fn rgsw_rgsw_decomposition_count(&self) -> (DecompositionCount, DecompositionCount) { + let params = self.rgrg_decomposer_params.expect(&format!( + "Parameter variant {:?} does not support RGSW x RGSW", + self.variant + )); + params.1 + } + + pub(crate) fn auto_decomposition_param(&self) -> &(DecompostionLogBase, DecompositionCount) { + &self.auto_decomposer_params + } + + pub(crate) fn auto_decomposition_base(&self) -> DecompostionLogBase { + self.auto_decomposer_params.decomposition_base() + } + + pub(crate) fn auto_decomposition_count(&self) -> DecompositionCount { + self.auto_decomposer_params.decomposition_count() + } + + pub(crate) fn lwe_decomposition_base(&self) -> DecompostionLogBase { + self.lwe_decomposer_params.decomposition_base() + } + + pub(crate) fn lwe_decomposition_count(&self) -> DecompositionCount { + self.lwe_decomposer_params.decomposition_count() + } + + pub(crate) fn non_interactive_ui_to_s_key_switch_decomposition_count( + &self, + ) -> DecompositionCount { + let params = self + .non_interactive_ui_to_s_key_switch_decomposer + .expect(&format!( + "Parameter variant {:?} does not support non-interactive", + self.variant + )); + params.decomposition_count() + } + + pub(crate) fn rgsw_rgsw_decomposer>(&self) -> (D, D) + where + El: Copy, + { + let params = self.rgrg_decomposer_params.expect(&format!( + "Parameter variant {:?} does not support RGSW x RGSW", + self.variant + )); + ( + // A + D::new( + self.rlwe_q.0, + params.decomposition_base().0, + params.decomposition_count_a().0, + ), + // B + D::new( + self.rlwe_q.0, + params.decomposition_base().0, + params.decomposition_count_b().0, + ), + ) + } + + pub(crate) fn auto_decomposer>(&self) -> D + where + El: Copy, + { + D::new( + self.rlwe_q.0, + self.auto_decomposer_params.decomposition_base().0, + self.auto_decomposer_params.decomposition_count().0, + ) + } + + pub(crate) fn lwe_decomposer>(&self) -> D + where + El: Copy, + { + D::new( + self.lwe_q.0, + self.lwe_decomposer_params.decomposition_base().0, + self.lwe_decomposer_params.decomposition_count().0, + ) + } + + pub(crate) fn rlwe_rgsw_decomposer>(&self) -> (D, D) + where + El: Copy, + { + ( + // A + D::new( + self.rlwe_q.0, + self.rlrg_decomposer_params.decomposition_base().0, + self.rlrg_decomposer_params.decomposition_count_a().0, + ), + // B + D::new( + self.rlwe_q.0, + self.rlrg_decomposer_params.decomposition_base().0, + self.rlrg_decomposer_params.decomposition_count_b().0, + ), + ) + } + + pub(crate) fn non_interactive_ui_to_s_key_switch_decomposer>( + &self, + ) -> D + where + El: Copy, + { + let params = self + .non_interactive_ui_to_s_key_switch_decomposer + .expect(&format!( + "Parameter variant {:?} does not support non-interactive", + self.variant + )); + D::new( + self.rlwe_q.0, + params.decomposition_base().0, + params.decomposition_count().0, + ) + } + + /// Returns dlogs of `g` for which auto keys are required as + /// per the parameter. Given that autos are required for [-g, g, g^2, ..., + /// g^w] function returns the following [0, 1, 2, ..., w] where `w` is + /// the window size. Note that although g^0 = 1, we use 0 for -g. + pub(crate) fn auto_element_dlogs(&self) -> Vec { + let mut els = vec![0]; + (1..self.w + 1).into_iter().for_each(|e| { + els.push(e); + }); + els + } + + pub(crate) fn variant(&self) -> &ParameterVariant { + &self.variant + } +} + +#[derive(Clone, Copy, PartialEq)] +pub struct DecompostionLogBase(pub(crate) usize); +impl AsRef for DecompostionLogBase { + fn as_ref(&self) -> &usize { + &self.0 + } +} +#[derive(Clone, Copy, PartialEq)] +pub struct DecompositionCount(pub(crate) usize); +impl AsRef for DecompositionCount { + fn as_ref(&self) -> &usize { + &self.0 + } +} + +#[derive(Clone, Copy, PartialEq)] +pub(crate) struct LweDimension(pub(crate) usize); +#[derive(Clone, Copy, PartialEq)] +pub(crate) struct PolynomialSize(pub(crate) usize); +#[derive(Clone, Copy, PartialEq, Debug)] + +/// T equals modulus when modulus is non-native. Otherwise T equals 0. bool is +/// true when modulus is native, false otherwise. +pub struct CiphertextModulus(T, bool); + +impl CiphertextModulus { + const fn new_native() -> Self { + // T::zero is stored only for convenience. It has no use when modulus + // is native. That is, either u128,u64,u32,u16 + Self(T::ZERO, true) + } + + const fn new_non_native(q: T) -> Self { + Self(q, false) + } +} + +impl CiphertextModulus +where + T: PrimInt + NumInfo, +{ + fn _bits() -> usize { + T::BITS as usize + } + + fn _native(&self) -> bool { + self.1 + } + + fn _half_q(&self) -> T { + if self._native() { + T::one() << (Self::_bits() - 1) + } else { + self.0 >> 1 + } + } + + fn _q(&self) -> Option { + if self._native() { + None + } else { + Some(self.0) + } + } +} + +impl Modulus for CiphertextModulus +where + T: PrimInt + FromPrimitive + NumInfo, +{ + type Element = T; + fn is_native(&self) -> bool { + self._native() + } + fn largest_unsigned_value(&self) -> Self::Element { + if self._native() { + T::max_value() + } else { + self.0 - T::one() + } + } + fn neg_one(&self) -> Self::Element { + if self._native() { + T::max_value() + } else { + self.0 - T::one() + } + } + // fn signed_max(&self) -> Self::Element {} + // fn signed_min(&self) -> Self::Element {} + fn smallest_unsigned_value(&self) -> Self::Element { + T::zero() + } + + fn map_element_to_i64(&self, v: &Self::Element) -> i64 { + assert!(*v <= self.largest_unsigned_value()); + if *v > self._half_q() { + -((self.largest_unsigned_value() - *v) + T::one()) + .to_i64() + .unwrap() + } else { + v.to_i64().unwrap() + } + } + + fn map_element_from_f64(&self, v: f64) -> Self::Element { + let v = v.round(); + + let v_el = T::from_f64(v.abs()).unwrap(); + assert!(v_el <= self.largest_unsigned_value()); + + if v < 0.0 { + self.largest_unsigned_value() - v_el + T::one() + } else { + v_el + } + } + + fn map_element_from_i64(&self, v: i64) -> Self::Element { + let v_el = T::from_i64(v.abs()).unwrap(); + assert!(v_el <= self.largest_unsigned_value()); + if v < 0 { + self.largest_unsigned_value() - v_el + T::one() + } else { + v_el + } + } + + fn q(&self) -> Option { + self._q() + } + + fn q_as_f64(&self) -> Option { + if self._native() { + Some(T::max_value().to_f64().unwrap() + 1.0) + } else { + self.0.to_f64() + } + } + + fn log_q(&self) -> usize { + if self.is_native() { + Self::_bits() + } else { + log2(&self.q().unwrap()) + } + } +} + +pub(crate) const I_2P_LB_SR: BoolParameters = BoolParameters:: { + rlwe_secret_key_dist: SecretKeyDistribution::TernaryDistribution, + lwe_secret_key_dist: SecretKeyDistribution::TernaryDistribution, + rlwe_q: CiphertextModulus::new_non_native(18014398509404161), + lwe_q: CiphertextModulus::new_non_native(1 << 15), + br_q: 1 << 11, + rlwe_n: PolynomialSize(1 << 11), + lwe_n: LweDimension(580), + lwe_decomposer_params: (DecompostionLogBase(1), DecompositionCount(12)), + rlrg_decomposer_params: ( + DecompostionLogBase(17), + (DecompositionCount(1), DecompositionCount(1)), + ), + rgrg_decomposer_params: Some(( + DecompostionLogBase(7), + (DecompositionCount(6), DecompositionCount(5)), + )), + auto_decomposer_params: (DecompostionLogBase(24), DecompositionCount(1)), + non_interactive_ui_to_s_key_switch_decomposer: None, + g: 5, + w: 10, + variant: ParameterVariant::InteractiveMultiParty, +}; + +pub(crate) const I_4P: BoolParameters = BoolParameters:: { + rlwe_secret_key_dist: SecretKeyDistribution::TernaryDistribution, + lwe_secret_key_dist: SecretKeyDistribution::TernaryDistribution, + rlwe_q: CiphertextModulus::new_non_native(18014398509404161), + lwe_q: CiphertextModulus::new_non_native(1 << 16), + br_q: 1 << 11, + rlwe_n: PolynomialSize(1 << 11), + lwe_n: LweDimension(620), + lwe_decomposer_params: (DecompostionLogBase(1), DecompositionCount(13)), + rlrg_decomposer_params: ( + DecompostionLogBase(17), + (DecompositionCount(1), DecompositionCount(1)), + ), + rgrg_decomposer_params: Some(( + DecompostionLogBase(6), + (DecompositionCount(7), DecompositionCount(6)), + )), + auto_decomposer_params: (DecompostionLogBase(24), DecompositionCount(1)), + non_interactive_ui_to_s_key_switch_decomposer: None, + g: 5, + w: 10, + variant: ParameterVariant::InteractiveMultiParty, +}; + +pub(crate) const I_8P: BoolParameters = BoolParameters:: { + rlwe_secret_key_dist: SecretKeyDistribution::TernaryDistribution, + lwe_secret_key_dist: SecretKeyDistribution::TernaryDistribution, + rlwe_q: CiphertextModulus::new_non_native(18014398509404161), + lwe_q: CiphertextModulus::new_non_native(1 << 17), + br_q: 1 << 12, + rlwe_n: PolynomialSize(1 << 11), + lwe_n: LweDimension(660), + lwe_decomposer_params: (DecompostionLogBase(1), DecompositionCount(14)), + rlrg_decomposer_params: ( + DecompostionLogBase(17), + (DecompositionCount(1), DecompositionCount(1)), + ), + rgrg_decomposer_params: Some(( + DecompostionLogBase(5), + (DecompositionCount(9), DecompositionCount(8)), + )), + auto_decomposer_params: (DecompostionLogBase(24), DecompositionCount(1)), + non_interactive_ui_to_s_key_switch_decomposer: None, + g: 5, + w: 10, + variant: ParameterVariant::InteractiveMultiParty, +}; + +pub(crate) const NI_2P: BoolParameters = BoolParameters:: { + rlwe_secret_key_dist: SecretKeyDistribution::TernaryDistribution, + lwe_secret_key_dist: SecretKeyDistribution::ErrorDistribution, + rlwe_q: CiphertextModulus::new_non_native(18014398509404161), + lwe_q: CiphertextModulus::new_non_native(1 << 16), + br_q: 1 << 12, + rlwe_n: PolynomialSize(1 << 11), + lwe_n: LweDimension(520), + lwe_decomposer_params: (DecompostionLogBase(1), DecompositionCount(13)), + rlrg_decomposer_params: ( + DecompostionLogBase(17), + (DecompositionCount(1), DecompositionCount(1)), + ), + rgrg_decomposer_params: Some(( + DecompostionLogBase(4), + (DecompositionCount(10), DecompositionCount(9)), + )), + auto_decomposer_params: (DecompostionLogBase(24), DecompositionCount(1)), + non_interactive_ui_to_s_key_switch_decomposer: Some(( + DecompostionLogBase(1), + DecompositionCount(50), + )), + g: 5, + w: 10, + variant: ParameterVariant::NonInteractiveMultiParty, +}; + +pub(crate) const NI_4P_HB_FR: BoolParameters = BoolParameters:: { + rlwe_secret_key_dist: SecretKeyDistribution::TernaryDistribution, + lwe_secret_key_dist: SecretKeyDistribution::TernaryDistribution, + rlwe_q: CiphertextModulus::new_non_native(18014398509404161), + lwe_q: CiphertextModulus::new_non_native(1 << 16), + br_q: 1 << 11, + rlwe_n: PolynomialSize(1 << 11), + lwe_n: LweDimension(620), + lwe_decomposer_params: (DecompostionLogBase(1), DecompositionCount(13)), + rlrg_decomposer_params: ( + DecompostionLogBase(17), + (DecompositionCount(1), DecompositionCount(1)), + ), + rgrg_decomposer_params: Some(( + DecompostionLogBase(3), + (DecompositionCount(13), DecompositionCount(12)), + )), + auto_decomposer_params: (DecompostionLogBase(24), DecompositionCount(1)), + non_interactive_ui_to_s_key_switch_decomposer: Some(( + DecompostionLogBase(1), + DecompositionCount(50), + )), + g: 5, + w: 10, + variant: ParameterVariant::NonInteractiveMultiParty, +}; + +pub(crate) const NI_4P_LB_SR: BoolParameters = BoolParameters:: { + rlwe_secret_key_dist: SecretKeyDistribution::TernaryDistribution, + lwe_secret_key_dist: SecretKeyDistribution::TernaryDistribution, + rlwe_q: CiphertextModulus::new_non_native(18014398509404161), + lwe_q: CiphertextModulus::new_non_native(1 << 16), + br_q: 1 << 12, + rlwe_n: PolynomialSize(1 << 11), + lwe_n: LweDimension(620), + lwe_decomposer_params: (DecompostionLogBase(1), DecompositionCount(13)), + rlrg_decomposer_params: ( + DecompostionLogBase(17), + (DecompositionCount(1), DecompositionCount(1)), + ), + rgrg_decomposer_params: Some(( + DecompostionLogBase(4), + (DecompositionCount(10), DecompositionCount(9)), + )), + auto_decomposer_params: (DecompostionLogBase(24), DecompositionCount(1)), + non_interactive_ui_to_s_key_switch_decomposer: Some(( + DecompostionLogBase(1), + DecompositionCount(50), + )), + g: 5, + w: 10, + variant: ParameterVariant::NonInteractiveMultiParty, +}; + +pub(crate) const NI_8P: BoolParameters = BoolParameters:: { + rlwe_secret_key_dist: SecretKeyDistribution::TernaryDistribution, + lwe_secret_key_dist: SecretKeyDistribution::TernaryDistribution, + rlwe_q: CiphertextModulus::new_non_native(18014398509404161), + lwe_q: CiphertextModulus::new_non_native(1 << 17), + br_q: 1 << 12, + rlwe_n: PolynomialSize(1 << 11), + lwe_n: LweDimension(660), + lwe_decomposer_params: (DecompostionLogBase(1), DecompositionCount(14)), + rlrg_decomposer_params: ( + DecompostionLogBase(17), + (DecompositionCount(1), DecompositionCount(1)), + ), + rgrg_decomposer_params: Some(( + DecompostionLogBase(2), + (DecompositionCount(20), DecompositionCount(18)), + )), + auto_decomposer_params: (DecompostionLogBase(24), DecompositionCount(1)), + non_interactive_ui_to_s_key_switch_decomposer: Some(( + DecompostionLogBase(1), + DecompositionCount(50), + )), + g: 5, + w: 10, + variant: ParameterVariant::NonInteractiveMultiParty, +}; + +#[cfg(test)] +pub(crate) const SP_TEST_BOOL_PARAMS: BoolParameters = BoolParameters:: { + rlwe_secret_key_dist: SecretKeyDistribution::TernaryDistribution, + lwe_secret_key_dist: SecretKeyDistribution::ErrorDistribution, + rlwe_q: CiphertextModulus::new_non_native(268369921u64), + lwe_q: CiphertextModulus::new_non_native(1 << 16), + br_q: 1 << 9, + rlwe_n: PolynomialSize(1 << 9), + lwe_n: LweDimension(100), + lwe_decomposer_params: (DecompostionLogBase(4), DecompositionCount(4)), + rlrg_decomposer_params: ( + DecompostionLogBase(7), + (DecompositionCount(4), DecompositionCount(4)), + ), + rgrg_decomposer_params: None, + auto_decomposer_params: (DecompostionLogBase(7), DecompositionCount(4)), + non_interactive_ui_to_s_key_switch_decomposer: None, + g: 5, + w: 5, + variant: ParameterVariant::SingleParty, +}; + +// #[cfg(test)] +// mod tests { + +// #[test] +// fn find_prime() { +// let bits = 60; +// let ring_size = 1 << 11; +// let prime = crate::utils::generate_prime(bits, ring_size * 2, 1 << +// bits).unwrap(); dbg!(prime); +// } +// } diff --git a/src/bool/print_noise.rs b/src/bool/print_noise.rs new file mode 100644 index 0000000..9b21ab8 --- /dev/null +++ b/src/bool/print_noise.rs @@ -0,0 +1,1020 @@ +use std::{fmt::Debug, iter::Sum}; + +use itertools::izip; +use num_traits::{FromPrimitive, PrimInt, Zero}; +use rand_distr::uniform::SampleUniform; + +use crate::{ + backend::{GetModulus, Modulus}, + decomposer::{Decomposer, NumInfo, RlweDecomposer}, + lwe::{decrypt_lwe, lwe_key_switch}, + parameters::{BoolParameters, CiphertextModulus}, + random::{DefaultSecureRng, RandomFillUniformInModulus}, + rgsw::{ + decrypt_rlwe, rlwe_auto, rlwe_auto_scratch_rows, RlweCiphertextMutRef, RlweKskRef, + RuntimeScratchMutRef, + }, + utils::{encode_x_pow_si_with_emebedding_factor, tests::Stats, TryConvertFrom1}, + ArithmeticOps, ClientKey, MatrixEntity, MatrixMut, ModInit, Ntt, NttInit, RowEntity, RowMut, + VectorOps, +}; + +use super::keys::tests::{ideal_sk_lwe, ideal_sk_rlwe}; + +pub(crate) trait CollectRuntimeServerKeyStats { + type M; + /// RGSW ciphertext X^{s[s_index]} in evaluation domain where `s` the LWE + /// secret + fn rgsw_cts_lwe_si(&self, s_index: usize) -> &Self::M; + /// Auto key in evaluation domain for automorphism g^k. For auto key for + /// automorphism corresponding to -g, set k = 0 + fn galois_key_for_auto(&self, k: usize) -> &Self::M; + /// LWE key switching key + fn lwe_ksk(&self) -> &Self::M; +} + +#[derive(Default)] +struct ServerKeyStats { + /// Distribution of noise in RGSW ciphertexts + /// + /// We collect statistics for RLWE'(-sm) separately from RLWE'(m) because + /// non-interactive protocol differents between the two. Although we expect + /// the distribution of noise in both to be the same. + brk_rgsw_cts: (Stats, Stats), + /// Distribtion of noise added to RLWE ciphertext after automorphism using + /// Server auto keys. + post_1_auto: Stats, + /// Distribution of noise added in LWE key switching from LWE_{q, s} to + /// LWE_{q, z} where `z` is ideal LWE secret and `s` is ideal RLWE secret + /// using Server's LWE key switching key. + post_lwe_key_switch: Stats, +} + +impl ServerKeyStats +where + T: for<'a> Sum<&'a T>, +{ + fn new() -> Self { + ServerKeyStats { + brk_rgsw_cts: (Stats::default(), Stats::default()), + post_1_auto: Stats::default(), + post_lwe_key_switch: Stats::default(), + } + } + + fn add_noise_brk_rgsw_cts_nsm(&mut self, noise: &[T]) { + self.brk_rgsw_cts.0.add_many_samples(noise); + } + + fn add_noise_brk_rgsw_cts_m(&mut self, noise: &[T]) { + self.brk_rgsw_cts.1.add_many_samples(noise); + } + + fn add_noise_post_1_auto(&mut self, noise: &[T]) { + self.post_1_auto.add_many_samples(&noise); + } + + fn add_noise_post_kwe_key_switch(&mut self, noise: &[T]) { + self.post_lwe_key_switch.add_many_samples(&noise); + } + + fn merge_in(&mut self, other: &Self) { + self.brk_rgsw_cts.0.merge_in(&other.brk_rgsw_cts.0); + self.brk_rgsw_cts.1.merge_in(&other.brk_rgsw_cts.1); + + self.post_1_auto.merge_in(&other.post_1_auto); + self.post_lwe_key_switch + .merge_in(&other.post_lwe_key_switch); + } +} + +fn collect_server_key_stats< + M: MatrixEntity + MatrixMut, + D: Decomposer, + NttOp: NttInit> + Ntt, + ModOp: VectorOps + + ArithmeticOps + + ModInit> + + GetModulus, Element = M::MatElement>, + S: CollectRuntimeServerKeyStats, +>( + parameters: BoolParameters, + client_keys: &[ClientKey], + server_key: &S, +) -> ServerKeyStats +where + M::R: RowMut + RowEntity + TryConvertFrom1<[i32], CiphertextModulus> + Clone, + M::MatElement: Copy + PrimInt + FromPrimitive + SampleUniform + Zero + Debug + NumInfo, +{ + let ideal_sk_rlwe = ideal_sk_rlwe(client_keys); + let ideal_sk_lwe = ideal_sk_lwe(client_keys); + + let embedding_factor = (2 * parameters.rlwe_n().0) / parameters.br_q(); + let rlwe_n = parameters.rlwe_n().0; + let rlwe_q = parameters.rlwe_q(); + let lwe_q = parameters.lwe_q(); + let rlwe_modop = ModOp::new(rlwe_q.clone()); + let rlwe_nttop = NttOp::new(rlwe_q, rlwe_n); + let lwe_modop = ModOp::new(*parameters.lwe_q()); + + let rlwe_x_rgsw_decomposer = parameters.rlwe_rgsw_decomposer::(); + let (rlwe_x_rgsw_gadget_a, rlwe_x_rgsw_gadget_b) = ( + rlwe_x_rgsw_decomposer.a().gadget_vector(), + rlwe_x_rgsw_decomposer.b().gadget_vector(), + ); + + let lwe_ks_decomposer = parameters.lwe_decomposer::(); + + let mut server_key_stats = ServerKeyStats::new(); + + let mut rng = DefaultSecureRng::new(); + + // RGSW ciphertext noise + // Check noise in RGSW ciphertexts of ideal LWE secret elements + { + ideal_sk_lwe.iter().enumerate().for_each(|(s_index, s_i)| { + let rgsw_ct_i = server_key.rgsw_cts_lwe_si(s_index); + + // X^{s[i]} + let m_si = encode_x_pow_si_with_emebedding_factor::( + *s_i, + embedding_factor, + rlwe_n, + rlwe_q, + ); + + // RLWE'(-sm) + let mut neg_s_eval = M::R::try_convert_from(ideal_sk_rlwe.as_slice(), rlwe_q); + rlwe_modop.elwise_neg_mut(neg_s_eval.as_mut()); + rlwe_nttop.forward(neg_s_eval.as_mut()); + + for j in 0..rlwe_x_rgsw_decomposer.a().decomposition_count().0 { + // RLWE(B^{j} * -s[X]*X^{s_lwe[i]}) + + // -s[X]*X^{s_lwe[i]}*B_j + let mut m_ideal = m_si.clone(); + rlwe_nttop.forward(m_ideal.as_mut()); + rlwe_modop.elwise_mul_mut(m_ideal.as_mut(), neg_s_eval.as_ref()); + rlwe_nttop.backward(m_ideal.as_mut()); + rlwe_modop.elwise_scalar_mul_mut(m_ideal.as_mut(), &rlwe_x_rgsw_gadget_a[j]); + + // RLWE(-s*X^{s_lwe[i]}*B_j) + let mut rlwe_ct = M::zeros(2, rlwe_n); + rlwe_ct + .get_row_mut(0) + .copy_from_slice(rgsw_ct_i.get_row_slice(j)); + rlwe_ct.get_row_mut(1).copy_from_slice( + rgsw_ct_i.get_row_slice(j + rlwe_x_rgsw_decomposer.a().decomposition_count().0), + ); + // RGSW ciphertexts are in eval domain. We put RLWE ciphertexts back in + // coefficient domain + rlwe_ct + .iter_rows_mut() + .for_each(|r| rlwe_nttop.backward(r.as_mut())); + + let mut m_back = M::R::zeros(rlwe_n); + decrypt_rlwe( + &rlwe_ct, + &ideal_sk_rlwe, + &mut m_back, + &rlwe_nttop, + &rlwe_modop, + ); + + // diff + rlwe_modop.elwise_sub_mut(m_back.as_mut(), m_ideal.as_ref()); + server_key_stats.add_noise_brk_rgsw_cts_nsm(&Vec::::try_convert_from( + m_back.as_ref(), + rlwe_q, + )); + } + + // RLWE'(m) + for j in 0..rlwe_x_rgsw_decomposer.b().decomposition_count().0 { + // RLWE(B^{j} * X^{s_lwe[i]}) + + // X^{s_lwe[i]}*B_j + let mut m_ideal = m_si.clone(); + rlwe_modop.elwise_scalar_mul_mut(m_ideal.as_mut(), &rlwe_x_rgsw_gadget_b[j]); + + // RLWE(X^{s_lwe[i]}*B_j) + let mut rlwe_ct = M::zeros(2, rlwe_n); + rlwe_ct.get_row_mut(0).copy_from_slice( + rgsw_ct_i.get_row_slice( + j + (2 * rlwe_x_rgsw_decomposer.a().decomposition_count().0), + ), + ); + rlwe_ct + .get_row_mut(1) + .copy_from_slice(rgsw_ct_i.get_row_slice( + j + (2 * rlwe_x_rgsw_decomposer.a().decomposition_count().0) + + rlwe_x_rgsw_decomposer.b().decomposition_count().0, + )); + rlwe_ct + .iter_rows_mut() + .for_each(|r| rlwe_nttop.backward(r.as_mut())); + + let mut m_back = M::R::zeros(rlwe_n); + decrypt_rlwe( + &rlwe_ct, + &ideal_sk_rlwe, + &mut m_back, + &rlwe_nttop, + &rlwe_modop, + ); + + // diff + rlwe_modop.elwise_sub_mut(m_back.as_mut(), m_ideal.as_ref()); + server_key_stats.add_noise_brk_rgsw_cts_m(&Vec::::try_convert_from( + m_back.as_ref(), + rlwe_q, + )); + } + }); + } + + // Noise in ciphertext after 1 auto + // For each auto key g^k. Sample random polynomial m(X) and multiply with + // -s(X^{g^k}) using key corresponding to auto g^k. Then check the noise in + // resutling RLWE(m(X) * -s(X^{g^k})) + { + let neg_s = { + let mut s = M::R::try_convert_from(ideal_sk_rlwe.as_slice(), rlwe_q); + rlwe_modop.elwise_neg_mut(s.as_mut()); + s + }; + let g = parameters.g(); + let br_q = parameters.br_q(); + let g_dlogs = parameters.auto_element_dlogs(); + let auto_decomposer = parameters.auto_decomposer::(); + let mut scratch_matrix = M::zeros(rlwe_auto_scratch_rows(&auto_decomposer), rlwe_n); + let mut scratch_matrix_ref = RuntimeScratchMutRef::new(scratch_matrix.as_mut()); + + g_dlogs.iter().for_each(|k| { + let g_pow_k = if *k == 0 { + -(g as isize) + } else { + (g.pow(*k as u32) % br_q) as isize + }; + + // Send s(X) -> s(X^{g^k}) + let (auto_index_map, auto_sign_map) = crate::rgsw::generate_auto_map(rlwe_n, g_pow_k); + let mut neg_s_g_k = M::R::zeros(rlwe_n); + izip!( + neg_s.as_ref().iter(), + auto_index_map.iter(), + auto_sign_map.iter() + ) + .for_each(|(el, to_index, to_sign)| { + if !to_sign { + neg_s_g_k.as_mut()[*to_index] = rlwe_modop.neg(el); + } else { + neg_s_g_k.as_mut()[*to_index] = *el; + } + }); + + let mut m = M::R::zeros(rlwe_n); + RandomFillUniformInModulus::random_fill(&mut rng, rlwe_q, m.as_mut()); + + // We want -m(X^{g^k})s(X^{g^k}) after key switch + let want_m = { + let mut m_g_k_eval = M::R::zeros(rlwe_n); + // send m(X) -> m(X^{g^k}) + izip!( + m.as_ref().iter(), + auto_index_map.iter(), + auto_sign_map.iter() + ) + .for_each(|(el, to_index, to_sign)| { + if !to_sign { + m_g_k_eval.as_mut()[*to_index] = rlwe_modop.neg(el); + } else { + m_g_k_eval.as_mut()[*to_index] = *el; + } + }); + + rlwe_nttop.forward(m_g_k_eval.as_mut()); + let mut s_g_k = neg_s_g_k.clone(); + rlwe_nttop.forward(s_g_k.as_mut()); + rlwe_modop.elwise_mul_mut(m_g_k_eval.as_mut(), s_g_k.as_ref()); + rlwe_nttop.backward(m_g_k_eval.as_mut()); + m_g_k_eval + }; + + // RLWE auto sends part A, A(X), of RLWE to A(X^{g^k}) and then multiplies it + // with -s(X^{g^k}) using auto key. Deliberately set RLWE = (0, m(X)) + // (ie. m in part A) to get back RLWE(-m(X^{g^k})s(X^{g^k})) + let mut rlwe = M::zeros(2, rlwe_n); + rlwe.get_row_mut(0).copy_from_slice(m.as_ref()); + + rlwe_auto( + &mut RlweCiphertextMutRef::new(rlwe.as_mut()), + &RlweKskRef::new( + server_key.galois_key_for_auto(*k).as_ref(), + auto_decomposer.decomposition_count().0, + ), + &mut scratch_matrix_ref, + &auto_index_map, + &auto_sign_map, + &rlwe_modop, + &rlwe_nttop, + &auto_decomposer, + false, + ); + + // decrypt RLWE(-m(X)s(X^{g^k]})) + let mut back_m = M::R::zeros(rlwe_n); + decrypt_rlwe(&rlwe, &ideal_sk_rlwe, &mut back_m, &rlwe_nttop, &rlwe_modop); + + // check difference + let mut diff = back_m; + rlwe_modop.elwise_sub_mut(diff.as_mut(), want_m.as_ref()); + server_key_stats + .add_noise_post_1_auto(&Vec::::try_convert_from(diff.as_ref(), rlwe_q)); + }); + + // sample random m + + // key switch + } + + // LWE Key switch + // LWE key switches LWE_in = LWE_{Q_ks,N, s}(m) = (b, a_0, ... a_N) -> LWE_out = + // LWE_{Q_{ks}, n, z}(m) = (b', a'_0, ..., a'n) + // If LWE_in = (0, a = {a_0, ..., a_N}), then LWE_out = LWE(-a \cdot s_{rlwe}) + for _ in 0..100 { + let mut lwe_in = M::R::zeros(rlwe_n + 1); + RandomFillUniformInModulus::random_fill(&mut rng, lwe_q, &mut lwe_in.as_mut()[1..]); + + // Key switch + let mut lwe_out = M::R::zeros(parameters.lwe_n().0 + 1); + lwe_key_switch( + &mut lwe_out, + &lwe_in, + server_key.lwe_ksk(), + &lwe_modop, + &lwe_ks_decomposer, + ); + + // -a \cdot s + let mut want_m = M::MatElement::zero(); + izip!(lwe_in.as_ref().iter().skip(1), ideal_sk_rlwe.iter()).for_each(|(a, b)| { + want_m = lwe_modop.add( + &want_m, + &lwe_modop.mul(a, &lwe_q.map_element_from_i64(*b as i64)), + ); + }); + want_m = lwe_modop.neg(&want_m); + + // decrypt lwe out + let back_m = decrypt_lwe(&lwe_out, &ideal_sk_lwe, &lwe_modop); + + let noise = lwe_modop.sub(&want_m, &back_m); + server_key_stats.add_noise_post_kwe_key_switch(&vec![lwe_q.map_element_to_i64(&noise)]); + } + + server_key_stats + // Auto keys noise + + // Ksk noise +} + +#[cfg(test)] +mod tests { + use itertools::Itertools; + + #[test] + #[cfg(feature = "interactive_mp")] + fn interactive_key_noise() { + use crate::{ + aggregate_public_key_shares, aggregate_server_key_shares, + bool::{ + evaluator::InteractiveMultiPartyCrs, + keys::{key_size::KeySize, ServerKeyEvaluationDomain}, + }, + collective_pk_share, collective_server_key_share, gen_client_key, + parameters::CiphertextModulus, + random::DefaultSecureRng, + set_common_reference_seed, set_parameter_set, + utils::WithLocal, + BoolEvaluator, DefaultDecomposer, ModularOpsU64, NttBackendU64, + }; + + use super::*; + + set_parameter_set(crate::ParameterSelector::InteractiveLTE8Party); + set_common_reference_seed(InteractiveMultiPartyCrs::random().seed); + let parties = 8; + + let mut server_key_stats = ServerKeyStats::default(); + let mut server_key_share_size = 0usize; + + for i in 0..2 { + let cks = (0..parties).map(|_| gen_client_key()).collect_vec(); + let pk_shares = cks.iter().map(|k| collective_pk_share(k)).collect_vec(); + + let pk = aggregate_public_key_shares(&pk_shares); + let server_key_shares = cks + .iter() + .enumerate() + .map(|(index, k)| collective_server_key_share(k, index, parties, &pk)) + .collect_vec(); + + // In 0th iteration measure server key size + if i == 0 { + // Server key share size of user with last id may not equal server key share + // sizes of other users if LWE dimension does not divides number of parties. + server_key_share_size = std::cmp::max( + server_key_shares.first().unwrap().size(), + server_key_shares.last().unwrap().size(), + ); + } + + // println!("Size: {}", server_key_shares[0].size()); + let seeded_server_key = aggregate_server_key_shares(&server_key_shares); + let server_key_eval = + ServerKeyEvaluationDomain::<_, _, DefaultSecureRng, NttBackendU64>::from( + &seeded_server_key, + ); + + let parameters = BoolEvaluator::with_local(|e| e.parameters().clone()); + server_key_stats.merge_in(&collect_server_key_stats::< + _, + DefaultDecomposer, + NttBackendU64, + ModularOpsU64>, + _, + >(parameters, &cks, &server_key_eval)); + } + + println!( + "Common reference seeded server key share key size: {} Bits", + server_key_share_size + ); + + println!( + "Rgsw nsm std log2 {}", + server_key_stats.brk_rgsw_cts.0.std_dev().log2() + ); + println!( + "Rgsw m std log2 {}", + server_key_stats.brk_rgsw_cts.1.std_dev().log2() + ); + println!( + "rlwe post 1 auto std log2 {}", + server_key_stats.post_1_auto.std_dev().log2() + ); + println!( + "key switching noise rlwe secret s to lwe secret z std log2 {}", + server_key_stats.post_lwe_key_switch.std_dev().log2() + ); + } + + const K: usize = 10; + + #[test] + #[cfg(feature = "interactive_mp")] + fn interactive_mp_bool_gates() { + use rand::{thread_rng, RngCore}; + + use crate::{ + aggregate_public_key_shares, aggregate_server_key_shares, + backend::Modulus, + bool::{ + keys::{ + tests::{ideal_sk_rlwe, measure_noise_lwe}, + ServerKeyEvaluationDomain, + }, + print_noise::collect_server_key_stats, + }, + collective_pk_share, collective_server_key_share, gen_client_key, + parameters::CiphertextModulus, + random::DefaultSecureRng, + set_common_reference_seed, set_parameter_set, + utils::{tests::Stats, Global, WithLocal}, + BoolEvaluator, BooleanGates, DefaultDecomposer, Encoder, Encryptor, ModInit, + ModularOpsU64, MultiPartyDecryptor, NttBackendU64, ParameterSelector, RuntimeServerKey, + }; + + set_parameter_set(ParameterSelector::InteractiveLTE8Party); + + let mut seed = [0u8; 32]; + thread_rng().fill_bytes(&mut seed); + set_common_reference_seed(seed); + + let no_of_parties = 8; + + let cks = (0..no_of_parties).map(|_| gen_client_key()).collect_vec(); + + // round 1 + let pk_shares = cks.iter().map(|k| collective_pk_share(k)).collect_vec(); + + let pk = aggregate_public_key_shares(&pk_shares); + + // round 2 + let server_key_shares = cks + .iter() + .enumerate() + .map(|(user_id, k)| collective_server_key_share(k, user_id, no_of_parties, &pk)) + .collect_vec(); + + let server_key = aggregate_server_key_shares(&server_key_shares); + server_key.set_server_key(); + + let mut m0 = false; + let mut m1 = true; + + let mut ct0 = pk.encrypt(&m0); + let mut ct1 = pk.encrypt(&m1); + + let ideal_sk_rlwe = ideal_sk_rlwe(&cks); + let parameters = BoolEvaluator::with_local(|e| e.parameters().clone()); + let rlwe_modop = ModularOpsU64::new(*parameters.rlwe_q()); + + let mut stats = Stats::new(); + + for _ in 0..K { + // let now = std::time::Instant::now(); + let ct_out = + BoolEvaluator::with_local_mut(|e| e.xor(&ct0, &ct1, RuntimeServerKey::global())); + // println!("Time: {:?}", now.elapsed()); + + let m_expected = m0 ^ m1; + + let decryption_shares = cks + .iter() + .map(|k| k.gen_decryption_share(&ct_out)) + .collect_vec(); + let m_out = cks[0].aggregate_decryption_shares(&ct_out, &decryption_shares); + + assert!(m_out == m_expected, "Expected {m_expected}, got {m_out}"); + + { + let noise = measure_noise_lwe( + &ct_out, + parameters.rlwe_q().encode(m_expected), + &ideal_sk_rlwe, + &rlwe_modop, + ); + stats.add_sample(parameters.rlwe_q().map_element_to_i64(&noise)); + } + + m1 = m0; + m0 = m_expected; + + ct1 = ct0; + ct0 = ct_out; + } + + let server_key_stats = collect_server_key_stats::< + _, + DefaultDecomposer, + NttBackendU64, + ModularOpsU64>, + _, + >( + parameters, + &cks, + &ServerKeyEvaluationDomain::<_, _, DefaultSecureRng, NttBackendU64>::from(&server_key), + ); + + println!("## Bootstrapping Statistics ##"); + println!("Bootstrapped ciphertext noise std_dev: {}", stats.std_dev()); + + println!("## Key Statistics ##"); + println!( + "Rgsw nsm std_dev {}", + server_key_stats.brk_rgsw_cts.0.std_dev() + ); + println!( + "Rgsw m std_dev {}", + server_key_stats.brk_rgsw_cts.1.std_dev() + ); + println!( + "rlwe post 1 auto std_dev {}", + server_key_stats.post_1_auto.std_dev() + ); + println!( + "key switching noise rlwe secret s to lwe secret z std_dev {}", + server_key_stats.post_lwe_key_switch.std_dev() + ); + println!(); + } + + #[test] + #[cfg(feature = "non_interactive_mp")] + fn non_interactive_mp_bool_gates() { + use rand::{thread_rng, RngCore}; + + use crate::{ + aggregate_server_key_shares, + backend::Modulus, + bool::{ + keys::{ + tests::{ideal_sk_rlwe, measure_noise_lwe}, + NonInteractiveServerKeyEvaluationDomain, + }, + print_noise::collect_server_key_stats, + NonInteractiveBatchedFheBools, + }, + gen_client_key, gen_server_key_share, + parameters::CiphertextModulus, + random::DefaultSecureRng, + set_common_reference_seed, set_parameter_set, + utils::{tests::Stats, Global, WithLocal}, + BoolEvaluator, BooleanGates, DefaultDecomposer, Encoder, Encryptor, KeySwitchWithId, + ModInit, ModularOpsU64, MultiPartyDecryptor, NttBackendU64, ParameterSelector, + RuntimeServerKey, + }; + + set_parameter_set(ParameterSelector::NonInteractiveLTE8Party); + let mut seed = [0u8; 32]; + thread_rng().fill_bytes(&mut seed); + set_common_reference_seed(seed); + + let parties = 8; + + let cks = (0..parties).map(|_| gen_client_key()).collect_vec(); + + let server_key_shares = cks + .iter() + .enumerate() + .map(|(user_index, ck)| gen_server_key_share(user_index, parties, ck)) + .collect_vec(); + + let seeded_server_key = aggregate_server_key_shares(&server_key_shares); + seeded_server_key.set_server_key(); + + let parameters = BoolEvaluator::with_local(|e| e.parameters().clone()); + let rlwe_modop = ModularOpsU64::new(*parameters.rlwe_q()); + + let ideal_sk_rlwe = ideal_sk_rlwe(&cks); + + let mut m0 = false; + let mut m1 = true; + + let mut ct0 = { + let ct: NonInteractiveBatchedFheBools<_> = cks[0].encrypt(vec![m0].as_slice()); + let ct = ct.key_switch(0); + ct.extract(0) + }; + let mut ct1 = { + let ct: NonInteractiveBatchedFheBools<_> = cks[1].encrypt(vec![m1].as_slice()); + let ct = ct.key_switch(1); + ct.extract(0) + }; + + let mut stats = Stats::new(); + + for _ in 0..K { + // let now = std::time::Instant::now(); + let ct_out = + BoolEvaluator::with_local_mut(|e| e.xor(&ct0, &ct1, RuntimeServerKey::global())); + // println!("Time: {:?}", now.elapsed()); + + let decryption_shares = cks + .iter() + .map(|k| k.gen_decryption_share(&ct_out)) + .collect_vec(); + let m_out = cks[0].aggregate_decryption_shares(&ct_out, &decryption_shares); + + let m_expected = m0 ^ m1; + + { + let noise = measure_noise_lwe( + &ct_out, + parameters.rlwe_q().encode(m_expected), + &ideal_sk_rlwe, + &rlwe_modop, + ); + stats.add_sample(parameters.rlwe_q().map_element_to_i64(&noise)); + } + + assert!(m_out == m_expected, "Expected {m_expected} but got {m_out}"); + + m1 = m0; + m0 = m_out; + + ct1 = ct0; + ct0 = ct_out; + } + + // server key statistics + let server_key_stats = collect_server_key_stats::< + _, + DefaultDecomposer, + NttBackendU64, + ModularOpsU64>, + _, + >( + parameters, + &cks, + &NonInteractiveServerKeyEvaluationDomain::<_, _, DefaultSecureRng, NttBackendU64>::from( + &seeded_server_key, + ), + ); + + println!("## Bootstrapping Statistics ##"); + println!("Bootstrapped ciphertext noise std_dev: {}", stats.std_dev()); + + println!("## Key Statistics ##"); + println!( + "Rgsw nsm std_dev {}", + server_key_stats.brk_rgsw_cts.0.std_dev() + ); + println!( + "Rgsw m std_dev {}", + server_key_stats.brk_rgsw_cts.1.std_dev() + ); + println!( + "rlwe post 1 auto std_dev {}", + server_key_stats.post_1_auto.std_dev() + ); + println!( + "key switching noise rlwe secret s to lwe secret z std_dev {}", + server_key_stats.post_lwe_key_switch.std_dev() + ); + println!(); + } + + #[test] + #[cfg(feature = "non_interactive_mp")] + fn non_interactive_key_noise() { + use crate::{ + aggregate_server_key_shares, + bool::{ + evaluator::NonInteractiveMultiPartyCrs, + keys::{key_size::KeySize, NonInteractiveServerKeyEvaluationDomain}, + }, + decomposer::DefaultDecomposer, + gen_client_key, gen_server_key_share, + parameters::CiphertextModulus, + random::DefaultSecureRng, + set_common_reference_seed, set_parameter_set, + utils::WithLocal, + BoolEvaluator, ModularOpsU64, NttBackendU64, + }; + + use super::*; + + set_parameter_set(crate::ParameterSelector::NonInteractiveLTE8Party); + set_common_reference_seed(NonInteractiveMultiPartyCrs::random().seed); + let parties = 8; + + let mut server_key_stats = ServerKeyStats::default(); + let mut server_key_share_size = 0; + for i in 0..2 { + let cks = (0..parties).map(|_| gen_client_key()).collect_vec(); + let server_key_shares = cks + .iter() + .enumerate() + .map(|(user_id, k)| gen_server_key_share(user_id, parties, k)) + .collect_vec(); + + // Collect server key size in the 0th iteration + if i == 0 { + // Server key share size may differ for user with last id from + // the share size of other users if the LWE dimension `n` is not + // divisible by no. of parties. + server_key_share_size = std::cmp::max( + server_key_shares.first().unwrap().size(), + server_key_shares.last().unwrap().size(), + ); + } + + let server_key = aggregate_server_key_shares(&server_key_shares); + + let server_key_eval = NonInteractiveServerKeyEvaluationDomain::< + _, + _, + DefaultSecureRng, + NttBackendU64, + >::from(&server_key); + + let parameters = BoolEvaluator::with_local(|e| e.parameters().clone()); + server_key_stats.merge_in(&collect_server_key_stats::< + _, + DefaultDecomposer, + NttBackendU64, + ModularOpsU64>, + _, + >(parameters, &cks, &server_key_eval)); + } + + println!( + "Common reference seeded server key share key size: {} Bits", + server_key_share_size + ); + println!( + "Rgsw nsm std log2 {}", + server_key_stats.brk_rgsw_cts.0.std_dev().abs().log2() + ); + println!( + "Rgsw m std log2 {}", + server_key_stats.brk_rgsw_cts.1.std_dev().abs().log2() + ); + println!( + "rlwe post 1 auto std log2 {}", + server_key_stats.post_1_auto.std_dev().abs().log2() + ); + println!( + "key switching noise rlwe secret s to lwe secret z std log2 {}", + server_key_stats.post_lwe_key_switch.std_dev().abs().log2() + ); + } + + #[test] + #[cfg(feature = "non_interactive_mp")] + fn enc_under_sk_and_key_switch() { + use rand::{thread_rng, Rng}; + + use crate::{ + aggregate_server_key_shares, + bool::{keys::tests::ideal_sk_rlwe, ni_mp_api::NonInteractiveBatchedFheBools}, + gen_client_key, gen_server_key_share, + rgsw::decrypt_rlwe, + set_common_reference_seed, set_parameter_set, + utils::{tests::Stats, TryConvertFrom1, WithLocal}, + BoolEvaluator, Encoder, Encryptor, KeySwitchWithId, ModInit, ModularOpsU64, + NttBackendU64, NttInit, ParameterSelector, VectorOps, + }; + + set_parameter_set(ParameterSelector::NonInteractiveLTE2Party); + set_common_reference_seed([2; 32]); + + let parties = 2; + + let cks = (0..parties).map(|_| gen_client_key()).collect_vec(); + + let key_shares = cks + .iter() + .enumerate() + .map(|(user_index, ck)| gen_server_key_share(user_index, parties, ck)) + .collect_vec(); + + let seeded_server_key = aggregate_server_key_shares(&key_shares); + seeded_server_key.set_server_key(); + + let parameters = BoolEvaluator::with_local(|e| e.parameters().clone()); + let nttop = NttBackendU64::new(parameters.rlwe_q(), parameters.rlwe_n().0); + let rlwe_q_modop = ModularOpsU64::new(*parameters.rlwe_q()); + + let m = (0..parameters.rlwe_n().0) + .map(|_| thread_rng().gen_bool(0.5)) + .collect_vec(); + let ct: NonInteractiveBatchedFheBools<_> = cks[0].encrypt(m.as_slice()); + let ct = ct.key_switch(0); + + let ideal_rlwe_sk = ideal_sk_rlwe(&cks); + + let message = m + .iter() + .map(|b| parameters.rlwe_q().encode(*b)) + .collect_vec(); + + let mut m_out = vec![0u64; parameters.rlwe_n().0]; + decrypt_rlwe( + &ct.data[0], + &ideal_rlwe_sk, + &mut m_out, + &nttop, + &rlwe_q_modop, + ); + + let mut diff = m_out; + rlwe_q_modop.elwise_sub_mut(diff.as_mut_slice(), message.as_ref()); + + let mut stats = Stats::new(); + stats.add_many_samples(&Vec::::try_convert_from( + diff.as_slice(), + parameters.rlwe_q(), + )); + println!("Noise std log2: {}", stats.std_dev().abs().log2()); + } + + #[test] + fn mod_switch_noise() { + // Experiment to check mod switch noise using different secret dist in + // multi-party setting + + use itertools::izip; + use num_traits::ToPrimitive; + + use crate::{ + backend::{Modulus, ModulusPowerOf2}, + parameters::SecretKeyDistribution, + random::{DefaultSecureRng, RandomFillGaussian, RandomFillUniformInModulus}, + utils::{fill_random_ternary_secret_with_hamming_weight, tests::Stats}, + ArithmeticOps, ModInit, + }; + + fn mod_switch(v: u64, q_from: u64, q_to: u64) -> f64 { + (v as f64) * (q_to as f64) / q_from as f64 + } + + fn mod_switch_round(v: u64, q_from: u64, q_to: u64) -> u64 { + mod_switch(v, q_from, q_to).round().to_u64().unwrap() + } + + fn mod_switch_odd(v: u64, q_from: u64, q_to: u64) -> u64 { + let odd_v = mod_switch(v, q_from, q_to).floor().to_u64().unwrap(); + odd_v + ((odd_v & 1) ^ 1) + } + + fn sample_secret(n: usize, dist: &SecretKeyDistribution) -> Vec { + let mut s = vec![0i32; n]; + let mut rng = DefaultSecureRng::new(); + + match dist { + SecretKeyDistribution::ErrorDistribution => { + RandomFillGaussian::random_fill(&mut rng, s.as_mut_slice()); + } + SecretKeyDistribution::TernaryDistribution => { + fill_random_ternary_secret_with_hamming_weight(&mut s, n >> 1, &mut rng); + } + } + + s + } + + let parties = 2; + let q_from = 1 << 40; + let q_to = 1 << 20; + let n = 480; + let lweq_in_modop = ModulusPowerOf2::new(q_from); + let lweq_out_modop = ModulusPowerOf2::new(q_to); + let secret_dist = SecretKeyDistribution::ErrorDistribution; + + let mut stats_ms_noise = Stats::new(); + let mut stats_ms_rounding_err = Stats::new(); + + for _ in 0..1000000 { + let mut rng = DefaultSecureRng::new(); + + // sample secrets + + let s = { + let mut s = vec![0i32; n]; + for _ in 0..parties { + let temp = sample_secret(n, &secret_dist); + izip!(s.iter_mut(), temp.iter()).for_each(|(si, ti)| { + *si = *si + *ti; + }); + } + s + }; + + let m = 10; + + // LWE encryption without noise + let mut lwe_in = vec![0u64; n + 1]; + { + RandomFillUniformInModulus::random_fill(&mut rng, &q_from, &mut lwe_in[1..]); + let mut b = m; + izip!(lwe_in.iter().skip(1), s.iter()).for_each(|(ai, si)| { + b = lweq_in_modop.add( + &b, + &lweq_in_modop.mul(ai, &q_from.map_element_from_i64(*si as i64)), + ); + }); + lwe_in[0] = b; + } + + // Mod switch + let lwe_out = lwe_in + .iter() + .map(|v| { + // mod_switch_round(*v, q_from, q_to) + mod_switch_odd(*v, q_from, q_to) + }) + .collect_vec(); + + let rounding_errors = izip!(lwe_out.iter(), lwe_in.iter()) + .map(|(v_out, v_in)| { + let r_i = mod_switch(*v_in, q_from, q_to) - (*v_out as f64); + r_i + }) + .collect_vec(); + stats_ms_rounding_err.add_many_samples(&rounding_errors); + + // LWE decrypt and calculate ms noise + let mut m_back = 0; + izip!(lwe_out.iter().skip(1), s.iter()).for_each(|(ai, si)| { + m_back = lweq_out_modop.add( + &m_back, + &lweq_out_modop.mul(ai, &q_from.map_element_from_i64(*si as i64)), + ); + }); + m_back = lweq_out_modop.sub(&lwe_out[0], &m_back); + let noise = lweq_out_modop.sub(&m_back, &m); + stats_ms_noise.add_many_samples(&vec![q_to.map_element_to_i64(&noise)]); + } + + println!("ms noise variance: {}", stats_ms_noise.variance()); + println!("ms rounding errors mean: {}", stats_ms_rounding_err.mean()); + println!( + "ms rounding errors variance: {}", + stats_ms_rounding_err.variance() + ); + } +} diff --git a/src/decomposer.rs b/src/decomposer.rs index db9238e..d50ae7f 100644 --- a/src/decomposer.rs +++ b/src/decomposer.rs @@ -1,32 +1,114 @@ -use itertools::Itertools; -use num_traits::{AsPrimitive, One, PrimInt, ToPrimitive, WrappingSub, Zero}; -use std::{fmt::Debug, marker::PhantomData, ops::Rem}; +use itertools::{izip, Itertools}; +use num_traits::{FromPrimitive, PrimInt, ToPrimitive, WrappingAdd, WrappingSub}; +use std::fmt::{Debug, Display}; -use crate::backend::{ArithmeticOps, ModularOpsU64}; +use crate::{ + backend::ArithmeticOps, + parameters::{ + DecompositionCount, DecompostionLogBase, DoubleDecomposerParams, SingleDecomposerParams, + }, + utils::log2, +}; -pub fn gadget_vector(logq: usize, logb: usize, d: usize) -> Vec { - let d_ideal = (logq as f64 / logb as f64).ceil().to_usize().unwrap(); - let ignored_limbs = d_ideal - d; - (ignored_limbs..ignored_limbs + d) +fn gadget_vector(logq: usize, logb: usize, d: usize) -> Vec { + assert!(logq >= (logb * d)); + let ignored_bits = logq - (logb * d); + + (0..d) .into_iter() - .map(|i| T::one() << (logb * i)) + .map(|i| T::one() << (logb * i + ignored_bits)) .collect_vec() } +pub trait RlweDecomposer { + type Element; + type D: Decomposer; + + /// Decomposer for RLWE Part A + fn a(&self) -> &Self::D; + /// Decomposer for RLWE Part B + fn b(&self) -> &Self::D; +} + +impl RlweDecomposer for (D, D) +where + D: Decomposer, +{ + type D = D; + type Element = D::Element; + fn a(&self) -> &Self::D { + &self.0 + } + fn b(&self) -> &Self::D { + &self.1 + } +} + +impl DoubleDecomposerParams for D +where + D: RlweDecomposer, +{ + type Base = DecompostionLogBase; + type Count = DecompositionCount; + + fn decomposition_base(&self) -> Self::Base { + assert!( + Decomposer::decomposition_base(self.a()) == Decomposer::decomposition_base(self.b()) + ); + Decomposer::decomposition_base(self.a()) + } + fn decomposition_count_a(&self) -> Self::Count { + Decomposer::decomposition_count(self.a()) + } + fn decomposition_count_b(&self) -> Self::Count { + Decomposer::decomposition_count(self.b()) + } +} + +impl SingleDecomposerParams for D +where + D: Decomposer, +{ + type Base = DecompostionLogBase; + type Count = DecompositionCount; + fn decomposition_base(&self) -> Self::Base { + Decomposer::decomposition_base(self) + } + fn decomposition_count(&self) -> Self::Count { + Decomposer::decomposition_count(self) + } +} + pub trait Decomposer { type Element; - //FIXME(Jay): there's no reason why it returns a vec instead of an iterator - fn decompose(&self, v: &Self::Element) -> Vec; - fn d(&self) -> usize; + type Iter: Iterator; + fn new(q: Self::Element, logb: usize, d: usize) -> Self; + + fn decompose_to_vec(&self, v: &Self::Element) -> Vec; + fn decompose_iter(&self, v: &Self::Element) -> Self::Iter; + fn decomposition_count(&self) -> DecompositionCount; + fn decomposition_base(&self) -> DecompostionLogBase; + fn gadget_vector(&self) -> Vec; } pub struct DefaultDecomposer { + /// Ciphertext modulus q: T, + /// Log of ciphertext modulus logq: usize, + /// Log of base B logb: usize, + /// base B + b: T, + /// (B - 1). To simulate (% B) as &(B-1), that is extract least significant + /// logb bits + b_mask: T, + /// B/2 + bby2: T, + /// Decomposition count d: usize, + /// No. of bits to ignore in rounding ignore_bits: usize, - ignore_limbs: usize, } pub trait NumInfo { @@ -44,121 +126,246 @@ impl NumInfo for u128 { } impl DefaultDecomposer { - pub fn new(q: T, logb: usize, d: usize) -> DefaultDecomposer { - // if q is power of 2, then BITS - leading zeros outputs logq + 1. - let logq = if q & (q - T::one()) == T::zero() { - (T::BITS - q.leading_zeros() - 1) as usize - } else { - (T::BITS - q.leading_zeros()) as usize - }; + fn recompose(&self, limbs: &[T], modq_op: &Op) -> T + where + Op: ArithmeticOps, + { + let mut value = T::zero(); + let gadget_vector = gadget_vector(self.logq, self.logb, self.d); + assert!(limbs.len() == gadget_vector.len()); + izip!(limbs.iter(), gadget_vector.iter()) + .for_each(|(d_el, beta)| value = modq_op.add(&value, &modq_op.mul(d_el, beta))); + + value + } +} + +impl< + T: PrimInt + + ToPrimitive + + FromPrimitive + + WrappingSub + + WrappingAdd + + NumInfo + + From + + Display + + Debug, + > Decomposer for DefaultDecomposer +{ + type Element = T; + type Iter = DecomposerIter; + + fn new(q: T, logb: usize, d: usize) -> DefaultDecomposer { + // if q is power of 2, then `BITS - leading_zeros` outputs logq + 1. + let logq = log2(&q); + assert!( + logq >= (logb * d), + "Decomposer wants logq >= logb*d but got logq={logq}, logb={logb}, d={d}" + ); - let d_ideal = (logq as f64 / logb as f64).ceil().to_usize().unwrap(); - let ignore_limbs = (d_ideal - d); - let ignore_bits = (d_ideal - d) * logb; + let ignore_bits = logq - (logb * d); DefaultDecomposer { q, logq, logb, + b: T::one() << logb, + b_mask: (T::one() << logb) - T::one(), + bby2: T::one() << (logb - 1), d, ignore_bits, - ignore_limbs, } } - fn recompose(&self, limbs: &[T], modq_op: &Op) -> T - where - Op: ArithmeticOps, - { - let mut value = T::zero(); - for i in self.ignore_limbs..self.ignore_limbs + self.d { - value = modq_op.add( - &value, - &(modq_op.mul(&limbs[i], &(T::one() << (self.logb * i)))), - ) + fn decompose_to_vec(&self, value: &T) -> Vec { + let q = self.q; + let logb = self.logb; + let b = T::one() << logb; + let full_mask = b - T::one(); + let bby2 = b >> 1; + + let mut value = *value; + if value >= (q >> 1) { + value = !(q - value) + T::one() } - value - } -} + value = round_value(value, self.ignore_bits); + let mut out = Vec::with_capacity(self.d); + for _ in 0..(self.d) { + let k_i = value & full_mask; -impl Decomposer for DefaultDecomposer { - type Element = T; - fn decompose(&self, value: &T) -> Vec { - let value = round_value(*value, self.ignore_bits); + value = (value - k_i) >> logb; - let q = self.q; - let logb = self.logb; - // let b = T::one() << logb; // base - let b_by2 = T::one() << (logb - 1); - // let neg_b_by2_modq = q - b_by2; - let full_mask = (T::one() << logb) - T::one(); - // let half_mask = b_by2 - T::one(); - let mut carry = T::zero(); - let mut out = Vec::::with_capacity(self.d); - for i in 0..self.d { - let mut limb = ((value >> (logb * i)) & full_mask) + carry; - - carry = limb & b_by2; - limb = (q + limb) - (carry << 1); - if limb > q { - limb = limb - q; + if k_i > bby2 || (k_i == bby2 && ((value & T::one()) == T::one())) { + out.push(q - (b - k_i)); + value = value + T::one(); + } else { + out.push(k_i); } - out.push(limb); - - carry = carry >> (logb - 1); } return out; } - fn d(&self) -> usize { - self.d + fn decomposition_count(&self) -> DecompositionCount { + DecompositionCount(self.d) + } + + fn decomposition_base(&self) -> DecompostionLogBase { + DecompostionLogBase(self.logb) + } + + fn decompose_iter(&self, value: &T) -> DecomposerIter { + let mut value = *value; + if value >= (self.q >> 1) { + value = !(self.q - value) + T::one() + } + value = round_value(value, self.ignore_bits); + + DecomposerIter { + value, + q: self.q, + logq: self.logq, + logb: self.logb, + b: self.b, + bby2: self.bby2, + b_mask: self.b_mask, + steps_left: self.d, + } + } + + fn gadget_vector(&self) -> Vec { + return gadget_vector(self.logq, self.logb, self.d); } } -fn round_value(value: T, ignore_bits: usize) -> T { +impl DefaultDecomposer {} + +pub struct DecomposerIter { + /// Value to decompose + value: T, + steps_left: usize, + /// (1 << logb) - 1 (for % (1< + WrappingSub + Display> Iterator for DecomposerIter { + type Item = T; + + fn next(&mut self) -> Option { + if self.steps_left != 0 { + self.steps_left -= 1; + let k_i = self.value & self.b_mask; + + self.value = (self.value - k_i) >> self.logb; + + // if k_i > self.bby2 || (k_i == self.bby2 && ((self.value & + // T::one()) == T::one())) { self.value = self.value + // + T::one(); Some(self.q + k_i - self.b) + // } else { + // Some(k_i) + // } + + // Following is without branching impl of the commented version above. It + // happens to speed up bootstrapping for `SMALL_MP_BOOL_PARAMS` (& other + // parameters as well but I haven't tested) by roughly 15ms. + // Suprisingly the improvement does not show up when I benchmark + // `decomposer_iter` in isolation. Putting this remark here as a + // future task to investiage (TODO). + let carry_bool = + k_i > self.bby2 || (k_i == self.bby2 && ((self.value & T::one()) == T::one())); + let carry = >::from(carry_bool); + let neg_carry = (T::zero().wrapping_sub(&carry)); + self.value = self.value + carry; + Some((neg_carry & self.q) + k_i - (carry << self.logb)) + + // Some( + // (self.q & ((carry << self.logq) - (T::one() & carry))) + k_i + // - (carry << self.logb), ) + + // Some(k_i) + } else { + None + } + } +} + +fn round_value(value: T, ignore_bits: usize) -> T { if ignore_bits == 0 { return value; } let ignored_msb = (value & ((T::one() << ignore_bits) - T::one())) >> (ignore_bits - 1); - (value >> ignore_bits) + ignored_msb + (value >> ignore_bits).wrapping_add(&ignored_msb) } #[cfg(test)] mod tests { + + use itertools::Itertools; use rand::{thread_rng, Rng}; - use crate::{backend::ModularOpsU64, decomposer::round_value, utils::generate_prime}; + use crate::{ + backend::{ModInit, ModularOpsU64}, + decomposer::round_value, + utils::generate_prime, + }; use super::{Decomposer, DefaultDecomposer}; #[test] fn decomposition_works() { - let logq = 50; - let logb = 5; - let d = 10; - - // q is prime of bits logq and i is true, other q = 1< { type MatElement; type R: Row; @@ -34,6 +45,13 @@ pub trait Matrix: AsRef<[Self::R]> { fn get(&self, row_idx: usize, column_idx: usize) -> &Self::MatElement { &self.as_ref()[row_idx].as_ref()[column_idx] } + + fn split_at_row(&self, idx: usize) -> (&[::R], &[::R]) { + self.as_ref().split_at(idx) + } + + /// Does the matrix fit sub-matrix of dimension row x col + fn fits(&self, row: usize, col: usize) -> bool; } pub trait MatrixMut: Matrix + AsMut<[::R]> @@ -52,7 +70,7 @@ where self.as_mut()[row_idx].as_mut()[column_idx] = val; } - fn split_at_row( + fn split_at_row_mut( &mut self, idx: usize, ) -> (&mut [::R], &mut [::R]) { @@ -72,9 +90,8 @@ pub trait Row: AsRef<[Self::Element]> { pub trait RowMut: Row + AsMut<[::Element]> {} -trait Secret { - type Element; - fn values(&self) -> &[Self::Element]; +pub trait RowEntity: Row { + fn zeros(col: usize) -> Self; } impl Matrix for Vec> { @@ -84,9 +101,40 @@ impl Matrix for Vec> { fn dimension(&self) -> (usize, usize) { (self.len(), self[0].len()) } + + fn fits(&self, row: usize, col: usize) -> bool { + self.len() >= row && self[0].len() >= col + } +} + +impl Matrix for &[Vec] { + type MatElement = T; + type R = Vec; + + fn dimension(&self) -> (usize, usize) { + (self.len(), self[0].len()) + } + + fn fits(&self, row: usize, col: usize) -> bool { + self.len() >= row && self[0].len() >= col + } +} + +impl Matrix for &mut [Vec] { + type MatElement = T; + type R = Vec; + + fn dimension(&self) -> (usize, usize) { + (self.len(), self[0].len()) + } + + fn fits(&self, row: usize, col: usize) -> bool { + self.len() >= row && self[0].len() >= col + } } impl MatrixMut for Vec> {} +impl MatrixMut for &mut [Vec] {} impl MatrixEntity for Vec> { fn zeros(row: usize, col: usize) -> Self { @@ -98,4 +146,66 @@ impl Row for Vec { type Element = T; } +impl Row for [T] { + type Element = T; +} + impl RowMut for Vec {} + +impl RowEntity for Vec { + fn zeros(col: usize) -> Self { + vec![T::zero(); col] + } +} + +pub trait Encryptor { + fn encrypt(&self, m: &M) -> C; +} + +pub trait Decryptor { + fn decrypt(&self, c: &C) -> M; +} + +pub trait MultiPartyDecryptor { + type DecryptionShare; + + fn gen_decryption_share(&self, c: &C) -> Self::DecryptionShare; + fn aggregate_decryption_shares(&self, c: &C, shares: &[Self::DecryptionShare]) -> M; +} + +pub trait KeySwitchWithId { + fn key_switch(&self, user_id: usize) -> C; +} + +pub trait SampleExtractor { + /// Extract ciphertext at `index` + fn extract_at(&self, index: usize) -> R; + /// Extract all ciphertexts + fn extract_all(&self) -> Vec; + /// Extract first `how_many` ciphertexts + fn extract_many(&self, how_many: usize) -> Vec; +} + +trait Encoder { + fn encode(&self, v: F) -> T; +} + +trait SizeInBitsWithLogModulus { + /// Returns size of `Self` containing several elements mod Q where + /// 2^{log_modulus-1} < Q <= `2^log_modulus` + fn size(&self, log_modulus: usize) -> usize; +} + +impl SizeInBitsWithLogModulus for Vec> { + fn size(&self, log_modulus: usize) -> usize { + let mut total = 0; + self.iter().for_each(|r| total += log_modulus * r.len()); + total + } +} + +impl SizeInBitsWithLogModulus for Vec { + fn size(&self, log_modulus: usize) -> usize { + self.len() * log_modulus + } +} diff --git a/src/lwe.rs b/src/lwe.rs index bb3cf35..aa2188c 100644 --- a/src/lwe.rs +++ b/src/lwe.rs @@ -1,220 +1,250 @@ use std::fmt::Debug; -use itertools::{izip, Itertools}; -use num_traits::{abs, Zero}; +use itertools::izip; +use num_traits::Zero; use crate::{ - backend::{ArithmeticOps, VectorOps}, + backend::{ArithmeticOps, GetModulus, VectorOps}, decomposer::Decomposer, - lwe, - num::UnsignedInteger, - random::{DefaultSecureRng, RandomGaussianDist, RandomUniformDist, DEFAULT_RNG}, - utils::{fill_random_ternary_secret_with_hamming_weight, TryConvertFrom, WithLocal}, - Matrix, MatrixEntity, MatrixMut, Row, RowMut, Secret, + random::{RandomFillUniformInModulus, RandomGaussianElementInModulus}, + utils::TryConvertFrom1, + Matrix, Row, RowEntity, RowMut, }; -trait LweKeySwitchParameters { - fn n_in(&self) -> usize; - fn n_out(&self) -> usize; - fn d_ks(&self) -> usize; -} - -trait LweCiphertext {} - -struct LweSecret { - values: Vec, -} - -impl Secret for LweSecret { - type Element = i32; - fn values(&self) -> &[Self::Element] { - &self.values - } -} - -impl LweSecret { - fn random(hw: usize, n: usize) -> LweSecret { - DefaultSecureRng::with_local_mut(|rng| { - let mut out = vec![0i32; n]; - fill_random_ternary_secret_with_hamming_weight(&mut out, hw, rng); - - LweSecret { values: out } - }) - } -} - -fn lwe_key_switch< +pub(crate) fn lwe_key_switch< M: Matrix, - Mmut: MatrixMut + MatrixEntity, + Ro: AsMut<[M::MatElement]> + AsRef<[M::MatElement]>, Op: VectorOps + ArithmeticOps, D: Decomposer, >( - lwe_out: &mut Mmut, - lwe_in: &M, + lwe_out: &mut Ro, + lwe_in: &Ro, lwe_ksk: &M, operator: &Op, decomposer: &D, -) where - ::R: RowMut, -{ - assert!(lwe_ksk.dimension().0 == ((lwe_in.dimension().1 - 1) * decomposer.d())); - assert!(lwe_out.dimension() == (1, lwe_ksk.dimension().1)); - - let mut scratch_space = Mmut::zeros(1, lwe_out.dimension().1); +) { + assert!( + lwe_ksk.dimension().0 == ((lwe_in.as_ref().len() - 1) * decomposer.decomposition_count().0) + ); + assert!(lwe_out.as_ref().len() == lwe_ksk.dimension().1); let lwe_in_a_decomposed = lwe_in - .get_row(0) + .as_ref() + .iter() .skip(1) - .flat_map(|ai| decomposer.decompose(ai)); + .flat_map(|ai| decomposer.decompose_iter(ai)); izip!(lwe_in_a_decomposed, lwe_ksk.iter_rows()).for_each(|(ai_j, beta_ij_lwe)| { - operator.elwise_scalar_mul(scratch_space.get_row_mut(0), beta_ij_lwe.as_ref(), &ai_j); - operator.elwise_add_mut(lwe_out.get_row_mut(0), scratch_space.get_row_slice(0)) + // let now = std::time::Instant::now(); + operator.elwise_fma_scalar_mut(lwe_out.as_mut(), beta_ij_lwe.as_ref(), &ai_j); + // println!("Time elwise_fma_scalar_mut: {:?}", now.elapsed()); }); - let out_b = operator.add(lwe_out.get(0, 0), lwe_in.get(0, 0)); - lwe_out.set(0, 0, out_b); + let out_b = operator.add(&lwe_out.as_ref()[0], &lwe_in.as_ref()[0]); + lwe_out.as_mut()[0] = out_b; } -fn lwe_ksk_keygen< - Mmut: MatrixMut, - S: Secret, - Op: VectorOps + ArithmeticOps, - R: RandomGaussianDist - + RandomUniformDist<[Mmut::MatElement], Parameters = Mmut::MatElement>, +pub(crate) fn seeded_lwe_ksk_keygen< + Ro: RowMut + RowEntity, + S, + Op: VectorOps + + ArithmeticOps + + GetModulus, + R: RandomGaussianElementInModulus, + PR: RandomFillUniformInModulus<[Ro::Element], Op::M>, >( - lwe_sk_in: &S, - lwe_sk_out: &S, - ksk_out: &mut Mmut, - gadget: &[Mmut::MatElement], + from_lwe_sk: &[S], + to_lwe_sk: &[S], + gadget: &[Ro::Element], operator: &Op, + p_rng: &mut PR, rng: &mut R, -) where - ::R: RowMut, - Mmut: TryConvertFrom<[S::Element], Parameters = Mmut::MatElement>, - Mmut::MatElement: Zero + Debug, +) -> Ro +where + Ro: TryConvertFrom1<[S], Op::M>, + Ro::Element: Zero + Debug, { - assert!( - ksk_out.dimension() - == ( - lwe_sk_in.values().len() * gadget.len(), - lwe_sk_out.values().len() + 1, - ) - ); + let mut ksk_out = Ro::zeros(from_lwe_sk.len() * gadget.len()); let d = gadget.len(); - let modulus = VectorOps::modulus(operator); - let mut neg_sk_in_m = Mmut::try_convert_from(lwe_sk_in.values(), &modulus); - operator.elwise_neg_mut(neg_sk_in_m.get_row_mut(0)); - let sk_out_m = Mmut::try_convert_from(lwe_sk_out.values(), &modulus); - - izip!( - neg_sk_in_m.get_row(0), - ksk_out.iter_rows_mut().chunks(d).into_iter() - ) - .for_each(|(neg_sk_in_si, d_ks_lwes)| { - izip!(gadget.iter(), d_ks_lwes.into_iter()).for_each(|(f, lwe)| { - // sample `a` - RandomUniformDist::random_fill(rng, &modulus, &mut lwe.as_mut()[1..]); - - // a * z - let mut az = Mmut::MatElement::zero(); - izip!(lwe.as_ref()[1..].iter(), sk_out_m.get_row(0)).for_each(|(ai, si)| { - let ai_si = operator.mul(ai, si); - az = operator.add(&az, &ai_si); - }); - - // a*z + (-s_i)*\beta^j + e - let mut b = operator.add(&az, &operator.mul(f, neg_sk_in_si)); - let mut e = Mmut::MatElement::zero(); - RandomGaussianDist::random_fill(rng, &modulus, &mut e); - b = operator.add(&b, &e); - - lwe.as_mut()[0] = b; + let modulus = operator.modulus(); + let mut neg_sk_in_m = Ro::try_convert_from(from_lwe_sk, modulus); + operator.elwise_neg_mut(neg_sk_in_m.as_mut()); + let sk_out_m = Ro::try_convert_from(to_lwe_sk, modulus); + + let mut scratch = Ro::zeros(to_lwe_sk.len()); + + izip!(neg_sk_in_m.as_ref(), ksk_out.as_mut().chunks_mut(d)).for_each( + |(neg_sk_in_si, d_lwes_partb)| { + izip!(gadget.iter(), d_lwes_partb.into_iter()).for_each(|(beta, lwe_b)| { + // sample `a` + RandomFillUniformInModulus::random_fill(p_rng, &modulus, scratch.as_mut()); + + // a * z + let mut az = Ro::Element::zero(); + izip!(scratch.as_ref().iter(), sk_out_m.as_ref()).for_each(|(ai, si)| { + let ai_si = operator.mul(ai, si); + az = operator.add(&az, &ai_si); + }); + + // a*z + (-s_i)*\beta^j + e + let mut b = operator.add(&az, &operator.mul(beta, neg_sk_in_si)); + let e = RandomGaussianElementInModulus::random(rng, &modulus); + b = operator.add(&b, &e); + + *lwe_b = b; + }) + }, + ); - // dbg!(&lwe.as_mut(), &f); - }) - }); + ksk_out } /// Encrypts encoded message m as LWE ciphertext -fn encrypt_lwe< - Mmut: MatrixMut + MatrixEntity, - R: RandomGaussianDist - + RandomUniformDist<[Mmut::MatElement], Parameters = Mmut::MatElement>, - S: Secret, - Op: ArithmeticOps, +pub(crate) fn encrypt_lwe< + Ro: RowMut + RowEntity, + Op: ArithmeticOps + GetModulus, + R: RandomGaussianElementInModulus + + RandomFillUniformInModulus<[Ro::Element], Op::M>, + S, >( - lwe_out: &mut Mmut, - m: Mmut::MatElement, - s: &S, + m: &Ro::Element, + s: &[S], operator: &Op, rng: &mut R, -) where - Mmut: TryConvertFrom<[S::Element], Parameters = Mmut::MatElement>, - Mmut::MatElement: Zero, - ::R: RowMut, +) -> Ro +where + Ro: TryConvertFrom1<[S], Op::M>, + Ro::Element: Zero, { - let s = Mmut::try_convert_from(s.values(), &operator.modulus()); - assert!(s.dimension().0 == (lwe_out.dimension().0)); - assert!(s.dimension().1 == (lwe_out.dimension().1 - 1)); + let s = Ro::try_convert_from(s, operator.modulus()); + let mut lwe_out = Ro::zeros(s.as_ref().len() + 1); // a*s - RandomUniformDist::random_fill(rng, &operator.modulus(), &mut lwe_out.get_row_mut(0)[1..]); - let mut sa = Mmut::MatElement::zero(); - izip!(lwe_out.get_row(0).skip(1), s.get_row(0)).for_each(|(ai, si)| { + RandomFillUniformInModulus::random_fill(rng, operator.modulus(), &mut lwe_out.as_mut()[1..]); + let mut sa = Ro::Element::zero(); + izip!(lwe_out.as_mut().iter().skip(1), s.as_ref()).for_each(|(ai, si)| { let tmp = operator.mul(ai, si); sa = operator.add(&tmp, &sa); }); // b = a*s + e + m - let mut e = Mmut::MatElement::zero(); - RandomGaussianDist::random_fill(rng, &operator.modulus(), &mut e); - let b = operator.add(&operator.add(&sa, &e), &m); - lwe_out.set(0, 0, b); + let e = RandomGaussianElementInModulus::random(rng, operator.modulus()); + let b = operator.add(&operator.add(&sa, &e), m); + lwe_out.as_mut()[0] = b; + + lwe_out } -fn decrypt_lwe, S: Secret>( - lwe_ct: &M, - s: &S, +pub(crate) fn decrypt_lwe< + Ro: Row, + Op: ArithmeticOps + GetModulus, + S, +>( + lwe_ct: &Ro, + s: &[S], operator: &Op, -) -> M::MatElement +) -> Ro::Element where - M: TryConvertFrom<[S::Element], Parameters = M::MatElement>, - M::MatElement: Zero, + Ro: TryConvertFrom1<[S], Op::M>, + Ro::Element: Zero, { - let s = M::try_convert_from(s.values(), &operator.modulus()); + let s = Ro::try_convert_from(s, operator.modulus()); - let mut sa = M::MatElement::zero(); - izip!(lwe_ct.get_row(0).skip(1), s.get_row(0)).for_each(|(ai, si)| { + let mut sa = Ro::Element::zero(); + izip!(lwe_ct.as_ref().iter().skip(1), s.as_ref()).for_each(|(ai, si)| { let tmp = operator.mul(ai, si); sa = operator.add(&tmp, &sa); }); - let b = &lwe_ct.get_row_slice(0)[0]; + let b = &lwe_ct.as_ref()[0]; operator.sub(b, &sa) } #[cfg(test)] mod tests { + use std::marker::PhantomData; + + use itertools::izip; + use crate::{ - backend::ModularOpsU64, - decomposer::{gadget_vector, DefaultDecomposer}, - lwe::lwe_key_switch, - random::DefaultSecureRng, + backend::{ModInit, ModulusPowerOf2}, + decomposer::DefaultDecomposer, + random::{DefaultSecureRng, NewWithSeed}, + utils::{fill_random_ternary_secret_with_hamming_weight, WithLocal}, + MatrixEntity, MatrixMut, }; - use super::{decrypt_lwe, encrypt_lwe, lwe_ksk_keygen, LweSecret}; + use super::*; + + const K: usize = 50; + + #[derive(Clone)] + struct LweSecret { + pub(crate) values: Vec, + } + + impl LweSecret { + fn random(hw: usize, n: usize) -> LweSecret { + DefaultSecureRng::with_local_mut(|rng| { + let mut out = vec![0i32; n]; + fill_random_ternary_secret_with_hamming_weight(&mut out, hw, rng); + + LweSecret { values: out } + }) + } + + fn values(&self) -> &[i32] { + &self.values + } + } + + struct LweKeySwitchingKey { + data: M, + _phantom: PhantomData, + } + + impl< + M: MatrixMut + MatrixEntity, + R: NewWithSeed + RandomFillUniformInModulus<[M::MatElement], M::MatElement>, + > From<&(M::R, R::Seed, usize, M::MatElement)> for LweKeySwitchingKey + where + M::R: RowMut, + R::Seed: Clone, + M::MatElement: Copy, + { + fn from(value: &(M::R, R::Seed, usize, M::MatElement)) -> Self { + let data_in = &value.0; + let seed = &value.1; + let to_lwe_n = value.2; + let modulus = value.3; + + let mut p_rng = R::new_with_seed(seed.clone()); + let mut data = M::zeros(data_in.as_ref().len(), to_lwe_n + 1); + izip!(data_in.as_ref().iter(), data.iter_rows_mut()).for_each(|(bi, lwe_i)| { + RandomFillUniformInModulus::random_fill( + &mut p_rng, + &modulus, + &mut lwe_i.as_mut()[1..], + ); + lwe_i.as_mut()[0] = *bi; + }); + LweKeySwitchingKey { + data, + _phantom: PhantomData, + } + } + } #[test] fn encrypt_decrypt_works() { - let logq = 20; + let logq = 16; let q = 1u64 << logq; let lwe_n = 1024; let logp = 3; - let modq_op = ModularOpsU64::new(q); + let modq_op = ModulusPowerOf2::new(q); let lwe_sk = LweSecret::random(lwe_n >> 1, lwe_n); let mut rng = DefaultSecureRng::new(); @@ -222,9 +252,9 @@ mod tests { // encrypt for m in 0..1u64 << logp { let encoded_m = m << (logq - logp); - let mut lwe_ct = vec![vec![0u64; lwe_n + 1]]; - encrypt_lwe(&mut lwe_ct, encoded_m, &lwe_sk, &modq_op, &mut rng); - let encoded_m_back = decrypt_lwe(&lwe_ct, &lwe_sk, &modq_op); + let lwe_ct = + encrypt_lwe::, _, _, _>(&encoded_m, &lwe_sk.values(), &modq_op, &mut rng); + let encoded_m_back = decrypt_lwe(&lwe_ct, &lwe_sk.values(), &modq_op); let m_back = ((((encoded_m_back as f64) * ((1 << logp) as f64)) / q as f64).round() as u64) % (1u64 << logp); @@ -234,52 +264,69 @@ mod tests { #[test] fn key_switch_works() { - let logq = 16; - let logp = 3; + let logq = 20; + let logp = 2; let q = 1u64 << logq; - let lwe_in_n = 1024; - let lwe_out_n = 470; - let d_ks = 3; + let lwe_in_n = 2048; + let lwe_out_n = 600; + let d_ks = 5; let logb = 4; let lwe_sk_in = LweSecret::random(lwe_in_n >> 1, lwe_in_n); let lwe_sk_out = LweSecret::random(lwe_out_n >> 1, lwe_out_n); let mut rng = DefaultSecureRng::new(); - let modq_op = ModularOpsU64::new(q); + let modq_op = ModulusPowerOf2::new(q); // genrate ksk - for _ in 0..10 { - let mut ksk = vec![vec![0u64; lwe_out_n + 1]; d_ks * lwe_in_n]; - let gadget = gadget_vector(logq, logb, d_ks); - lwe_ksk_keygen( - &lwe_sk_in, - &lwe_sk_out, - &mut ksk, + for _ in 0..1 { + let mut ksk_seed = [0u8; 32]; + rng.fill_bytes(&mut ksk_seed); + let mut p_rng = DefaultSecureRng::new_seeded(ksk_seed); + let decomposer = DefaultDecomposer::new(q, logb, d_ks); + let gadget = decomposer.gadget_vector(); + let seeded_ksk = seeded_lwe_ksk_keygen( + &lwe_sk_in.values(), + &lwe_sk_out.values(), &gadget, &modq_op, + &mut p_rng, &mut rng, ); // println!("{:?}", ksk); + let ksk = LweKeySwitchingKey::>, DefaultSecureRng>::from(&( + seeded_ksk, ksk_seed, lwe_out_n, q, + )); for m in 0..(1 << logp) { // encrypt using lwe_sk_in let encoded_m = m << (logq - logp); - let mut lwe_in_ct = vec![vec![0u64; lwe_in_n + 1]]; - encrypt_lwe(&mut lwe_in_ct, encoded_m, &lwe_sk_in, &modq_op, &mut rng); + let lwe_in_ct = encrypt_lwe(&encoded_m, lwe_sk_in.values(), &modq_op, &mut rng); // key switch from lwe_sk_in to lwe_sk_out - let decomposer = DefaultDecomposer::new(1u64 << logq, logb, d_ks); - let mut lwe_out_ct = vec![vec![0u64; lwe_out_n + 1]]; - lwe_key_switch(&mut lwe_out_ct, &lwe_in_ct, &ksk, &modq_op, &decomposer); + let mut lwe_out_ct = vec![0u64; lwe_out_n + 1]; + let now = std::time::Instant::now(); + lwe_key_switch( + &mut lwe_out_ct, + &lwe_in_ct, + &ksk.data, + &modq_op, + &decomposer, + ); + println!("Time: {:?}", now.elapsed()); // decrypt lwe_out_ct using lwe_sk_out - let encoded_m_back = decrypt_lwe(&lwe_out_ct, &lwe_sk_out, &modq_op); - let m_back = ((((encoded_m_back as f64) * ((1 << logp) as f64)) / q as f64).round() - as u64) - % (1u64 << logp); - assert_eq!(m, m_back, "Expected {m} but got {m_back}"); - // dbg!(m, m_back); + // TODO(Jay): Fix me + // let encoded_m_back = decrypt_lwe(&lwe_out_ct, + // &lwe_sk_out.values(), &modq_op); let m_back = + // ((((encoded_m_back as f64) * ((1 << logp) as f64)) / q as + // f64).round() as u64) + // % (1u64 << logp); + // let noise = + // measure_noise_lwe(&lwe_out_ct, lwe_sk_out.values(), + // &modq_op, &encoded_m); println!("Noise: + // {noise}"); assert_eq!(m, m_back, "Expected + // {m} but got {m_back}"); dbg!(m, m_back); // dbg!(encoded_m, encoded_m_back); } } diff --git a/src/main.rs b/src/main.rs index e7a11a9..f328e4d 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,3 +1 @@ -fn main() { - println!("Hello, world!"); -} +fn main() {} diff --git a/src/multi_party.rs b/src/multi_party.rs new file mode 100644 index 0000000..3f51775 --- /dev/null +++ b/src/multi_party.rs @@ -0,0 +1,286 @@ +use std::fmt::Debug; + +use itertools::izip; +use num_traits::Zero; + +use crate::{ + backend::{GetModulus, Modulus, VectorOps}, + ntt::Ntt, + random::{ + RandomFillGaussianInModulus, RandomFillUniformInModulus, RandomGaussianElementInModulus, + }, + utils::TryConvertFrom1, + ArithmeticOps, Matrix, MatrixEntity, MatrixMut, Row, RowEntity, RowMut, +}; + +pub(crate) fn public_key_share< + R: Row + RowMut + RowEntity, + S, + ModOp: VectorOps + GetModulus, + NttOp: Ntt, + Rng: RandomFillGaussianInModulus<[R::Element], ModOp::M>, + PRng: RandomFillUniformInModulus<[R::Element], ModOp::M>, +>( + share_out: &mut R, + s_i: &[S], + modop: &ModOp, + nttop: &NttOp, + p_rng: &mut PRng, + rng: &mut Rng, +) where + R: TryConvertFrom1<[S], ModOp::M>, +{ + let ring_size = share_out.as_ref().len(); + assert!(s_i.len() == ring_size); + + let q = modop.modulus(); + + // sample a + let mut a = { + let mut a = R::zeros(ring_size); + RandomFillUniformInModulus::random_fill(p_rng, &q, a.as_mut()); + a + }; + + // s*a + nttop.forward(a.as_mut()); + let mut s = R::try_convert_from(s_i, &q); + nttop.forward(s.as_mut()); + modop.elwise_mul_mut(s.as_mut(), a.as_ref()); + nttop.backward(s.as_mut()); + + RandomFillGaussianInModulus::random_fill(rng, &q, share_out.as_mut()); + modop.elwise_add_mut(share_out.as_mut(), s.as_ref()); // s*e + e +} + +/// Generate decryption share for LWE ciphertext `lwe_ct` with user's secret `s` +pub(crate) fn multi_party_decryption_share< + R: RowMut + RowEntity, + Mod: Modulus, + ModOp: ArithmeticOps + VectorOps + GetModulus, + Rng: RandomGaussianElementInModulus, + S, +>( + lwe_ct: &R, + s: &[S], + mod_op: &ModOp, + rng: &mut Rng, +) -> R::Element +where + R: TryConvertFrom1<[S], Mod>, + R::Element: Zero, +{ + assert!(lwe_ct.as_ref().len() == s.len() + 1); + let mut neg_s = R::try_convert_from(s, mod_op.modulus()); + mod_op.elwise_neg_mut(neg_s.as_mut()); + + // share = (\sum -s_i * a_i) + e + let mut share = R::Element::zero(); + izip!(neg_s.as_ref().iter(), lwe_ct.as_ref().iter().skip(1)).for_each(|(si, ai)| { + share = mod_op.add(&share, &mod_op.mul(si, ai)); + }); + + let e = rng.random(mod_op.modulus()); + share = mod_op.add(&share, &e); + + share +} + +/// Aggregate decryption shares for `lwe_ct` and return noisy decryption output +/// `m + e` +pub(crate) fn multi_party_aggregate_decryption_shares_and_decrypt< + R: RowMut + RowEntity, + ModOp: ArithmeticOps, +>( + lwe_ct: &R, + shares: &[R::Element], + mod_op: &ModOp, +) -> R::Element +where + R::Element: Zero, +{ + let mut sum_shares = R::Element::zero(); + shares + .iter() + .for_each(|v| sum_shares = mod_op.add(&sum_shares, v)); + mod_op.add(&lwe_ct.as_ref()[0], &sum_shares) +} + +pub(crate) fn non_interactive_rgsw_ct< + M: MatrixMut + MatrixEntity, + S, + PRng: RandomFillUniformInModulus<[M::MatElement], ModOp::M>, + Rng: RandomFillGaussianInModulus<[M::MatElement], ModOp::M>, + NttOp: Ntt, + ModOp: VectorOps + GetModulus, +>( + s: &[S], + u: &[S], + m: &[M::MatElement], + gadget_vec: &[M::MatElement], + p_rng: &mut PRng, + rng: &mut Rng, + nttop: &NttOp, + modop: &ModOp, +) -> (M, M) +where + ::R: RowMut + TryConvertFrom1<[S], ModOp::M> + RowEntity, + M::MatElement: Copy, +{ + assert_eq!(s.len(), u.len()); + assert_eq!(s.len(), m.len()); + let q = modop.modulus(); + let d = gadget_vec.len(); + let ring_size = s.len(); + + let mut s_poly_eval = M::R::try_convert_from(s, q); + let mut u_poly_eval = M::R::try_convert_from(u, q); + nttop.forward(s_poly_eval.as_mut()); + nttop.forward(u_poly_eval.as_mut()); + + // encryptions of a_i*u + e + \beta m + let mut enc_beta_m = M::zeros(d, ring_size); + // zero encrypition: a_i*s + e' + let mut zero_encryptions = M::zeros(d, ring_size); + + let mut scratch_space = M::R::zeros(ring_size); + + izip!( + enc_beta_m.iter_rows_mut(), + zero_encryptions.iter_rows_mut(), + gadget_vec.iter() + ) + .for_each(|(e_beta_m, e_zero, beta)| { + // sample a_i + RandomFillUniformInModulus::random_fill(p_rng, q, e_beta_m.as_mut()); + e_zero.as_mut().copy_from_slice(e_beta_m.as_ref()); + + // a_i * u + \beta m + e // + // a_i * u + nttop.forward(e_beta_m.as_mut()); + modop.elwise_mul_mut(e_beta_m.as_mut(), u_poly_eval.as_ref()); + nttop.backward(e_beta_m.as_mut()); + // sample error e + RandomFillGaussianInModulus::random_fill(rng, q, scratch_space.as_mut()); + // a_i * u + e + modop.elwise_add_mut(e_beta_m.as_mut(), scratch_space.as_ref()); + // beta * m + modop.elwise_scalar_mul(scratch_space.as_mut(), m.as_ref(), beta); + // a_i * u + e + \beta m + modop.elwise_add_mut(e_beta_m.as_mut(), scratch_space.as_ref()); + + // a_i * s + e // + // a_i * s + nttop.forward(e_zero.as_mut()); + modop.elwise_mul_mut(e_zero.as_mut(), s_poly_eval.as_ref()); + nttop.backward(e_zero.as_mut()); + // sample error e + RandomFillGaussianInModulus::random_fill(rng, q, scratch_space.as_mut()); + // a_i * s + e + modop.elwise_add_mut(e_zero.as_mut(), scratch_space.as_ref()); + }); + + (enc_beta_m, zero_encryptions) +} + +pub(crate) fn non_interactive_ksk_gen< + M: MatrixMut + MatrixEntity, + S, + PRng: RandomFillUniformInModulus<[M::MatElement], ModOp::M>, + Rng: RandomFillGaussianInModulus<[M::MatElement], ModOp::M>, + NttOp: Ntt, + ModOp: VectorOps + GetModulus, +>( + s: &[S], + u: &[S], + gadget_vec: &[M::MatElement], + p_rng: &mut PRng, + rng: &mut Rng, + nttop: &NttOp, + modop: &ModOp, +) -> M +where + ::R: RowMut + TryConvertFrom1<[S], ModOp::M> + RowEntity, + M::MatElement: Copy + Debug, +{ + assert_eq!(s.len(), u.len()); + + let q = modop.modulus(); + let d = gadget_vec.len(); + let ring_size = s.len(); + + let mut s_poly_eval = M::R::try_convert_from(s, q); + nttop.forward(s_poly_eval.as_mut()); + let u_poly = M::R::try_convert_from(u, q); + // a_i * s + \beta u + e + let mut ksk = M::zeros(d, ring_size); + + let mut scratch_space = M::R::zeros(ring_size); + + izip!(ksk.iter_rows_mut(), gadget_vec.iter()).for_each(|(e_ksk, beta)| { + // sample a_i + RandomFillUniformInModulus::random_fill(p_rng, q, e_ksk.as_mut()); + + // a_i * s + e + beta u + nttop.forward(e_ksk.as_mut()); + modop.elwise_mul_mut(e_ksk.as_mut(), s_poly_eval.as_ref()); + nttop.backward(e_ksk.as_mut()); + // sample error e + RandomFillGaussianInModulus::random_fill(rng, q, scratch_space.as_mut()); + // a_i * s + e + modop.elwise_add_mut(e_ksk.as_mut(), scratch_space.as_ref()); + // \beta * u + modop.elwise_scalar_mul(scratch_space.as_mut(), u_poly.as_ref(), beta); + // a_i * s + e + \beta * u + modop.elwise_add_mut(e_ksk.as_mut(), scratch_space.as_ref()); + }); + + ksk +} + +pub(crate) fn non_interactive_ksk_zero_encryptions_for_other_party_i< + M: MatrixMut + MatrixEntity, + S, + PRng: RandomFillUniformInModulus<[M::MatElement], ModOp::M>, + Rng: RandomFillGaussianInModulus<[M::MatElement], ModOp::M>, + NttOp: Ntt, + ModOp: VectorOps + GetModulus, +>( + s: &[S], + gadget_vec: &[M::MatElement], + p_rng: &mut PRng, + rng: &mut Rng, + nttop: &NttOp, + modop: &ModOp, +) -> M +where + ::R: RowMut + TryConvertFrom1<[S], ModOp::M> + RowEntity, + M::MatElement: Copy + Debug, +{ + let q = modop.modulus(); + let d = gadget_vec.len(); + let ring_size = s.len(); + + let mut s_poly_eval = M::R::try_convert_from(s, q); + nttop.forward(s_poly_eval.as_mut()); + + // a_i * s + e + let mut zero_encs = M::zeros(d, ring_size); + + let mut scratch_space = M::R::zeros(ring_size); + + izip!(zero_encs.iter_rows_mut()).for_each(|e_zero| { + // sample a_i + RandomFillUniformInModulus::random_fill(p_rng, q, e_zero.as_mut()); + + // a_i * s + e + nttop.forward(e_zero.as_mut()); + modop.elwise_mul_mut(e_zero.as_mut(), s_poly_eval.as_ref()); + nttop.backward(e_zero.as_mut()); + // sample error e + RandomFillGaussianInModulus::random_fill(rng, q, scratch_space.as_mut()); + modop.elwise_add_mut(e_zero.as_mut(), scratch_space.as_ref()); + }); + + zero_encs +} diff --git a/src/ntt.rs b/src/ntt.rs index 320a28c..743aa32 100644 --- a/src/ntt.rs +++ b/src/ntt.rs @@ -1,11 +1,18 @@ -use itertools::Itertools; -use rand::{thread_rng, Rng, RngCore}; +use itertools::{izip, Itertools}; +use rand::{Rng, RngCore, SeedableRng}; +use rand_chacha::ChaCha8Rng; use crate::{ - backend::{ArithmeticOps, ModularOpsU64}, - utils::{mod_exponent, mod_inverse, shoup_representation_fq}, + backend::{ArithmeticOps, ModInit, ModularOpsU64, Modulus}, + utils::{mod_exponent, mod_inverse, ShoupMul}, }; +pub trait NttInit { + /// Ntt istance must be compatible across different instances with same `q` + /// and `n` + fn new(q: &M, n: usize) -> Self; +} + pub trait Ntt { type Element; fn forward_lazy(&self, v: &mut [Self::Element]); @@ -21,27 +28,50 @@ pub trait Ntt { /// and both x' and y' are \in [0, 4q) /// /// Implements Algorithm 4 of [FASTER ARITHMETIC FOR NUMBER-THEORETIC TRANSFORMS](https://arxiv.org/pdf/1205.2926.pdf) -pub unsafe fn forward_butterly( - x: *mut u64, - y: *mut u64, - w: &u64, - w_shoup: &u64, - q: &u64, - q_twice: &u64, -) { - debug_assert!(*x < *q * 4, "{} >= (4q){}", *x, 4 * q); - debug_assert!(*y < *q * 4, "{} >= (4q){}", *y, 4 * q); +pub fn forward_butterly_0_to_4q( + mut x: u64, + y: u64, + w: u64, + w_shoup: u64, + q: u64, + q_twice: u64, +) -> (u64, u64) { + debug_assert!(x < q * 4, "{} >= (4q){}", x, 4 * q); + debug_assert!(y < q * 4, "{} >= (4q){}", y, 4 * q); - if *x >= *q_twice { - *x = *x - q_twice; + if x >= q_twice { + x = x - q_twice; } - // TODO (Jay): Hot path expected. How expensive is it? - let k = ((*w_shoup as u128 * *y as u128) >> 64) as u64; - let t = w.wrapping_mul(*y).wrapping_sub(k.wrapping_mul(*q)); + let t = ShoupMul::mul(y, w, w_shoup, q); - *y = *x + q_twice - t; - *x = *x + t; + (x + t, x + q_twice - t) +} + +pub fn forward_butterly_0_to_2q( + mut x: u64, + y: u64, + w: u64, + w_shoup: u64, + q: u64, + q_twice: u64, +) -> (u64, u64) { + debug_assert!(x < q * 4, "{} >= (4q){}", x, 4 * q); + debug_assert!(y < q * 4, "{} >= (4q){}", y, 4 * q); + + if x >= q_twice { + x = x - q_twice; + } + + let t = ShoupMul::mul(y, w, w_shoup, q); + + let ox = x.wrapping_add(t); + let oy = x.wrapping_sub(t); + + ( + (ox).min(ox.wrapping_sub(q_twice)), + oy.min(oy.wrapping_add(q_twice)), + ) } /// Inverse butterfly routine of Inverse Number theoretic transform. Given @@ -51,27 +81,26 @@ pub unsafe fn forward_butterly( /// and both x' and y' are \in [0, 2q) /// /// Implements Algorithm 3 of [FASTER ARITHMETIC FOR NUMBER-THEORETIC TRANSFORMS](https://arxiv.org/pdf/1205.2926.pdf) -pub unsafe fn inverse_butterfly( - x: *mut u64, - y: *mut u64, - w_inv: &u64, - w_inv_shoup: &u64, - q: &u64, - q_twice: &u64, -) { - debug_assert!(*x < *q_twice, "{} >= (2q){q_twice}", *x); - debug_assert!(*y < *q_twice, "{} >= (2q){q_twice}", *y); +pub fn inverse_butterfly_0_to_2q( + x: u64, + y: u64, + w_inv: u64, + w_inv_shoup: u64, + q: u64, + q_twice: u64, +) -> (u64, u64) { + debug_assert!(x < q_twice, "{} >= (2q){q_twice}", x); + debug_assert!(y < q_twice, "{} >= (2q){q_twice}", y); - let mut x_dash = *x + *y; - if x_dash >= *q_twice { + let mut x_dash = x + y; + if x_dash >= q_twice { x_dash -= q_twice } - let t = *x + q_twice - *y; - let k = ((*w_inv_shoup as u128 * t as u128) >> 64) as u64; // TODO (Jay): Hot path - *y = w_inv.wrapping_mul(t).wrapping_sub(k.wrapping_mul(*q)); + let t = x + q_twice - y; + let y = ShoupMul::mul(t, w_inv, w_inv_shoup, q); - *x = x_dash; + (x_dash, y) } /// Number theoretic transform of vector `a` where each element can be in range @@ -79,7 +108,7 @@ pub unsafe fn inverse_butterfly( /// /// Implements Cooley-tukey based forward NTT as given in Algorithm 1 of https://eprint.iacr.org/2016/504.pdf. pub fn ntt_lazy(a: &mut [u64], psi: &[u64], psi_shoup: &[u64], q: u64, q_twice: u64) { - debug_assert!(a.len() == psi.len()); + assert!(a.len() == psi.len()); let n = a.len(); let mut t = n; @@ -87,30 +116,67 @@ pub fn ntt_lazy(a: &mut [u64], psi: &[u64], psi_shoup: &[u64], q: u64, q_twice: let mut m = 1; while m < n { t >>= 1; - - for i in 0..m { - let j_1 = 2 * i * t; - let j_2 = j_1 + t; - - unsafe { - let w = psi.get_unchecked(m + i); - let w_shoup = psi_shoup.get_unchecked(m + i); - for j in j_1..j_2 { - let x = a.get_unchecked_mut(j) as *mut u64; - let y = a.get_unchecked_mut(j + t) as *mut u64; - forward_butterly(x, y, w, w_shoup, &q, &q_twice); + let w = &psi[m..]; + let w_shoup = &psi_shoup[m..]; + + if t == 1 { + for (a, w, w_shoup) in izip!(a.chunks_mut(2), w.iter(), w_shoup.iter()) { + let (ox, oy) = forward_butterly_0_to_2q(a[0], a[1], *w, *w_shoup, q, q_twice); + a[0] = ox; + a[1] = oy; + } + } else { + for i in 0..m { + let a = &mut a[2 * i * t..(2 * (i + 1) * t)]; + let (left, right) = a.split_at_mut(t); + + for (x, y) in izip!(left.iter_mut(), right.iter_mut()) { + let (ox, oy) = forward_butterly_0_to_4q(*x, *y, w[i], w_shoup[i], q, q_twice); + *x = ox; + *y = oy; } } } m <<= 1; } +} + +/// Same as `ntt_lazy` with output in range [0, q) +pub fn ntt(a: &mut [u64], psi: &[u64], psi_shoup: &[u64], q: u64, q_twice: u64) { + assert!(a.len() == psi.len()); - a.iter_mut().for_each(|a0| { - if *a0 >= q_twice { - *a0 -= q_twice + let n = a.len(); + let mut t = n; + + let mut m = 1; + while m < n { + t >>= 1; + let w = &psi[m..]; + let w_shoup = &psi_shoup[m..]; + + if t == 1 { + for (a, w, w_shoup) in izip!(a.chunks_mut(2), w.iter(), w_shoup.iter()) { + let (ox, oy) = forward_butterly_0_to_2q(a[0], a[1], *w, *w_shoup, q, q_twice); + // reduce from range [0, 2q) to [0, q) + a[0] = ox.min(ox.wrapping_sub(q)); + a[1] = oy.min(oy.wrapping_sub(q)); + } + } else { + for i in 0..m { + let a = &mut a[2 * i * t..(2 * (i + 1) * t)]; + let (left, right) = a.split_at_mut(t); + + for (x, y) in izip!(left.iter_mut(), right.iter_mut()) { + let (ox, oy) = forward_butterly_0_to_4q(*x, *y, w[i], w_shoup[i], q, q_twice); + *x = ox; + *y = oy; + } + } } - }); + + m <<= 1; + } } /// Inverse number theoretic transform of input vector `a` with each element can @@ -123,36 +189,92 @@ pub fn ntt_inv_lazy( psi_inv: &[u64], psi_inv_shoup: &[u64], n_inv: u64, + n_inv_shoup: u64, q: u64, q_twice: u64, ) { - debug_assert!(a.len() == psi_inv.len()); + assert!(a.len() == psi_inv.len()); - let mut m = a.len(); + let mut m = a.len() >> 1; let mut t = 1; - while m > 1 { - let mut j_1: usize = 0; - let h = m >> 1; - for i in 0..h { - let j_2 = j_1 + t; - unsafe { - let w_inv = psi_inv.get_unchecked(h + i); - let w_inv_shoup = psi_inv_shoup.get_unchecked(h + i); - - for j in j_1..j_2 { - let x = a.get_unchecked_mut(j) as *mut u64; - let y = a.get_unchecked_mut(j + t) as *mut u64; - inverse_butterfly(x, y, w_inv, w_inv_shoup, &q, &q_twice); + + while m > 0 { + if m == 1 { + let (left, right) = a.split_at_mut(t); + + for (x, y) in izip!(left.iter_mut(), right.iter_mut()) { + let (ox, oy) = + inverse_butterfly_0_to_2q(*x, *y, psi_inv[1], psi_inv_shoup[1], q, q_twice); + *x = ShoupMul::mul(ox, n_inv, n_inv_shoup, q); + *y = ShoupMul::mul(oy, n_inv, n_inv_shoup, q); + } + } else { + let w_inv = &psi_inv[m..]; + let w_inv_shoup = &psi_inv_shoup[m..]; + for i in 0..m { + let a = &mut a[2 * i * t..2 * (i + 1) * t]; + let (left, right) = a.split_at_mut(t); + + for (x, y) in izip!(left.iter_mut(), right.iter_mut()) { + let (ox, oy) = + inverse_butterfly_0_to_2q(*x, *y, w_inv[i], w_inv_shoup[i], q, q_twice); + *x = ox; + *y = oy; } } - j_1 = j_1 + 2 * t; } + t *= 2; m >>= 1; } +} + +/// Same as `ntt_inv_lazy` with output in range [0, q) +pub fn ntt_inv( + a: &mut [u64], + psi_inv: &[u64], + psi_inv_shoup: &[u64], + n_inv: u64, + n_inv_shoup: u64, + q: u64, + q_twice: u64, +) { + assert!(a.len() == psi_inv.len()); + + let mut m = a.len() >> 1; + let mut t = 1; - a.iter_mut() - .for_each(|a0| *a0 = ((*a0 as u128 * n_inv as u128) % q as u128) as u64); + while m > 0 { + if m == 1 { + let (left, right) = a.split_at_mut(t); + + for (x, y) in izip!(left.iter_mut(), right.iter_mut()) { + let (ox, oy) = + inverse_butterfly_0_to_2q(*x, *y, psi_inv[1], psi_inv_shoup[1], q, q_twice); + let ox = ShoupMul::mul(ox, n_inv, n_inv_shoup, q); + let oy = ShoupMul::mul(oy, n_inv, n_inv_shoup, q); + *x = ox.min(ox.wrapping_sub(q)); + *y = oy.min(oy.wrapping_sub(q)); + } + } else { + let w_inv = &psi_inv[m..]; + let w_inv_shoup = &psi_inv_shoup[m..]; + for i in 0..m { + let a = &mut a[2 * i * t..2 * (i + 1) * t]; + let (left, right) = a.split_at_mut(t); + + for (x, y) in izip!(left.iter_mut(), right.iter_mut()) { + let (ox, oy) = + inverse_butterfly_0_to_2q(*x, *y, w_inv[i], w_inv_shoup[i], q, q_twice); + *x = ox; + *y = oy; + } + } + } + + t *= 2; + m >>= 1; + } } /// Find n^{th} root of unity in field F_q, if one exists @@ -184,11 +306,13 @@ pub(crate) fn find_primitive_root(q: u64, n: u64, rng: &mut R) -> Op None } +#[derive(Debug)] pub struct NttBackendU64 { q: u64, q_twice: u64, - n: u64, + _n: u64, n_inv: u64, + n_inv_shoup: u64, psi_powers_bo: Box<[u64]>, psi_inv_powers_bo: Box<[u64]>, psi_powers_bo_shoup: Box<[u64]>, @@ -196,12 +320,11 @@ pub struct NttBackendU64 { } impl NttBackendU64 { - pub fn new(q: u64, n: usize) -> Self { + fn _new(q: u64, n: usize) -> Self { // \psi = 2n^{th} primitive root of unity in F_q - let mut rng = thread_rng(); + let mut rng = ChaCha8Rng::from_seed([0u8; 32]); let psi = find_primitive_root(q, (n * 2) as u64, &mut rng) .expect("Unable to find 2n^th root of unity"); - let psi_inv = mod_inverse(psi, q); // assert!( @@ -238,11 +361,11 @@ impl NttBackendU64 { // shoup representation let psi_powers_bo_shoup = psi_powers_bo .iter() - .map(|v| shoup_representation_fq(*v, q)) + .map(|v| ShoupMul::representation(*v, q)) .collect_vec(); let psi_inv_powers_bo_shoup = psi_inv_powers_bo .iter() - .map(|v| shoup_representation_fq(*v, q)) + .map(|v| ShoupMul::representation(*v, q)) .collect_vec(); // n^{-1} \mod{q} @@ -251,8 +374,9 @@ impl NttBackendU64 { NttBackendU64 { q, q_twice: 2 * q, - n: n as u64, + _n: n as u64, n_inv, + n_inv_shoup: ShoupMul::representation(n_inv, q), psi_powers_bo: psi_powers_bo.into_boxed_slice(), psi_inv_powers_bo: psi_inv_powers_bo.into_boxed_slice(), psi_powers_bo_shoup: psi_powers_bo_shoup.into_boxed_slice(), @@ -261,14 +385,11 @@ impl NttBackendU64 { } } -impl NttBackendU64 { - fn reduce_from_lazy(&self, a: &mut [u64]) { - let q = self.q; - a.iter_mut().for_each(|a0| { - if *a0 >= q { - *a0 = *a0 - q; - } - }); +impl> NttInit for NttBackendU64 { + fn new(q: &M, n: usize) -> Self { + // This NTT does not support native modulus + assert!(!q.is_native()); + NttBackendU64::_new(q.q().unwrap(), n) } } @@ -286,14 +407,13 @@ impl Ntt for NttBackendU64 { } fn forward(&self, v: &mut [Self::Element]) { - ntt_lazy( + ntt( v, &self.psi_powers_bo, &self.psi_powers_bo_shoup, self.q, self.q_twice, ); - self.reduce_from_lazy(v); } fn backward_lazy(&self, v: &mut [Self::Element]) { @@ -302,24 +422,26 @@ impl Ntt for NttBackendU64 { &self.psi_inv_powers_bo, &self.psi_inv_powers_bo_shoup, self.n_inv, + self.n_inv_shoup, self.q, self.q_twice, ) } fn backward(&self, v: &mut [Self::Element]) { - ntt_inv_lazy( + ntt_inv( v, &self.psi_inv_powers_bo, &self.psi_inv_powers_bo_shoup, self.n_inv, + self.n_inv_shoup, self.q, self.q_twice, ); - self.reduce_from_lazy(v); } } +#[cfg(test)] mod tests { use itertools::Itertools; use rand::{thread_rng, Rng}; @@ -327,7 +449,7 @@ mod tests { use super::NttBackendU64; use crate::{ - backend::{ModularOpsU64, VectorOps}, + backend::{ModInit, ModularOpsU64, VectorOps}, ntt::Ntt, utils::{generate_prime, negacyclic_mul}, }; @@ -344,29 +466,40 @@ mod tests { .collect_vec() } + fn assert_output_range(a: &[u64], max_val: u64) { + a.iter() + .for_each(|v| assert!(v <= &max_val, "{v} > {max_val}")); + } + #[test] fn native_ntt_backend_works() { // TODO(Jay): Improve tests. Add tests for different primes and ring size. - let ntt_backend = NttBackendU64::new(Q_60_BITS, N); + let ntt_backend = NttBackendU64::_new(Q_60_BITS, N); for _ in 0..K { let mut a = random_vec_in_fq(N, Q_60_BITS); let a_clone = a.clone(); ntt_backend.forward(&mut a); + assert_output_range(a.as_ref(), Q_60_BITS - 1); assert_ne!(a, a_clone); ntt_backend.backward(&mut a); + assert_output_range(a.as_ref(), Q_60_BITS - 1); assert_eq!(a, a_clone); ntt_backend.forward_lazy(&mut a); + assert_output_range(a.as_ref(), (2 * Q_60_BITS) - 1); assert_ne!(a, a_clone); ntt_backend.backward(&mut a); + assert_output_range(a.as_ref(), Q_60_BITS - 1); assert_eq!(a, a_clone); ntt_backend.forward(&mut a); + assert_output_range(a.as_ref(), Q_60_BITS - 1); ntt_backend.backward_lazy(&mut a); + assert_output_range(a.as_ref(), (2 * Q_60_BITS) - 1); // reduce a.iter_mut().for_each(|a0| { - if *a0 > Q_60_BITS { + if *a0 >= Q_60_BITS { *a0 -= *a0 - Q_60_BITS; } }); @@ -376,13 +509,13 @@ mod tests { #[test] fn native_ntt_negacylic_mul() { - let primes = [40, 50, 60] + let primes = [25, 40, 50, 60] .iter() .map(|bits| generate_prime(*bits, (2 * N) as u64, 1u64 << bits).unwrap()) .collect_vec(); for p in primes.into_iter() { - let ntt_backend = NttBackendU64::new(p, N); + let ntt_backend = NttBackendU64::_new(p, N); let modulus_backend = ModularOpsU64::new(p); for _ in 0..K { let a = random_vec_in_fq(N, p); diff --git a/src/num.rs b/src/num.rs deleted file mode 100644 index 14522ea..0000000 --- a/src/num.rs +++ /dev/null @@ -1,3 +0,0 @@ -use num_traits::{Num, PrimInt, WrappingShl, WrappingShr, Zero}; - -pub trait UnsignedInteger: Zero + Num {} diff --git a/src/pbs.rs b/src/pbs.rs new file mode 100644 index 0000000..c8a8547 --- /dev/null +++ b/src/pbs.rs @@ -0,0 +1,482 @@ +use std::fmt::Display; + +use num_traits::{FromPrimitive, One, PrimInt, ToPrimitive, Zero}; + +use crate::{ + backend::{ArithmeticOps, Modulus, ShoupMatrixFMA, VectorOps}, + decomposer::{Decomposer, RlweDecomposer}, + lwe::lwe_key_switch, + ntt::Ntt, + rgsw::{ + rlwe_auto_shoup, rlwe_by_rgsw_shoup, RgswCiphertextRef, RlweCiphertextMutRef, RlweKskRef, + RuntimeScratchMutRef, + }, + Matrix, MatrixEntity, MatrixMut, RowMut, +}; +pub(crate) trait PbsKey { + type RgswCt; + type AutoKey; + type LweKskKey; + + /// RGSW ciphertext of LWE secret elements + fn rgsw_ct_lwe_si(&self, si: usize) -> &Self::RgswCt; + /// Key for automorphism with g^k. For -g use k = 0 + fn galois_key_for_auto(&self, k: usize) -> &Self::AutoKey; + /// LWE ksk to key switch from RLWE secret to LWE secret + fn lwe_ksk(&self) -> &Self::LweKskKey; +} + +pub(crate) trait WithShoupRepr: AsRef { + type M; + fn shoup_repr(&self) -> &Self::M; +} + +pub(crate) trait PbsInfo { + /// Type of Matrix + type M: Matrix; + /// Type of Ciphertext modulus + type Modulus: Modulus::MatElement>; + /// Type of Ntt Operator for Ring polynomials + type NttOp: Ntt::MatElement>; + /// Type of Signed Decomposer + type D: Decomposer::MatElement>; + + // Although both `RlweModOp` and `LweModOp` types have same bounds, they can be + // different types. For ex, type RlweModOp may only support native modulus, + // where LweModOp may only support prime modulus, etc. + + /// Type of RLWE Modulus Operator + type RlweModOp: ArithmeticOps::MatElement> + + ShoupMatrixFMA<::R>; + /// Type of LWE Modulus Operator + type LweModOp: VectorOps::MatElement> + + ArithmeticOps::MatElement>; + + /// RLWE ciphertext modulus + fn rlwe_q(&self) -> &Self::Modulus; + /// LWE ciphertext modulus + fn lwe_q(&self) -> &Self::Modulus; + /// Blind rotation modulus. It is the modulus to which we switch for blind + /// rotation. Since blind rotation decrypts LWE ciphetext in the exponent of + /// ring polynmial (which is a ring mod 2N), `br_q <= 2N` + fn br_q(&self) -> usize; + /// Ring polynomial size `N` + fn rlwe_n(&self) -> usize; + /// LWE dimension `n` + fn lwe_n(&self) -> usize; + /// Embedding fator for ring X^{q}+1 inside + fn embedding_factor(&self) -> usize; + /// Window size parameter LKMC++ blind rotaiton + fn w(&self) -> usize; + /// generator `g` for group Z^*_{br_q} + fn g(&self) -> isize; + /// LWE key switching decomposer + fn lwe_decomposer(&self) -> &Self::D; + /// RLWE x RGSW decoposer + fn rlwe_rgsw_decomposer(&self) -> &(Self::D, Self::D); + /// RLWE auto decomposer + fn auto_decomposer(&self) -> &Self::D; + + /// LWE modulus operator + fn modop_lweq(&self) -> &Self::LweModOp; + /// RLWE modulus operator + fn modop_rlweq(&self) -> &Self::RlweModOp; + + /// Ntt operators + fn nttop_rlweq(&self) -> &Self::NttOp; + + /// Maps a \in Z^*_{br_q} to discrete log k, with generator g (i.e. g^k = + /// a). Returned vector is of size q that stores dlog of `a` at `vec[a]`. + /// + /// For any `a`, if k is s.t. `a = g^{k} % br_q`, then `k` is expressed as + /// k. If `k` is s.t `a = -g^{k} % br_q`, then `k` is expressed as + /// k=k+q/4 + fn g_k_dlog_map(&self) -> &[usize]; + /// Returns auto map and index vector for auto element g^k. For auto element + /// -g set k = 0. + fn rlwe_auto_map(&self, k: usize) -> &(Vec, Vec); +} + +/// - Mod down +/// - key switching +/// - mod down +/// - blind rotate +pub(crate) fn pbs< + M: MatrixMut + MatrixEntity, + MShoup: WithShoupRepr, + P: PbsInfo, + K: PbsKey, +>( + pbs_info: &P, + test_vec: &M::R, + lwe_in: &mut M::R, + pbs_key: &K, + scratch_lwe_vec: &mut M::R, + scratch_blind_rotate_matrix: &mut M, +) where + ::R: RowMut, + M::MatElement: PrimInt + FromPrimitive + One + Copy + Zero + Display, +{ + let rlwe_q = pbs_info.rlwe_q(); + let lwe_q = pbs_info.lwe_q(); + let br_q = pbs_info.br_q(); + let rlwe_qf64 = rlwe_q.q_as_f64().unwrap(); + let lwe_qf64 = lwe_q.q_as_f64().unwrap(); + let br_qf64 = br_q.to_f64().unwrap(); + let rlwe_n = pbs_info.rlwe_n(); + + // moddown Q -> Q_ks + lwe_in.as_mut().iter_mut().for_each(|v| { + *v = + M::MatElement::from_f64(((v.to_f64().unwrap() * lwe_qf64) / rlwe_qf64).round()).unwrap() + }); + + // key switch RLWE secret to LWE secret + // let now = std::time::Instant::now(); + scratch_lwe_vec.as_mut().fill(M::MatElement::zero()); + lwe_key_switch( + scratch_lwe_vec, + lwe_in, + pbs_key.lwe_ksk(), + pbs_info.modop_lweq(), + pbs_info.lwe_decomposer(), + ); + // println!("Time: {:?}", now.elapsed()); + + // odd moddown Q_ks -> q + let g_k_dlog_map = pbs_info.g_k_dlog_map(); + let mut g_k_si = vec![vec![]; br_q >> 1]; + scratch_lwe_vec + .as_ref() + .iter() + .skip(1) + .enumerate() + .for_each(|(index, v)| { + let odd_v = mod_switch_odd(v.to_f64().unwrap(), lwe_qf64, br_qf64); + // dlog `k` for `odd_v` is stored as `k` if odd_v = +g^{k}. If odd_v = -g^{k}, + // then `k` is stored as `q/4 + k`. + let k = g_k_dlog_map[odd_v]; + // assert!(k != 0); + g_k_si[k].push(index); + }); + + // handle b and set trivial test RLWE + let g = pbs_info.g() as usize; + let g_times_b = (g * mod_switch_odd( + scratch_lwe_vec.as_ref()[0].to_f64().unwrap(), + lwe_qf64, + br_qf64, + )) % (br_q); + // v = (v(X) * X^{g*b}) mod X^{q/2}+1 + let br_qby2 = br_q >> 1; + let mut gb_monomial_sign = true; + let mut gb_monomial_exp = g_times_b; + // X^{g*b} mod X^{q/2}+1 + if gb_monomial_exp > br_qby2 { + gb_monomial_exp -= br_qby2; + gb_monomial_sign = false + } + // monomial mul + let mut trivial_rlwe_test_poly = M::zeros(2, rlwe_n); + if pbs_info.embedding_factor() == 1 { + monomial_mul( + test_vec.as_ref(), + trivial_rlwe_test_poly.get_row_mut(1).as_mut(), + gb_monomial_exp, + gb_monomial_sign, + br_qby2, + pbs_info.modop_rlweq(), + ); + } else { + // use lwe_in to store the `t = v(X) * X^{g*2} mod X^{q/2}+1` temporarily. This + // works because q/2 <= N (where N is lwe_in LWE dimension) always. + monomial_mul( + test_vec.as_ref(), + &mut lwe_in.as_mut()[..br_qby2], + gb_monomial_exp, + gb_monomial_sign, + br_qby2, + pbs_info.modop_rlweq(), + ); + + // emebed poly `t` in ring X^{q/2}+1 inside the bigger ring X^{N}+1 + let embed_factor = pbs_info.embedding_factor(); + let partb_trivial_rlwe = trivial_rlwe_test_poly.get_row_mut(1); + lwe_in.as_ref()[..br_qby2] + .iter() + .enumerate() + .for_each(|(index, v)| { + partb_trivial_rlwe[embed_factor * index] = *v; + }); + } + + // let now = std::time::Instant::now(); + // blind rotate + blind_rotation( + &mut trivial_rlwe_test_poly, + scratch_blind_rotate_matrix, + pbs_info.g(), + pbs_info.w(), + br_q, + &g_k_si, + pbs_info.rlwe_rgsw_decomposer(), + pbs_info.auto_decomposer(), + pbs_info.nttop_rlweq(), + pbs_info.modop_rlweq(), + pbs_info, + pbs_key, + ); + // println!("Blind rotation time: {:?}", now.elapsed()); + + // sample extract + sample_extract(lwe_in, &trivial_rlwe_test_poly, pbs_info.modop_rlweq(), 0); +} + +/// LMKCY+ Blind rotation +/// +/// - gk_to_si: Contains LWE secret index `i` in array of secret indices at k^th +/// index if a_i = g^k if k < q/4 or a_i = -g^k if k > q/4. [g^0, ..., +/// g^{q/2-1}, -g^0, -g^1, .., -g^{q/2-1}] +fn blind_rotation< + Mmut: MatrixMut, + RlweD: RlweDecomposer, + AutoD: Decomposer, + NttOp: Ntt, + ModOp: ArithmeticOps + ShoupMatrixFMA, + MShoup: WithShoupRepr, + K: PbsKey, + P: PbsInfo, +>( + trivial_rlwe_test_poly: &mut Mmut, + scratch_matrix: &mut Mmut, + _g: isize, + w: usize, + q: usize, + gk_to_si: &[Vec], + rlwe_rgsw_decomposer: &RlweD, + auto_decomposer: &AutoD, + ntt_op: &NttOp, + mod_op: &ModOp, + parameters: &P, + pbs_key: &K, +) where + ::R: RowMut, + Mmut::MatElement: Copy + Zero, +{ + let mut is_trivial = true; + let mut scratch_matrix = RuntimeScratchMutRef::new(scratch_matrix.as_mut()); + let mut rlwe = RlweCiphertextMutRef::new(trivial_rlwe_test_poly.as_mut()); + let d_a = rlwe_rgsw_decomposer.a().decomposition_count().0; + let d_b = rlwe_rgsw_decomposer.b().decomposition_count().0; + let d_auto = auto_decomposer.decomposition_count().0; + + let q_by_4 = q >> 2; + // let mut count = 0; + // -(g^k) + let mut v = 0; + for i in (1..q_by_4).rev() { + // dbg!(q_by_4 + i); + let s_indices = &gk_to_si[q_by_4 + i]; + + s_indices.iter().for_each(|s_index| { + // let new = std::time::Instant::now(); + let ct = pbs_key.rgsw_ct_lwe_si(*s_index); + rlwe_by_rgsw_shoup( + &mut rlwe, + &RgswCiphertextRef::new(ct.as_ref().as_ref(), d_a, d_b), + &RgswCiphertextRef::new(ct.shoup_repr().as_ref(), d_a, d_b), + &mut scratch_matrix, + rlwe_rgsw_decomposer, + ntt_op, + mod_op, + is_trivial, + ); + is_trivial = false; + // println!("Rlwe x Rgsw time: {:?}", new.elapsed()); + }); + v += 1; + + if gk_to_si[q_by_4 + i - 1].len() != 0 || v == w || i == 1 { + let (auto_map_index, auto_map_sign) = parameters.rlwe_auto_map(v); + + // let now = std::time::Instant::now(); + let auto_key = pbs_key.galois_key_for_auto(v); + rlwe_auto_shoup( + &mut rlwe, + &RlweKskRef::new(auto_key.as_ref().as_ref(), d_auto), + &RlweKskRef::new(auto_key.shoup_repr().as_ref(), d_auto), + &mut scratch_matrix, + &auto_map_index, + &auto_map_sign, + mod_op, + ntt_op, + auto_decomposer, + is_trivial, + ); + // println!("Auto time: {:?}", now.elapsed()); + // count += 1; + + v = 0; + } + } + + // -(g^0) + { + gk_to_si[q_by_4].iter().for_each(|s_index| { + let ct = pbs_key.rgsw_ct_lwe_si(*s_index); + rlwe_by_rgsw_shoup( + &mut rlwe, + &RgswCiphertextRef::new(ct.as_ref().as_ref(), d_a, d_b), + &RgswCiphertextRef::new(ct.shoup_repr().as_ref(), d_a, d_b), + &mut scratch_matrix, + rlwe_rgsw_decomposer, + ntt_op, + mod_op, + is_trivial, + ); + is_trivial = false; + }); + + let (auto_map_index, auto_map_sign) = parameters.rlwe_auto_map(0); + let auto_key = pbs_key.galois_key_for_auto(0); + rlwe_auto_shoup( + &mut rlwe, + &RlweKskRef::new(auto_key.as_ref().as_ref(), d_auto), + &RlweKskRef::new(auto_key.shoup_repr().as_ref(), d_auto), + &mut scratch_matrix, + &auto_map_index, + &auto_map_sign, + mod_op, + ntt_op, + auto_decomposer, + is_trivial, + ); + // count += 1; + } + + // +(g^k) + let mut v = 0; + for i in (1..q_by_4).rev() { + let s_indices = &gk_to_si[i]; + s_indices.iter().for_each(|s_index| { + let ct = pbs_key.rgsw_ct_lwe_si(*s_index); + rlwe_by_rgsw_shoup( + &mut rlwe, + &RgswCiphertextRef::new(ct.as_ref().as_ref(), d_a, d_b), + &RgswCiphertextRef::new(ct.shoup_repr().as_ref(), d_a, d_b), + &mut scratch_matrix, + rlwe_rgsw_decomposer, + ntt_op, + mod_op, + is_trivial, + ); + is_trivial = false; + }); + v += 1; + + if gk_to_si[i - 1].len() != 0 || v == w || i == 1 { + let (auto_map_index, auto_map_sign) = parameters.rlwe_auto_map(v); + let auto_key = pbs_key.galois_key_for_auto(v); + rlwe_auto_shoup( + &mut rlwe, + &RlweKskRef::new(auto_key.as_ref().as_ref(), d_auto), + &RlweKskRef::new(auto_key.shoup_repr().as_ref(), d_auto), + &mut scratch_matrix, + &auto_map_index, + &auto_map_sign, + mod_op, + ntt_op, + auto_decomposer, + is_trivial, + ); + v = 0; + + // count += 1; + } + } + + // +(g^0) + gk_to_si[0].iter().for_each(|s_index| { + let ct = pbs_key.rgsw_ct_lwe_si(*s_index); + rlwe_by_rgsw_shoup( + &mut rlwe, + &RgswCiphertextRef::new(ct.as_ref().as_ref(), d_a, d_b), + &RgswCiphertextRef::new(ct.shoup_repr().as_ref(), d_a, d_b), + &mut scratch_matrix, + rlwe_rgsw_decomposer, + ntt_op, + mod_op, + is_trivial, + ); + is_trivial = false; + }); + // println!("Auto count: {count}"); +} + +fn mod_switch_odd(v: f64, from_q: f64, to_q: f64) -> usize { + let odd_v = (((v * to_q) / (from_q)).floor()).to_usize().unwrap(); + //TODO(Jay): check correctness of this + odd_v + ((odd_v & 1) ^ 1) +} + +// TODO(Jay): Add tests for sample extract +pub(crate) fn sample_extract>( + lwe_out: &mut M::R, + rlwe_in: &M, + mod_op: &ModOp, + index: usize, +) where + ::R: RowMut, + M::MatElement: Copy, +{ + let ring_size = rlwe_in.dimension().1; + assert!(ring_size + 1 == lwe_out.as_ref().len()); + + // index..=0 + let to = &mut lwe_out.as_mut()[1..]; + let from = rlwe_in.get_row_slice(0); + for i in 0..index + 1 { + to[i] = from[index - i]; + } + + // -(N..index) + for i in index + 1..ring_size { + to[i] = mod_op.neg(&from[ring_size + index - i]); + } + + // set b + lwe_out.as_mut()[0] = *rlwe_in.get(1, index); +} + +/// Monomial multiplication (p(X)*X^{mon_exp}) +/// +/// - p_out: Output is written to p_out and independent of values in p_out +fn monomial_mul>( + p_in: &[El], + p_out: &mut [El], + mon_exp: usize, + mon_sign: bool, + ring_size: usize, + mod_op: &ModOp, +) where + El: Copy, +{ + debug_assert!(p_in.as_ref().len() == ring_size); + debug_assert!(p_in.as_ref().len() == p_out.as_ref().len()); + debug_assert!(mon_exp < ring_size); + + p_in.as_ref().iter().enumerate().for_each(|(index, v)| { + let mut to_index = index + mon_exp; + let mut to_sign = mon_sign; + if to_index >= ring_size { + to_index = to_index - ring_size; + to_sign = !to_sign; + } + + if !to_sign { + p_out.as_mut()[to_index] = mod_op.neg(v); + } else { + p_out.as_mut()[to_index] = *v; + } + }); +} diff --git a/src/random.rs b/src/random.rs index 397585e..1745a94 100644 --- a/src/random.rs +++ b/src/random.rs @@ -1,33 +1,69 @@ use std::cell::RefCell; use itertools::izip; -use rand::{distributions::Uniform, thread_rng, CryptoRng, Rng, RngCore, SeedableRng}; +use num_traits::{FromPrimitive, PrimInt, Zero}; +use rand::{distributions::Uniform, Rng, RngCore, SeedableRng}; use rand_chacha::ChaCha8Rng; -use rand_distr::Distribution; +use rand_distr::{uniform::SampleUniform, Distribution}; -use crate::utils::WithLocal; +use crate::{backend::Modulus, utils::WithLocal}; thread_local! { pub(crate) static DEFAULT_RNG: RefCell = RefCell::new(DefaultSecureRng::new()); } -pub trait RandomGaussianDist +pub trait NewWithSeed { + type Seed; + fn new_with_seed(seed: Self::Seed) -> Self; +} + +pub trait RandomElementInModulus { + /// Sample Random element of type T in range [0, modulus) + fn random(&mut self, modulus: &M) -> T; +} + +pub trait RandomGaussianElementInModulus { + /// Sample Random gaussian element from \mu = 0.0 and \sigma = 3.19. Sampled + /// element is converted to signed representation in modulus. + fn random(&mut self, modulus: &M) -> T; +} + +pub trait RandomFill +where + M: ?Sized, +{ + /// Fill container with random elements of type of its elements + fn random_fill(&mut self, container: &mut M); +} + +pub trait RandomFillGaussian +where + M: ?Sized, +{ + /// Fill container with random elements sampled from normal distribtuion + /// with \mu = 0.0 and \sigma = 3.19. + fn random_fill(&mut self, container: &mut M); +} + +pub trait RandomFillUniformInModulus where M: ?Sized, { - type Parameters: ?Sized; - fn random_fill(&mut self, parameters: &Self::Parameters, container: &mut M); + /// Fill container with random elements in range [0, modulus) + fn random_fill(&mut self, modulus: &P, container: &mut M); } -pub trait RandomUniformDist +pub trait RandomFillGaussianInModulus where M: ?Sized, { - type Parameters: ?Sized; - fn random_fill(&mut self, parameters: &Self::Parameters, container: &mut M); + /// Fill container with gaussian elements sampled from normal distribution + /// with \mu = 0.0 and \sigma = 3.19. Elements are converted to signed + /// represented in the modulus. + fn random_fill(&mut self, modulus: &P, container: &mut M); } -pub(crate) struct DefaultSecureRng { +pub struct DefaultSecureRng { rng: ChaCha8Rng, } @@ -41,27 +77,30 @@ impl DefaultSecureRng { let rng = ChaCha8Rng::from_entropy(); DefaultSecureRng { rng } } -} -impl RandomUniformDist for DefaultSecureRng { - type Parameters = usize; - fn random_fill(&mut self, parameters: &Self::Parameters, container: &mut usize) { - *container = self.rng.gen_range(0..*parameters); + pub fn fill_bytes(&mut self, a: &mut [u8; 32]) { + self.rng.fill_bytes(a); } } -impl RandomUniformDist<[u8]> for DefaultSecureRng { - type Parameters = u8; - fn random_fill(&mut self, parameters: &Self::Parameters, container: &mut [u8]) { - self.rng.fill_bytes(container); +impl NewWithSeed for DefaultSecureRng { + type Seed = ::Seed; + fn new_with_seed(seed: Self::Seed) -> Self { + DefaultSecureRng::new_seeded(seed) } } -impl RandomUniformDist<[u32]> for DefaultSecureRng { - type Parameters = u32; - fn random_fill(&mut self, parameters: &Self::Parameters, container: &mut [u32]) { +impl RandomFillUniformInModulus<[T], C> for DefaultSecureRng +where + T: PrimInt + SampleUniform, + C: Modulus, +{ + fn random_fill(&mut self, modulus: &C, container: &mut [T]) { izip!( - (&mut self.rng).sample_iter(Uniform::new(0, parameters)), + (&mut self.rng).sample_iter(Uniform::new_inclusive( + T::zero(), + modulus.largest_unsigned_value() + )), container.iter_mut() ) .for_each(|(from, to)| { @@ -70,99 +109,90 @@ impl RandomUniformDist<[u32]> for DefaultSecureRng { } } -impl RandomUniformDist<[u64]> for DefaultSecureRng { - type Parameters = u64; - fn random_fill(&mut self, parameters: &Self::Parameters, container: &mut [u64]) { +impl RandomFillGaussianInModulus<[T], C> for DefaultSecureRng +where + T: PrimInt, + C: Modulus, +{ + fn random_fill(&mut self, modulus: &C, container: &mut [T]) { izip!( - (&mut self.rng).sample_iter(Uniform::new(0, parameters)), + rand_distr::Normal::new(0.0, 3.19f64) + .unwrap() + .sample_iter(&mut self.rng), container.iter_mut() ) .for_each(|(from, to)| { - *to = from; + *to = modulus.map_element_from_f64(from); }); } } -impl RandomGaussianDist for DefaultSecureRng { - type Parameters = u64; - fn random_fill(&mut self, parameters: &Self::Parameters, container: &mut u64) { - let o = rand_distr::Normal::new(0.0, 3.2f64) - .unwrap() - .sample(&mut self.rng) - .round(); - - // let o = 0.0f64; - - let is_neg = o.is_sign_negative() && o != 0.0; - if is_neg { - *container = parameters - (o.abs() as u64); - } else { - *container = o as u64; - } - } -} - -impl RandomGaussianDist for DefaultSecureRng { - type Parameters = u32; - fn random_fill(&mut self, parameters: &Self::Parameters, container: &mut u32) { - let o = rand_distr::Normal::new(0.0, 3.2f32) - .unwrap() - .sample(&mut self.rng) - .round(); - - // let o = 0.0f32; - let is_neg = o.is_sign_negative() && o != 0.0; - - if is_neg { - *container = parameters - (o.abs() as u32); - } else { - *container = o as u32; - } +impl RandomFill<[T]> for DefaultSecureRng +where + T: PrimInt + SampleUniform, +{ + fn random_fill(&mut self, container: &mut [T]) { + izip!( + (&mut self.rng).sample_iter(Uniform::new_inclusive(T::zero(), T::max_value())), + container.iter_mut() + ) + .for_each(|(from, to)| { + *to = from; + }); } } -impl RandomGaussianDist<[u64]> for DefaultSecureRng { - type Parameters = u64; - fn random_fill(&mut self, parameters: &Self::Parameters, container: &mut [u64]) { +impl RandomFillGaussian<[T]> for DefaultSecureRng +where + T: FromPrimitive, +{ + fn random_fill(&mut self, container: &mut [T]) { izip!( - rand_distr::Normal::new(0.0, 3.2f64) + rand_distr::Normal::new(0.0, 3.19f64) .unwrap() .sample_iter(&mut self.rng), container.iter_mut() ) - .for_each(|(oi, v)| { - let oi = oi.round(); - let is_neg = oi.is_sign_negative() && oi != 0.0; - if is_neg { - *v = parameters - (oi.abs() as u64); - } else { - *v = oi as u64; - } + .for_each(|(from, to)| { + *to = T::from_f64(from).unwrap(); }); } } -impl RandomGaussianDist<[u32]> for DefaultSecureRng { - type Parameters = u32; - fn random_fill(&mut self, parameters: &Self::Parameters, container: &mut [u32]) { +impl RandomFill<[T; 32]> for DefaultSecureRng +where + T: PrimInt + SampleUniform, +{ + fn random_fill(&mut self, container: &mut [T; 32]) { izip!( - rand_distr::Normal::new(0.0, 3.2f32) - .unwrap() - .sample_iter(&mut self.rng), + (&mut self.rng).sample_iter(Uniform::new_inclusive(T::zero(), T::max_value())), container.iter_mut() ) - .for_each(|(oi, v)| { - let oi = oi.round(); - let is_neg = oi.is_sign_negative() && oi != 0.0; - if is_neg { - *v = parameters - (oi.abs() as u32); - } else { - *v = oi as u32; - } + .for_each(|(from, to)| { + *to = from; }); } } +impl RandomElementInModulus for DefaultSecureRng +where + T: Zero + SampleUniform, +{ + fn random(&mut self, modulus: &T) -> T { + Uniform::new(T::zero(), modulus).sample(&mut self.rng) + } +} + +impl> RandomGaussianElementInModulus for DefaultSecureRng { + fn random(&mut self, modulus: &M) -> T { + modulus.map_element_from_f64( + rand_distr::Normal::new(0.0, 3.19f64) + .unwrap() + .sample(&mut self.rng), + ) + } +} + impl WithLocal for DefaultSecureRng { fn with_local(func: F) -> R where @@ -177,4 +207,11 @@ impl WithLocal for DefaultSecureRng { { DEFAULT_RNG.with_borrow_mut(|r| func(r)) } + + fn with_local_mut_mut(func: &mut F) -> R + where + F: FnMut(&mut Self) -> R, + { + DEFAULT_RNG.with_borrow_mut(|r| func(r)) + } } diff --git a/src/rgsw.rs b/src/rgsw.rs deleted file mode 100644 index 72a30cf..0000000 --- a/src/rgsw.rs +++ /dev/null @@ -1,466 +0,0 @@ -use itertools::izip; - -use crate::{ - backend::VectorOps, - decomposer::{self, Decomposer}, - ntt::Ntt, - random::{DefaultSecureRng, RandomGaussianDist, RandomUniformDist}, - utils::{fill_random_ternary_secret_with_hamming_weight, TryConvertFrom, WithLocal}, - Matrix, MatrixEntity, MatrixMut, RowMut, Secret, -}; - -struct RlweSecret { - values: Vec, -} - -impl Secret for RlweSecret { - type Element = i32; - fn values(&self) -> &[Self::Element] { - &self.values - } -} - -impl RlweSecret { - fn random(hw: usize, n: usize) -> RlweSecret { - DefaultSecureRng::with_local_mut(|rng| { - let mut out = vec![0i32; n]; - fill_random_ternary_secret_with_hamming_weight(&mut out, hw, rng); - - RlweSecret { values: out } - }) - } -} - -/// Encrypts message m as a RGSW ciphertext. -/// -/// - m_eval: is `m` is evaluation domain -/// - out_rgsw: RGSW(m) is stored as single matrix of dimension (d_rgsw * 4, -/// ring_size). The matrix has the following structure [RLWE'_A(-sm) || -/// RLWE'_B(-sm) || RLWE'_A(m) || RLWE'_B(m)]^T -fn encrypt_rgsw< - Mmut: MatrixMut + MatrixEntity, - M: Matrix + Clone, - S: Secret, - R: RandomGaussianDist<[Mmut::MatElement], Parameters = Mmut::MatElement> - + RandomUniformDist<[Mmut::MatElement], Parameters = Mmut::MatElement>, - ModOp: VectorOps, - NttOp: Ntt, ->( - out_rgsw: &mut Mmut, - m_eval: &M, - gadget_vector: &[Mmut::MatElement], - s: &S, - mod_op: &ModOp, - ntt_op: &NttOp, - rng: &mut R, -) where - ::R: RowMut, - Mmut: TryConvertFrom<[S::Element], Parameters = Mmut::MatElement>, -{ - let d = gadget_vector.len(); - let q = mod_op.modulus(); - let ring_size = s.values().len(); - assert!(out_rgsw.dimension() == (d * 4, ring_size)); - assert!(m_eval.dimension() == (1, ring_size)); - - // RLWE(-sm), RLWE(-sm) - let (rlwe_dash_nsm, rlwe_dash_m) = out_rgsw.split_at_row(d * 2); - - let mut s_eval = Mmut::try_convert_from(s.values(), &q); - ntt_op.forward(s_eval.get_row_mut(0).as_mut()); - - let mut scratch_space = Mmut::zeros(1, ring_size); - - // RLWE'(-sm) - let (a_rlwe_dash_nsm, b_rlwe_dash_nsm) = rlwe_dash_nsm.split_at_mut(d); - izip!( - a_rlwe_dash_nsm.iter_mut(), - b_rlwe_dash_nsm.iter_mut(), - gadget_vector.iter() - ) - .for_each(|(ai, bi, beta_i)| { - // Sample a_i and transform to evaluation domain - RandomUniformDist::random_fill(rng, &q, ai.as_mut()); - ntt_op.forward(ai.as_mut()); - - // a_i * s - mod_op.elwise_mul( - scratch_space.get_row_mut(0), - ai.as_ref(), - s_eval.get_row_slice(0), - ); - // b_i = e_i + a_i * s - RandomGaussianDist::random_fill(rng, &q, bi.as_mut()); - ntt_op.forward(bi.as_mut()); - mod_op.elwise_add_mut(bi.as_mut(), scratch_space.get_row_slice(0)); - - // a_i + \beta_i * m - mod_op.elwise_scalar_mul( - scratch_space.get_row_mut(0), - m_eval.get_row_slice(0), - beta_i, - ); - mod_op.elwise_add_mut(ai.as_mut(), scratch_space.get_row_slice(0)); - }); - - // RLWE(m) - let (a_rlwe_dash_m, b_rlwe_dash_m) = rlwe_dash_m.split_at_mut(d); - izip!( - a_rlwe_dash_m.iter_mut(), - b_rlwe_dash_m.iter_mut(), - gadget_vector.iter() - ) - .for_each(|(ai, bi, beta_i)| { - // Sample e_i and transform to evaluation domain - RandomGaussianDist::random_fill(rng, &q, bi.as_mut()); - ntt_op.forward(bi.as_mut()); - - // beta_i * m - mod_op.elwise_scalar_mul( - scratch_space.get_row_mut(0), - m_eval.get_row_slice(0), - beta_i, - ); - // e_i + beta_i * m - mod_op.elwise_add_mut(bi.as_mut(), scratch_space.get_row_slice(0)); - - // Sample a_i and transform to evaluation domain - RandomUniformDist::random_fill(rng, &q, ai.as_mut()); - ntt_op.forward(ai.as_mut()); - - // ai * s - mod_op.elwise_mul( - scratch_space.get_row_mut(0), - ai.as_ref(), - s_eval.get_row_slice(0), - ); - - // b_i = a_i*s + e_i + beta_i*m - mod_op.elwise_add_mut(bi.as_mut(), scratch_space.get_row_slice(0)); - }); -} - -/// Returns RLWE(mm') = RLWE(m) x RGSW(m') -/// -/// - rgsw_in: RGSW(m') in evaluation domain -/// - rlwe_in_decomposed: decomposed RLWE(m) in evaluation domain -/// - rlwe_out: returned RLWE(mm') in evaluation domain -fn rlwe_in_decomposed_evaluation_domain_mul_rgsw_rlwe_out_evaluation_domain< - Mmut: MatrixMut + MatrixEntity, - M: Matrix + Clone, - ModOp: VectorOps, ->( - rgsw_in: &M, - rlwe_in_decomposed_eval: &Mmut, - rlwe_out_eval: &mut Mmut, - mod_op: &ModOp, -) where - ::R: RowMut, -{ - let ring_size = rgsw_in.dimension().1; - let d_rgsw = rgsw_in.dimension().0 / 4; - assert!(rlwe_in_decomposed_eval.dimension() == (2 * d_rgsw, ring_size)); - assert!(rlwe_out_eval.dimension() == (2, ring_size)); - - let (a_rlwe_out, b_rlwe_out) = rlwe_out_eval.split_at_row(1); - - // a * RLWE'(-sm) - let a_rlwe_dash_nsm = rgsw_in.iter_rows().take(d_rgsw); - let b_rlwe_dash_nsm = rgsw_in.iter_rows().skip(d_rgsw).take(d_rgsw); - izip!( - rlwe_in_decomposed_eval.iter_rows().take(d_rgsw), - a_rlwe_dash_nsm - ) - .for_each(|(a, b)| { - mod_op.elwise_fma_mut(a_rlwe_out[0].as_mut(), a.as_ref(), b.as_ref()); - }); - izip!( - rlwe_in_decomposed_eval.iter_rows().take(d_rgsw), - b_rlwe_dash_nsm - ) - .for_each(|(a, b)| { - mod_op.elwise_fma_mut(b_rlwe_out[0].as_mut(), a.as_ref(), b.as_ref()); - }); - - // b * RLWE'(m) - let a_rlwe_dash_m = rgsw_in.iter_rows().skip(d_rgsw * 2).take(d_rgsw); - let b_rlwe_dash_m = rgsw_in.iter_rows().skip(d_rgsw * 3); - izip!( - rlwe_in_decomposed_eval.iter_rows().skip(d_rgsw), - a_rlwe_dash_m - ) - .for_each(|(a, b)| { - mod_op.elwise_fma_mut(a_rlwe_out[0].as_mut(), a.as_ref(), b.as_ref()); - }); - izip!( - rlwe_in_decomposed_eval.iter_rows().skip(d_rgsw), - b_rlwe_dash_m - ) - .for_each(|(a, b)| { - mod_op.elwise_fma_mut(b_rlwe_out[0].as_mut(), a.as_ref(), b.as_ref()); - }); -} - -fn decompose_rlwe< - M: Matrix + Clone, - Mmut: MatrixMut + MatrixEntity, - D: Decomposer, ->( - rlwe_in: &M, - decomposer: &D, - rlwe_in_decomposed: &mut Mmut, -) where - M::MatElement: Copy, - ::R: RowMut, -{ - let d_rgsw = decomposer.d(); - let ring_size = rlwe_in.dimension().1; - assert!(rlwe_in_decomposed.dimension() == (2 * d_rgsw, ring_size)); - - // Decompose rlwe_in - for ri in 0..ring_size { - // ai - let ai_decomposed = decomposer.decompose(rlwe_in.get(0, ri)); - for j in 0..d_rgsw { - rlwe_in_decomposed.set(j, ri, ai_decomposed[j]); - } - - // bi - let bi_decomposed = decomposer.decompose(rlwe_in.get(1, ri)); - for j in 0..d_rgsw { - rlwe_in_decomposed.set(j + d_rgsw, ri, bi_decomposed[j]); - } - } -} - -/// Returns RLWE(m0m1) = RLWE(m0) x RGSW(m1) -/// -/// - rlwe_in: is RLWE(m0) with polynomials in coefficient domain -/// - rgsw_in: is RGSW(m1) with polynomials in evaluation domain -/// - rlwe_out: is output RLWE(m0m1) with polynomials in coefficient domain -/// - rlwe_in_decomposed: is a matrix of dimension (d_rgsw * 2, ring_size) used -/// as scratch space to store decomposed RLWE(m0) -fn rlwe_by_rgsw< - M: Matrix + Clone, - Mmut: MatrixMut + MatrixEntity, - D: Decomposer, - ModOp: VectorOps, - NttOp: Ntt, ->( - rlwe_in: &M, - rgsw_in: &M, - rlwe_out: &mut Mmut, - rlwe_in_decomposed: &mut Mmut, - decomposer: &D, - ntt_op: &NttOp, - mod_op: &ModOp, -) where - M::MatElement: Copy, - ::R: RowMut, -{ - decompose_rlwe(rlwe_in, decomposer, rlwe_in_decomposed); - - // transform rlwe_in decomposed to evaluation domain - rlwe_in_decomposed - .iter_rows_mut() - .for_each(|r| ntt_op.forward(r.as_mut())); - - // decomposed RLWE x RGSW - rlwe_in_decomposed_evaluation_domain_mul_rgsw_rlwe_out_evaluation_domain( - rgsw_in, - rlwe_in_decomposed, - rlwe_out, - mod_op, - ); - - // transform rlwe_out to coefficient domain - rlwe_out - .iter_rows_mut() - .for_each(|r| ntt_op.backward(r.as_mut())); -} - -/// Encrypt polynomial m(X) as RLWE ciphertext. -/// -/// - rlwe_out: returned RLWE ciphertext RLWE(m) in coefficient domain. RLWE -/// ciphertext is a matirx with first row consiting polynomial `a` and the -/// second rows consting polynomial `b` -fn encrypt_rlwe< - Mmut: Matrix + MatrixMut + Clone, - ModOp: VectorOps, - NttOp: Ntt, - S: Secret, - R: RandomUniformDist<[Mmut::MatElement], Parameters = Mmut::MatElement> - + RandomGaussianDist<[Mmut::MatElement], Parameters = Mmut::MatElement>, ->( - m: &Mmut, - rlwe_out: &mut Mmut, - s: &S, - mod_op: &ModOp, - ntt_op: &NttOp, - rng: &mut R, -) where - ::R: RowMut, - Mmut: TryConvertFrom<[S::Element], Parameters = Mmut::MatElement>, -{ - let ring_size = s.values().len(); - assert!(rlwe_out.dimension() == (2, ring_size)); - assert!(m.dimension() == (1, ring_size)); - - let q = mod_op.modulus(); - - // sample a - RandomUniformDist::random_fill(rng, &q, rlwe_out.get_row_mut(0)); - - // s * a - let mut sa = Mmut::try_convert_from(s.values(), &q); - ntt_op.forward(sa.get_row_mut(0)); - ntt_op.forward(rlwe_out.get_row_mut(0)); - mod_op.elwise_mul_mut(sa.get_row_mut(0), rlwe_out.get_row_slice(0)); - ntt_op.backward(rlwe_out.get_row_mut(0)); - ntt_op.backward(sa.get_row_mut(0)); - - // sample e - RandomGaussianDist::random_fill(rng, &q, rlwe_out.get_row_mut(1)); - mod_op.elwise_add_mut(rlwe_out.get_row_mut(1), m.get_row_slice(0)); - mod_op.elwise_add_mut(rlwe_out.get_row_mut(1), sa.get_row_slice(0)); -} - -/// Decrypts degree 1 RLWE ciphertext RLWE(m) and returns m -/// -/// - rlwe_ct: input degree 1 ciphertext RLWE(m). -fn decrypt_rlwe< - Mmut: MatrixMut + Clone, - M: Matrix, - ModOp: VectorOps, - NttOp: Ntt, - S: Secret, ->( - rlwe_ct: &M, - s: &S, - m_out: &mut Mmut, - ntt_op: &NttOp, - mod_op: &ModOp, -) where - ::R: RowMut, - Mmut: TryConvertFrom<[S::Element], Parameters = Mmut::MatElement>, - Mmut::MatElement: Copy, -{ - let ring_size = s.values().len(); - assert!(rlwe_ct.dimension() == (2, ring_size)); - assert!(m_out.dimension() == (1, ring_size)); - - // transform a to evluation form - m_out - .get_row_mut(0) - .copy_from_slice(rlwe_ct.get_row_slice(0)); - ntt_op.forward(m_out.get_row_mut(0)); - - // -s*a - let mut s = Mmut::try_convert_from(&s.values(), &mod_op.modulus()); - ntt_op.forward(s.get_row_mut(0)); - mod_op.elwise_mul_mut(m_out.get_row_mut(0), s.get_row_slice(0)); - mod_op.elwise_neg_mut(m_out.get_row_mut(0)); - ntt_op.backward(m_out.get_row_mut(0)); - - // m+e = b - s*a - mod_op.elwise_add_mut(m_out.get_row_mut(0), rlwe_ct.get_row_slice(1)); -} - -#[cfg(test)] -mod tests { - use std::vec; - - use itertools::Itertools; - use rand::{thread_rng, Rng}; - - use crate::{ - backend::ModularOpsU64, - decomposer::{gadget_vector, DefaultDecomposer}, - ntt::{self, Ntt, NttBackendU64}, - random::{DefaultSecureRng, RandomUniformDist}, - utils::{generate_prime, negacyclic_mul}, - }; - - use super::{decrypt_rlwe, encrypt_rgsw, encrypt_rlwe, rlwe_by_rgsw, RlweSecret}; - - #[test] - fn rlwe_by_rgsw_works() { - let logq = 50; - let logp = 3; - let ring_size = 1 << 10; - let q = generate_prime(logq, ring_size, 1u64 << logq).unwrap(); - let p = 1u64 << logp; - let d_rgsw = 10; - let logb = 5; - - let mut rng = DefaultSecureRng::new(); - - let s = RlweSecret::random((ring_size >> 1) as usize, ring_size as usize); - - let mut m0 = vec![0u64; ring_size as usize]; - RandomUniformDist::<[u64]>::random_fill(&mut rng, &(1u64 << logp), m0.as_mut_slice()); - let mut m1 = vec![0u64; ring_size as usize]; - m1[thread_rng().gen_range(0..ring_size) as usize] = 1; - - let ntt_op = NttBackendU64::new(q, ring_size as usize); - let mod_op = ModularOpsU64::new(q); - - // Encrypt m1 as RGSW(m1) - let mut rgsw_ct = vec![vec![0u64; ring_size as usize]; d_rgsw * 4]; - let gadget_vector = gadget_vector(logq, logb, d_rgsw); - let mut m1_eval = m1.clone(); - ntt_op.forward(&mut m1_eval); - encrypt_rgsw( - &mut rgsw_ct, - &vec![m1_eval], - &gadget_vector, - &s, - &mod_op, - &ntt_op, - &mut rng, - ); - // println!("RGSW(m1): {:?}", &rgsw_ct); - - // Encrypt m0 as RLWE(m0) - let mut rlwe_in_ct = vec![vec![0u64; ring_size as usize]; 2]; - let encoded_m = m0 - .iter() - .map(|v| (((*v as f64) * q as f64) / (p as f64)).round() as u64) - .collect_vec(); - encrypt_rlwe( - &vec![encoded_m.clone()], - &mut rlwe_in_ct, - &s, - &mod_op, - &ntt_op, - &mut rng, - ); - - // RLWE(m0m1) = RLWE(m0) x RGSW(m1) - let mut rlwe_out_ct = vec![vec![0u64; ring_size as usize]; 2]; - let mut scratch_space = vec![vec![0u64; ring_size as usize]; d_rgsw * 2]; - let decomposer = DefaultDecomposer::new(q, logb, d_rgsw); - rlwe_by_rgsw( - &rlwe_in_ct, - &rgsw_ct, - &mut rlwe_out_ct, - &mut scratch_space, - &decomposer, - &ntt_op, - &mod_op, - ); - - // Decrypt RLWE(m0m1) - let mut encoded_m0m1_back = vec![vec![0u64; ring_size as usize]]; - decrypt_rlwe(&rlwe_out_ct, &s, &mut encoded_m0m1_back, &ntt_op, &mod_op); - let m0m1_back = encoded_m0m1_back[0] - .iter() - .map(|v| (((*v as f64 * p as f64) / (q as f64)).round() as u64) % p) - .collect_vec(); - - let mul_mod = |v0: &u64, v1: &u64| (v0 * v1) % p; - let m0m1 = negacyclic_mul(&m0, &m1, mul_mod, p); - assert_eq!(m0m1, m0m1_back, "Expected {:?} got {:?}", m0m1, m0m1_back); - // dbg!(&m0m1_back, m0m1, q); - } -} diff --git a/src/rgsw/keygen.rs b/src/rgsw/keygen.rs new file mode 100644 index 0000000..777ee9a --- /dev/null +++ b/src/rgsw/keygen.rs @@ -0,0 +1,677 @@ +use std::{fmt::Debug, ops::Sub}; + +use itertools::izip; +use num_traits::{PrimInt, Signed, ToPrimitive, Zero}; + +use crate::{ + backend::{ArithmeticOps, GetModulus, Modulus, VectorOps}, + ntt::Ntt, + random::{ + RandomElementInModulus, RandomFill, RandomFillGaussianInModulus, RandomFillUniformInModulus, + }, + utils::{fill_random_ternary_secret_with_hamming_weight, TryConvertFrom1}, + Matrix, MatrixEntity, MatrixMut, Row, RowEntity, RowMut, +}; + +pub(crate) fn generate_auto_map(ring_size: usize, k: isize) -> (Vec, Vec) { + assert!(k & 1 == 1, "Auto {k} must be odd"); + + let k = if k < 0 { + // k is -ve, return k%(2*N) + (2 * ring_size) - (k.abs() as usize % (2 * ring_size)) + } else { + k as usize + }; + let (auto_map_index, auto_sign_index): (Vec, Vec) = (0..ring_size) + .into_iter() + .map(|i| { + let mut to_index = (i * k) % (2 * ring_size); + let mut sign = true; + + // wrap around. false implies negative + if to_index >= ring_size { + to_index = to_index - ring_size; + sign = false; + } + + (to_index, sign) + }) + .unzip(); + (auto_map_index, auto_sign_index) +} + +/// Returns RGSW(m) +/// +/// RGSW = [RLWE'(-sm) || RLWE(m)] = [RLWE'_A(-sm), RLWE'_B(-sm), RLWE'_A(m), +/// RLWE'_B(m)] +/// +/// RGSW(m1) ciphertext is used for RLWE(m0) x RGSW(m1) multiplication. +/// Let RLWE(m) = [a, b] where b = as + e + m0. +/// For RLWExRGSW we calculate: +/// (\sum signed_decompose(a)[i] x RLWE(-s \beta^i' m1)) +/// + (\sum signed_decompose(b)[i'] x RLWE(\beta^i' m1)) +/// = RLWE(m0m1) +/// We denote decomposer count for signed_decompose(a)[i] with d_a and +/// corresponding gadget vector with `gadget_a`. We denote decomposer count for +/// signed_decompose(b)[i] with d_b and corresponding gadget vector with +/// `gadget_b` +/// +/// In secret key RGSW encrypton RLWE'_A(m) can be seeded. Hence, we seed it +/// using the `p_rng` passed and the retured RGSW ciphertext has d_a * 2 + d_b +/// rows +/// +/// - s: is the secret key +/// - m: message to encrypt +/// - gadget_a: Gadget vector for RLWE'(-sm) +/// - gadget_b: Gadget vector for RLWE'(m) +/// - p_rng: Seeded psuedo random generator used to sample RLWE'_A(m). +pub(crate) fn secret_key_encrypt_rgsw< + Mmut: MatrixMut + MatrixEntity, + S, + R: RandomFillGaussianInModulus<[Mmut::MatElement], ModOp::M> + + RandomFillUniformInModulus<[Mmut::MatElement], ModOp::M>, + PR: RandomFillUniformInModulus<[Mmut::MatElement], ModOp::M>, + ModOp: VectorOps + GetModulus, + NttOp: Ntt, +>( + out_rgsw: &mut Mmut, + m: &[Mmut::MatElement], + gadget_a: &[Mmut::MatElement], + gadget_b: &[Mmut::MatElement], + s: &[S], + mod_op: &ModOp, + ntt_op: &NttOp, + p_rng: &mut PR, + rng: &mut R, +) where + ::R: RowMut + RowEntity + TryConvertFrom1<[S], ModOp::M> + Debug, + Mmut::MatElement: Copy + Debug, +{ + let d_a = gadget_a.len(); + let d_b = gadget_b.len(); + let q = mod_op.modulus(); + let ring_size = s.len(); + assert!(out_rgsw.dimension() == (d_a * 2 + d_b, ring_size)); + assert!(m.as_ref().len() == ring_size); + + // RLWE(-sm), RLWE(m) + let (rlwe_dash_nsm, b_rlwe_dash_m) = out_rgsw.split_at_row_mut(d_a * 2); + + let mut s_eval = Mmut::R::try_convert_from(s, &q); + ntt_op.forward(s_eval.as_mut()); + + let mut scratch_space = Mmut::R::zeros(ring_size); + + // RLWE'(-sm) + let (a_rlwe_dash_nsm, b_rlwe_dash_nsm) = rlwe_dash_nsm.split_at_mut(d_a); + izip!( + a_rlwe_dash_nsm.iter_mut(), + b_rlwe_dash_nsm.iter_mut(), + gadget_a.iter() + ) + .for_each(|(ai, bi, beta_i)| { + // Sample a_i + RandomFillUniformInModulus::random_fill(rng, &q, ai.as_mut()); + + // a_i * s + scratch_space.as_mut().copy_from_slice(ai.as_ref()); + ntt_op.forward(scratch_space.as_mut()); + mod_op.elwise_mul_mut(scratch_space.as_mut(), s_eval.as_ref()); + ntt_op.backward(scratch_space.as_mut()); + + // b_i = e_i + a_i * s + RandomFillGaussianInModulus::random_fill(rng, &q, bi.as_mut()); + mod_op.elwise_add_mut(bi.as_mut(), scratch_space.as_ref()); + + // a_i + \beta_i * m + mod_op.elwise_scalar_mul(scratch_space.as_mut(), m.as_ref(), beta_i); + mod_op.elwise_add_mut(ai.as_mut(), scratch_space.as_ref()); + }); + + // RLWE(m) + let mut a_rlwe_dash_m = { + // polynomials of part A of RLWE'(m) are sampled from seed + let mut a = Mmut::zeros(d_b, ring_size); + a.iter_rows_mut() + .for_each(|ai| RandomFillUniformInModulus::random_fill(p_rng, &q, ai.as_mut())); + a + }; + + izip!( + a_rlwe_dash_m.iter_rows_mut(), + b_rlwe_dash_m.iter_mut(), + gadget_b.iter() + ) + .for_each(|(ai, bi, beta_i)| { + // ai * s + ntt_op.forward(ai.as_mut()); + mod_op.elwise_mul_mut(ai.as_mut(), s_eval.as_ref()); + ntt_op.backward(ai.as_mut()); + + // beta_i * m + mod_op.elwise_scalar_mul(scratch_space.as_mut(), m.as_ref(), beta_i); + + // Sample e_i + RandomFillGaussianInModulus::random_fill(rng, &q, bi.as_mut()); + // e_i + beta_i * m + ai*s + mod_op.elwise_add_mut(bi.as_mut(), scratch_space.as_ref()); + mod_op.elwise_add_mut(bi.as_mut(), ai.as_ref()); + }); +} + +/// Returns RGSW(m) encrypted with public key +/// +/// Follows the same routine as `secret_key_encrypt_rgsw` but with the +/// difference that each RLWE encryption uses public key instead of secret key. +/// +/// Since public key encryption cannot be seeded `RLWE'_A(m)` is included in the +/// ciphertext. Hence the returned RGSW ciphertext has d_a * 2 + d_b * 2 rows +pub(crate) fn public_key_encrypt_rgsw< + Mmut: MatrixMut + MatrixEntity, + M: Matrix, + R: RandomFillGaussianInModulus<[Mmut::MatElement], ModOp::M> + + RandomFill<[u8]> + + RandomElementInModulus, + ModOp: VectorOps + GetModulus, + NttOp: Ntt, +>( + out_rgsw: &mut Mmut, + m: &[M::MatElement], + public_key: &M, + gadget_a: &[Mmut::MatElement], + gadget_b: &[Mmut::MatElement], + mod_op: &ModOp, + ntt_op: &NttOp, + rng: &mut R, +) where + ::R: RowMut + RowEntity + TryConvertFrom1<[i32], ModOp::M>, + Mmut::MatElement: Copy, +{ + let ring_size = public_key.dimension().1; + let d_a = gadget_a.len(); + let d_b = gadget_b.len(); + assert!(public_key.dimension().0 == 2); + assert!(out_rgsw.dimension() == (d_a * 2 + d_b * 2, ring_size)); + + let mut pk_eval = Mmut::zeros(2, ring_size); + izip!(pk_eval.iter_rows_mut(), public_key.iter_rows()).for_each(|(to_i, from_i)| { + to_i.as_mut().copy_from_slice(from_i.as_ref()); + ntt_op.forward(to_i.as_mut()); + }); + let p0 = pk_eval.get_row_slice(0); + let p1 = pk_eval.get_row_slice(1); + + let q = mod_op.modulus(); + + // RGSW(m) = RLWE'(-sm), RLWE(m) + let (rlwe_dash_nsm, rlwe_dash_m) = out_rgsw.split_at_row_mut(d_a * 2); + + // RLWE(-sm) + let (rlwe_dash_nsm_parta, rlwe_dash_nsm_partb) = rlwe_dash_nsm.split_at_mut(d_a); + izip!( + rlwe_dash_nsm_parta.iter_mut(), + rlwe_dash_nsm_partb.iter_mut(), + gadget_a.iter() + ) + .for_each(|(ai, bi, beta_i)| { + // sample ephemeral secret u_i + let mut u = vec![0i32; ring_size]; + fill_random_ternary_secret_with_hamming_weight(u.as_mut(), ring_size >> 1, rng); + let mut u_eval = Mmut::R::try_convert_from(u.as_ref(), &q); + ntt_op.forward(u_eval.as_mut()); + + let mut u_eval_copy = Mmut::R::zeros(ring_size); + u_eval_copy.as_mut().copy_from_slice(u_eval.as_ref()); + + // p0 * u + mod_op.elwise_mul_mut(u_eval.as_mut(), p0.as_ref()); + // p1 * u + mod_op.elwise_mul_mut(u_eval_copy.as_mut(), p1.as_ref()); + ntt_op.backward(u_eval.as_mut()); + ntt_op.backward(u_eval_copy.as_mut()); + + // sample error + RandomFillGaussianInModulus::random_fill(rng, &q, ai.as_mut()); + RandomFillGaussianInModulus::random_fill(rng, &q, bi.as_mut()); + + // a = p0*u+e0 + mod_op.elwise_add_mut(ai.as_mut(), u_eval.as_ref()); + // b = p1*u+e1 + mod_op.elwise_add_mut(bi.as_mut(), u_eval_copy.as_ref()); + + // a = p0*u + e0 + \beta*m + // use u_eval as scratch + mod_op.elwise_scalar_mul(u_eval.as_mut(), m.as_ref(), beta_i); + mod_op.elwise_add_mut(ai.as_mut(), u_eval.as_ref()); + }); + + // RLWE(m) + let (rlwe_dash_m_parta, rlwe_dash_m_partb) = rlwe_dash_m.split_at_mut(d_b); + izip!( + rlwe_dash_m_parta.iter_mut(), + rlwe_dash_m_partb.iter_mut(), + gadget_b.iter() + ) + .for_each(|(ai, bi, beta_i)| { + // sample ephemeral secret u_i + let mut u = vec![0i32; ring_size]; + fill_random_ternary_secret_with_hamming_weight(u.as_mut(), ring_size >> 1, rng); + let mut u_eval = Mmut::R::try_convert_from(u.as_ref(), &q); + ntt_op.forward(u_eval.as_mut()); + + let mut u_eval_copy = Mmut::R::zeros(ring_size); + u_eval_copy.as_mut().copy_from_slice(u_eval.as_ref()); + + // p0 * u + mod_op.elwise_mul_mut(u_eval.as_mut(), p0.as_ref()); + // p1 * u + mod_op.elwise_mul_mut(u_eval_copy.as_mut(), p1.as_ref()); + ntt_op.backward(u_eval.as_mut()); + ntt_op.backward(u_eval_copy.as_mut()); + + // sample error + RandomFillGaussianInModulus::random_fill(rng, &q, ai.as_mut()); + RandomFillGaussianInModulus::random_fill(rng, &q, bi.as_mut()); + + // a = p0*u+e0 + mod_op.elwise_add_mut(ai.as_mut(), u_eval.as_ref()); + // b = p1*u+e1 + mod_op.elwise_add_mut(bi.as_mut(), u_eval_copy.as_ref()); + + // b = p1*u + e0 + \beta*m + // use u_eval as scratch + mod_op.elwise_scalar_mul(u_eval.as_mut(), m.as_ref(), beta_i); + mod_op.elwise_add_mut(bi.as_mut(), u_eval.as_ref()); + }); +} + +/// Returns key switching key to key switch ciphertext RLWE_{from_s}(m) +/// to RLWE_{to_s}(m). +/// +/// Let key switching decomposer have `d` decompostion count with gadget vector: +/// [1, \beta, ..., \beta^d-1] +/// +/// Key switching key consists of `d` RLWE ciphertexts: +/// RLWE'_{to_s}(-from_s) = [RLWE_{to_s}(\beta^i -from_s)] +/// +/// In RLWE(m) s.t. b = as + e + m where s is the secret key, `a` can be seeded. +/// And we seed all RLWE ciphertexts in key switchin key. +/// +/// - neg_from_s: Negative of secret polynomial to key switch from (i.e. +/// -from_s) +/// - to_s: secret polynomial to key switch to. +/// - gadget_vector: Gadget vector of decomposer used in key switch +/// - p_rng: Seeded pseudo random generate used to generate `a` polynomials of +/// key switching key RLWE ciphertexts +fn seeded_rlwe_ksk_gen< + Mmut: MatrixMut + MatrixEntity, + ModOp: ArithmeticOps + + VectorOps + + GetModulus, + NttOp: Ntt, + R: RandomFillGaussianInModulus<[Mmut::MatElement], ModOp::M>, + PR: RandomFillUniformInModulus<[Mmut::MatElement], ModOp::M>, +>( + ksk_out: &mut Mmut, + neg_from_s: Mmut::R, + mut to_s: Mmut::R, + gadget_vector: &[Mmut::MatElement], + mod_op: &ModOp, + ntt_op: &NttOp, + p_rng: &mut PR, + rng: &mut R, +) where + ::R: RowMut, +{ + let ring_size = neg_from_s.as_ref().len(); + let d = gadget_vector.len(); + assert!(ksk_out.dimension() == (d, ring_size)); + + let q = mod_op.modulus(); + + ntt_op.forward(to_s.as_mut()); + + // RLWE'_{to_s}(-from_s) + let mut part_a = { + let mut a = Mmut::zeros(d, ring_size); + a.iter_rows_mut() + .for_each(|ai| RandomFillUniformInModulus::random_fill(p_rng, q, ai.as_mut())); + a + }; + izip!( + part_a.iter_rows_mut(), + ksk_out.iter_rows_mut(), + gadget_vector.iter(), + ) + .for_each(|(ai, bi, beta_i)| { + // si * ai + ntt_op.forward(ai.as_mut()); + mod_op.elwise_mul_mut(ai.as_mut(), to_s.as_ref()); + ntt_op.backward(ai.as_mut()); + + // ei + to_s*ai + RandomFillGaussianInModulus::random_fill(rng, &q, bi.as_mut()); + mod_op.elwise_add_mut(bi.as_mut(), ai.as_ref()); + + // beta_i * -from_s + // use ai as scratch space + mod_op.elwise_scalar_mul(ai.as_mut(), neg_from_s.as_ref(), beta_i); + + // bi = ei + to_s*ai + beta_i*-from_s + mod_op.elwise_add_mut(bi.as_mut(), ai.as_ref()); + }); +} + +/// Returns auto key to send RLWE(m(X)) -> RLWE(m(X^k)) +/// +/// Auto key is key switchin key that key-switches RLWE_{s(X^k)}(m(X^k)) to +/// RLWE_{s(X)}(m(X^k)). +/// +/// - s: secret polynomial s(X) +/// - auto_k: k used in for autmorphism X -> X^k +/// - gadget_vector: Gadget vector corresponding to decomposer used in key +/// switch +/// - p_rng: pseudo random generator used to generate `a` polynomials of key +/// switching key RLWE ciphertexts +pub(crate) fn seeded_auto_key_gen< + Mmut: MatrixMut + MatrixEntity, + ModOp: ArithmeticOps + + VectorOps + + GetModulus, + NttOp: Ntt, + S, + R: RandomFillGaussianInModulus<[Mmut::MatElement], ModOp::M>, + PR: RandomFillUniformInModulus<[Mmut::MatElement], ModOp::M>, +>( + ksk_out: &mut Mmut, + s: &[S], + auto_k: isize, + gadget_vector: &[Mmut::MatElement], + mod_op: &ModOp, + ntt_op: &NttOp, + p_rng: &mut PR, + rng: &mut R, +) where + ::R: RowMut, + Mmut::R: TryConvertFrom1<[S], ModOp::M> + RowEntity, + Mmut::MatElement: Copy + Sub, +{ + let ring_size = s.len(); + let (auto_map_index, auto_map_sign) = generate_auto_map(ring_size, auto_k); + + let q = mod_op.modulus(); + + // s(X) -> -s(X^k) + let s = Mmut::R::try_convert_from(s, q); + let mut neg_s_auto = Mmut::R::zeros(s.as_ref().len()); + izip!(s.as_ref(), auto_map_index.iter(), auto_map_sign.iter()).for_each( + |(el, to_index, sign)| { + // if sign is +ve (true), then negate because we need -s(X) (i.e. do the + // opposite than the usual case) + if *sign { + neg_s_auto.as_mut()[*to_index] = mod_op.neg(el); + } else { + neg_s_auto.as_mut()[*to_index] = *el; + } + }, + ); + + // Ksk from -s(X^k) to s(X) + seeded_rlwe_ksk_gen( + ksk_out, + neg_s_auto, + s, + gadget_vector, + mod_op, + ntt_op, + p_rng, + rng, + ); +} + +/// Returns seeded RLWE(m(X)) +/// +/// RLWE(m(X)) = [a(X), b(X) = a(X)s(X) + e(X) + m(X)] +/// +/// a(X) of RLWE encyrptions using secret key s(X) can be seeded. We use seeded +/// pseudo random generator `p_rng` to sample a(X) and return seeded RLWE +/// ciphertext (i.e. only b(X)) +pub(crate) fn seeded_secret_key_encrypt_rlwe< + Ro: Row + RowMut + RowEntity, + ModOp: VectorOps + GetModulus, + NttOp: Ntt, + S, + R: RandomFillGaussianInModulus<[Ro::Element], ModOp::M>, + PR: RandomFillUniformInModulus<[Ro::Element], ModOp::M>, +>( + m: &[Ro::Element], + b_rlwe_out: &mut Ro, + s: &[S], + mod_op: &ModOp, + ntt_op: &NttOp, + p_rng: &mut PR, + rng: &mut R, +) where + Ro: TryConvertFrom1<[S], ModOp::M> + Debug, +{ + let ring_size = s.len(); + assert!(m.as_ref().len() == ring_size); + assert!(b_rlwe_out.as_ref().len() == ring_size); + + let q = mod_op.modulus(); + + // sample a + let mut a = { + let mut a = Ro::zeros(ring_size); + RandomFillUniformInModulus::random_fill(p_rng, q, a.as_mut()); + a + }; + + // s * a + let mut sa = Ro::try_convert_from(s, q); + ntt_op.forward(sa.as_mut()); + ntt_op.forward(a.as_mut()); + mod_op.elwise_mul_mut(sa.as_mut(), a.as_ref()); + ntt_op.backward(sa.as_mut()); + + // sample e + RandomFillGaussianInModulus::random_fill(rng, q, b_rlwe_out.as_mut()); + mod_op.elwise_add_mut(b_rlwe_out.as_mut(), m.as_ref()); + mod_op.elwise_add_mut(b_rlwe_out.as_mut(), sa.as_ref()); +} + +/// Returns RLWE(m(X)) encrypted using public key. +/// +/// Unlike secret key encryption, public key encryption cannot be seeded +pub(crate) fn public_key_encrypt_rlwe< + M: Matrix, + Mmut: MatrixMut, + ModOp: VectorOps + GetModulus, + NttOp: Ntt, + S, + R: RandomFillGaussianInModulus<[M::MatElement], ModOp::M> + + RandomFillUniformInModulus<[M::MatElement], ModOp::M> + + RandomFill<[u8]> + + RandomElementInModulus, +>( + rlwe_out: &mut Mmut, + pk: &M, + m: &[M::MatElement], + mod_op: &ModOp, + ntt_op: &NttOp, + rng: &mut R, +) where + ::R: RowMut + TryConvertFrom1<[S], ModOp::M> + RowEntity, + M::MatElement: Copy, + S: Zero + Signed + Copy, +{ + let ring_size = m.len(); + assert!(rlwe_out.dimension() == (2, ring_size)); + + let q = mod_op.modulus(); + + let mut u = vec![S::zero(); ring_size]; + fill_random_ternary_secret_with_hamming_weight(u.as_mut(), ring_size >> 1, rng); + let mut u = Mmut::R::try_convert_from(&u, q); + ntt_op.forward(u.as_mut()); + + let mut ua = Mmut::R::zeros(ring_size); + ua.as_mut().copy_from_slice(pk.get_row_slice(0)); + let mut ub = Mmut::R::zeros(ring_size); + ub.as_mut().copy_from_slice(pk.get_row_slice(1)); + + // a*u + ntt_op.forward(ua.as_mut()); + mod_op.elwise_mul_mut(ua.as_mut(), u.as_ref()); + ntt_op.backward(ua.as_mut()); + + // b*u + ntt_op.forward(ub.as_mut()); + mod_op.elwise_mul_mut(ub.as_mut(), u.as_ref()); + ntt_op.backward(ub.as_mut()); + + // sample error + rlwe_out.iter_rows_mut().for_each(|ri| { + RandomFillGaussianInModulus::random_fill(rng, &q, ri.as_mut()); + }); + + // a*u + e0 + mod_op.elwise_add_mut(rlwe_out.get_row_mut(0), ua.as_ref()); + // b*u + e1 + mod_op.elwise_add_mut(rlwe_out.get_row_mut(1), ub.as_ref()); + + // b*u + e1 + m + mod_op.elwise_add_mut(rlwe_out.get_row_mut(1), m); +} + +/// Returns RLWE public key generated using RLWE secret key +pub(crate) fn rlwe_public_key< + Ro: RowMut + RowEntity, + S, + ModOp: VectorOps + GetModulus, + NttOp: Ntt, + PRng: RandomFillUniformInModulus<[Ro::Element], ModOp::M>, + Rng: RandomFillGaussianInModulus<[Ro::Element], ModOp::M>, +>( + part_b_out: &mut Ro, + s: &[S], + ntt_op: &NttOp, + mod_op: &ModOp, + p_rng: &mut PRng, + rng: &mut Rng, +) where + Ro: TryConvertFrom1<[S], ModOp::M>, +{ + let ring_size = s.len(); + assert!(part_b_out.as_ref().len() == ring_size); + + let q = mod_op.modulus(); + + // sample a + let mut a = { + let mut tmp = Ro::zeros(ring_size); + RandomFillUniformInModulus::random_fill(p_rng, &q, tmp.as_mut()); + tmp + }; + ntt_op.forward(a.as_mut()); + + // s*a + let mut sa = Ro::try_convert_from(s, &q); + ntt_op.forward(sa.as_mut()); + mod_op.elwise_mul_mut(sa.as_mut(), a.as_ref()); + ntt_op.backward(sa.as_mut()); + + // s*a + e + RandomFillGaussianInModulus::random_fill(rng, &q, part_b_out.as_mut()); + mod_op.elwise_add_mut(part_b_out.as_mut(), sa.as_ref()); +} + +/// Decrypts ciphertext RLWE(m) and returns noisy m +/// +/// We assume RLWE(m) = [a, b] is a degree 1 ciphertext s.t. b - sa = e + m +pub(crate) fn decrypt_rlwe< + R: RowMut, + M: Matrix, + ModOp: VectorOps + GetModulus, + NttOp: Ntt, + S, +>( + rlwe_ct: &M, + s: &[S], + m_out: &mut R, + ntt_op: &NttOp, + mod_op: &ModOp, +) where + R: TryConvertFrom1<[S], ModOp::M>, + R::Element: Copy, +{ + let ring_size = s.len(); + assert!(rlwe_ct.dimension() == (2, ring_size)); + assert!(m_out.as_ref().len() == ring_size); + + // transform a to evluation form + m_out.as_mut().copy_from_slice(rlwe_ct.get_row_slice(0)); + ntt_op.forward(m_out.as_mut()); + + // -s*a + let mut s = R::try_convert_from(&s, mod_op.modulus()); + ntt_op.forward(s.as_mut()); + mod_op.elwise_mul_mut(m_out.as_mut(), s.as_ref()); + mod_op.elwise_neg_mut(m_out.as_mut()); + ntt_op.backward(m_out.as_mut()); + + // m+e = b - s*a + mod_op.elwise_add_mut(m_out.as_mut(), rlwe_ct.get_row_slice(1)); +} + +// Measures maximum noise in degree 1 RLWE ciphertext against message `want_m` +fn measure_max_noise< + Mmut: MatrixMut + Matrix, + ModOp: VectorOps + GetModulus, + NttOp: Ntt, + S, +>( + rlwe_ct: &Mmut, + want_m: &Mmut::R, + ntt_op: &NttOp, + mod_op: &ModOp, + s: &[S], +) -> f64 +where + ::R: RowMut, + Mmut::R: RowEntity + TryConvertFrom1<[S], ModOp::M>, + Mmut::MatElement: PrimInt + ToPrimitive + Debug, +{ + let ring_size = s.len(); + assert!(rlwe_ct.dimension() == (2, ring_size)); + assert!(want_m.as_ref().len() == ring_size); + + // -(s * a) + let q = mod_op.modulus(); + let mut s = Mmut::R::try_convert_from(s, &q); + ntt_op.forward(s.as_mut()); + let mut a = Mmut::R::zeros(ring_size); + a.as_mut().copy_from_slice(rlwe_ct.get_row_slice(0)); + ntt_op.forward(a.as_mut()); + mod_op.elwise_mul_mut(s.as_mut(), a.as_ref()); + mod_op.elwise_neg_mut(s.as_mut()); + ntt_op.backward(s.as_mut()); + + // m+e = b - s*a + let mut m_plus_e = s; + mod_op.elwise_add_mut(m_plus_e.as_mut(), rlwe_ct.get_row_slice(1)); + + // difference + mod_op.elwise_sub_mut(m_plus_e.as_mut(), want_m.as_ref()); + + let mut max_diff_bits = f64::MIN; + m_plus_e.as_ref().iter().for_each(|v| { + let bits = (q.map_element_to_i64(v).to_f64().unwrap().abs()).log2(); + + if max_diff_bits < bits { + max_diff_bits = bits; + } + }); + + return max_diff_bits; +} diff --git a/src/rgsw/mod.rs b/src/rgsw/mod.rs new file mode 100644 index 0000000..a50608a --- /dev/null +++ b/src/rgsw/mod.rs @@ -0,0 +1,982 @@ +mod keygen; +mod runtime; + +pub(crate) use keygen::*; +pub(crate) use runtime::*; + +#[cfg(test)] +pub(crate) mod tests { + use std::{fmt::Debug, marker::PhantomData, vec}; + + use itertools::{izip, Itertools}; + use rand::{thread_rng, Rng}; + + use crate::{ + backend::{GetModulus, ModInit, ModularOpsU64, Modulus, VectorOps}, + decomposer::{Decomposer, DefaultDecomposer, RlweDecomposer}, + ntt::{Ntt, NttBackendU64, NttInit}, + random::{DefaultSecureRng, NewWithSeed, RandomFillUniformInModulus}, + rgsw::{ + rlwe_auto_scratch_rows, rlwe_auto_shoup, rlwe_by_rgsw_shoup, rlwe_x_rgsw_scratch_rows, + RgswCiphertextRef, RlweCiphertextMutRef, RlweKskRef, RuntimeScratchMutRef, + }, + utils::{ + fill_random_ternary_secret_with_hamming_weight, generate_prime, negacyclic_mul, + tests::Stats, ToShoup, TryConvertFrom1, WithLocal, + }, + Matrix, MatrixEntity, MatrixMut, Row, RowEntity, RowMut, + }; + + use super::{ + keygen::{ + decrypt_rlwe, generate_auto_map, rlwe_public_key, secret_key_encrypt_rgsw, + seeded_auto_key_gen, seeded_secret_key_encrypt_rlwe, + }, + rgsw_x_rgsw_scratch_rows, + runtime::{rgsw_by_rgsw_inplace, rlwe_auto, rlwe_by_rgsw}, + RgswCiphertextMutRef, + }; + + struct SeededAutoKey + where + M: Matrix, + { + data: M, + seed: S, + modulus: Mod, + } + + impl> SeededAutoKey { + fn empty( + ring_size: usize, + auto_decomposer: &D, + seed: S, + modulus: Mod, + ) -> Self { + SeededAutoKey { + data: M::zeros(auto_decomposer.decomposition_count().0, ring_size), + seed, + modulus, + } + } + } + + struct AutoKeyEvaluationDomain { + data: M, + _phantom: PhantomData<(R, N)>, + } + + impl< + M: MatrixMut + MatrixEntity, + Mod: Modulus + Clone, + R: RandomFillUniformInModulus<[M::MatElement], Mod> + NewWithSeed, + N: NttInit + Ntt, + > From<&SeededAutoKey> for AutoKeyEvaluationDomain + where + ::R: RowMut, + M::MatElement: Copy, + + R::Seed: Clone, + { + fn from(value: &SeededAutoKey) -> Self { + let (d, ring_size) = value.data.dimension(); + let mut data = M::zeros(2 * d, ring_size); + + // sample RLWE'_A(-s(X^k)) + let mut p_rng = R::new_with_seed(value.seed.clone()); + data.iter_rows_mut().take(d).for_each(|r| { + RandomFillUniformInModulus::random_fill(&mut p_rng, &value.modulus, r.as_mut()); + }); + + // copy over RLWE'_B(-s(X^k)) + izip!(data.iter_rows_mut().skip(d), value.data.iter_rows()).for_each( + |(to_r, from_r)| { + to_r.as_mut().copy_from_slice(from_r.as_ref()); + }, + ); + + // send RLWE'(-s(X^k)) polynomials to evaluation domain + let ntt_op = N::new(&value.modulus, ring_size); + data.iter_rows_mut() + .for_each(|r| ntt_op.forward(r.as_mut())); + + AutoKeyEvaluationDomain { + data, + _phantom: PhantomData, + } + } + } + + struct RgswCiphertext { + /// Rgsw ciphertext polynomials + data: M, + modulus: Mod, + /// Decomposition for RLWE part A + d_a: usize, + /// Decomposition for RLWE part B + d_b: usize, + } + + impl> RgswCiphertext { + pub(crate) fn empty( + ring_size: usize, + decomposer: &D, + modulus: Mod, + ) -> RgswCiphertext { + RgswCiphertext { + data: M::zeros( + decomposer.a().decomposition_count().0 * 2 + + decomposer.b().decomposition_count().0 * 2, + ring_size, + ), + d_a: decomposer.a().decomposition_count().0, + d_b: decomposer.b().decomposition_count().0, + modulus, + } + } + } + + pub struct SeededRgswCiphertext + where + M: Matrix, + { + pub(crate) data: M, + seed: S, + modulus: Mod, + /// Decomposition for RLWE part A + d_a: usize, + /// Decomposition for RLWE part B + d_b: usize, + } + + impl SeededRgswCiphertext { + pub(crate) fn empty( + ring_size: usize, + decomposer: &D, + seed: S, + modulus: Mod, + ) -> SeededRgswCiphertext { + SeededRgswCiphertext { + data: M::zeros( + decomposer.a().decomposition_count().0 * 2 + + decomposer.b().decomposition_count().0, + ring_size, + ), + seed, + modulus, + d_a: decomposer.a().decomposition_count().0, + d_b: decomposer.b().decomposition_count().0, + } + } + } + + impl Debug for SeededRgswCiphertext + where + M::MatElement: Debug, + { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("SeededRgswCiphertext") + .field("data", &self.data) + .field("seed", &self.seed) + .field("modulus", &self.modulus) + .finish() + } + } + + pub struct RgswCiphertextEvaluationDomain { + pub(crate) data: M, + modulus: Mod, + _phantom: PhantomData<(R, N)>, + } + + impl< + M: MatrixMut + MatrixEntity, + Mod: Modulus + Clone, + R: NewWithSeed + RandomFillUniformInModulus<[M::MatElement], Mod>, + N: NttInit + Ntt + Debug, + > From<&SeededRgswCiphertext> + for RgswCiphertextEvaluationDomain + where + ::R: RowMut, + M::MatElement: Copy, + R::Seed: Clone, + M: Debug, + { + fn from(value: &SeededRgswCiphertext) -> Self { + let mut data = M::zeros(value.d_a * 2 + value.d_b * 2, value.data.dimension().1); + + // copy RLWE'(-sm) + izip!( + data.iter_rows_mut().take(value.d_a * 2), + value.data.iter_rows().take(value.d_a * 2) + ) + .for_each(|(to_ri, from_ri)| { + to_ri.as_mut().copy_from_slice(from_ri.as_ref()); + }); + + // sample A polynomials of RLWE'(m) - RLWE'A(m) + let mut p_rng = R::new_with_seed(value.seed.clone()); + izip!(data.iter_rows_mut().skip(value.d_a * 2).take(value.d_b * 1)) + .for_each(|ri| p_rng.random_fill(&value.modulus, ri.as_mut())); + + // RLWE'_B(m) + izip!( + data.iter_rows_mut().skip(value.d_a * 2 + value.d_b), + value.data.iter_rows().skip(value.d_a * 2) + ) + .for_each(|(to_ri, from_ri)| { + to_ri.as_mut().copy_from_slice(from_ri.as_ref()); + }); + + // Send polynomials to evaluation domain + let ring_size = data.dimension().1; + let nttop = N::new(&value.modulus, ring_size); + data.iter_rows_mut() + .for_each(|ri| nttop.forward(ri.as_mut())); + + Self { + data: data, + modulus: value.modulus.clone(), + _phantom: PhantomData, + } + } + } + + impl< + M: MatrixMut + MatrixEntity, + Mod: Modulus + Clone, + R, + N: NttInit + Ntt, + > From<&RgswCiphertext> for RgswCiphertextEvaluationDomain + where + ::R: RowMut, + M::MatElement: Copy, + M: Debug, + { + fn from(value: &RgswCiphertext) -> Self { + let mut data = M::zeros(value.d_a * 2 + value.d_b * 2, value.data.dimension().1); + + // copy RLWE'(-sm) + izip!( + data.iter_rows_mut().take(value.d_a * 2), + value.data.iter_rows().take(value.d_a * 2) + ) + .for_each(|(to_ri, from_ri)| { + to_ri.as_mut().copy_from_slice(from_ri.as_ref()); + }); + + // copy RLWE'(m) + izip!( + data.iter_rows_mut().skip(value.d_a * 2), + value.data.iter_rows().skip(value.d_a * 2) + ) + .for_each(|(to_ri, from_ri)| { + to_ri.as_mut().copy_from_slice(from_ri.as_ref()); + }); + + // Send polynomials to evaluation domain + let ring_size = data.dimension().1; + let nttop = N::new(&value.modulus, ring_size); + data.iter_rows_mut() + .for_each(|ri| nttop.forward(ri.as_mut())); + + Self { + data: data, + modulus: value.modulus.clone(), + _phantom: PhantomData, + } + } + } + + impl Debug for RgswCiphertextEvaluationDomain { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("RgswCiphertextEvaluationDomain") + .field("data", &self.data) + .field("modulus", &self.modulus) + .field("_phantom", &self._phantom) + .finish() + } + } + + struct SeededRlweCiphertext { + data: R, + seed: S, + modulus: Mod, + } + + impl SeededRlweCiphertext { + fn empty(ring_size: usize, seed: S, modulus: Mod) -> Self { + SeededRlweCiphertext { + data: R::zeros(ring_size), + seed, + modulus, + } + } + } + + pub struct RlweCiphertext { + data: M, + _phatom: PhantomData, + } + + impl< + R: Row, + M: MatrixEntity + MatrixMut, + Rng: NewWithSeed + RandomFillUniformInModulus<[M::MatElement], Mod>, + Mod: Modulus, + > From<&SeededRlweCiphertext> for RlweCiphertext + where + Rng::Seed: Clone, + ::R: RowMut, + R::Element: Copy, + { + fn from(value: &SeededRlweCiphertext) -> Self { + let mut data = M::zeros(2, value.data.as_ref().len()); + + // sample a + let mut p_rng = Rng::new_with_seed(value.seed.clone()); + RandomFillUniformInModulus::random_fill( + &mut p_rng, + &value.modulus, + data.get_row_mut(0), + ); + + data.get_row_mut(1).copy_from_slice(value.data.as_ref()); + + RlweCiphertext { + data, + _phatom: PhantomData, + } + } + } + + struct SeededRlwePublicKey { + data: Ro, + seed: S, + modulus: Ro::Element, + } + + impl SeededRlwePublicKey { + pub(crate) fn empty(ring_size: usize, seed: S, modulus: Ro::Element) -> Self { + Self { + data: Ro::zeros(ring_size), + seed, + modulus, + } + } + } + + struct RlwePublicKey { + data: M, + _phantom: PhantomData, + } + + impl< + M: MatrixMut + MatrixEntity, + Rng: NewWithSeed + RandomFillUniformInModulus<[M::MatElement], M::MatElement>, + > From<&SeededRlwePublicKey> for RlwePublicKey + where + ::R: RowMut, + M::MatElement: Copy, + Rng::Seed: Copy, + { + fn from(value: &SeededRlwePublicKey) -> Self { + let mut data = M::zeros(2, value.data.as_ref().len()); + + // sample a + let mut p_rng = Rng::new_with_seed(value.seed); + RandomFillUniformInModulus::random_fill( + &mut p_rng, + &value.modulus, + data.get_row_mut(0), + ); + + // copy over b + data.get_row_mut(1).copy_from_slice(value.data.as_ref()); + + Self { + data, + _phantom: PhantomData, + } + } + } + + #[derive(Clone)] + struct RlweSecret { + pub(crate) values: Vec, + } + + impl RlweSecret { + pub fn random(hw: usize, n: usize) -> RlweSecret { + DefaultSecureRng::with_local_mut(|rng| { + let mut out = vec![0i32; n]; + fill_random_ternary_secret_with_hamming_weight(&mut out, hw, rng); + + RlweSecret { values: out } + }) + } + + fn values(&self) -> &[i32] { + &self.values + } + } + + fn random_seed() -> [u8; 32] { + let mut rng = DefaultSecureRng::new(); + let mut seed = [0u8; 32]; + rng.fill_bytes(&mut seed); + seed + } + + /// Encrypts m as RGSW ciphertext RGSW(m) using supplied secret key. Returns + /// seeded RGSW ciphertext in coefficient domain + fn sk_encrypt_rgsw + Clone>( + m: &[u64], + s: &[i32], + decomposer: &(DefaultDecomposer, DefaultDecomposer), + mod_op: &ModularOpsU64, + ntt_op: &NttBackendU64, + ) -> SeededRgswCiphertext>, [u8; 32], T> { + let ring_size = s.len(); + assert!(m.len() == s.len()); + + let mut rng = DefaultSecureRng::new(); + + let q = mod_op.modulus(); + let rgsw_seed = random_seed(); + let mut seeded_rgsw_ct = SeededRgswCiphertext::>, [u8; 32], T>::empty( + ring_size as usize, + decomposer, + rgsw_seed, + q.clone(), + ); + let mut p_rng = DefaultSecureRng::new_seeded(rgsw_seed); + secret_key_encrypt_rgsw( + &mut seeded_rgsw_ct.data, + m, + &decomposer.a().gadget_vector(), + &decomposer.b().gadget_vector(), + s, + mod_op, + ntt_op, + &mut p_rng, + &mut rng, + ); + seeded_rgsw_ct + } + + #[test] + fn rlwe_encrypt_decryption() { + let logq = 50; + let logp = 2; + let ring_size = 1 << 4; + let q = generate_prime(logq, ring_size, 1u64 << logq).unwrap(); + let p = 1u64 << logp; + + let mut rng = DefaultSecureRng::new(); + + let s = RlweSecret::random((ring_size >> 1) as usize, ring_size as usize); + + // sample m0 + let mut m0 = vec![0u64; ring_size as usize]; + RandomFillUniformInModulus::<[u64], u64>::random_fill( + &mut rng, + &(1u64 << logp), + m0.as_mut_slice(), + ); + + let ntt_op = NttBackendU64::new(&q, ring_size as usize); + let mod_op = ModularOpsU64::new(q); + + // encrypt m0 + let encoded_m = m0 + .iter() + .map(|v| (((*v as f64) * q as f64) / (p as f64)).round() as u64) + .collect_vec(); + let seed = random_seed(); + let mut rlwe_in_ct = + SeededRlweCiphertext::, _, _>::empty(ring_size as usize, seed, q); + let mut p_rng = DefaultSecureRng::new_seeded(seed); + seeded_secret_key_encrypt_rlwe( + &encoded_m, + &mut rlwe_in_ct.data, + s.values(), + &mod_op, + &ntt_op, + &mut p_rng, + &mut rng, + ); + let rlwe_in_ct = RlweCiphertext::>, DefaultSecureRng>::from(&rlwe_in_ct); + + let mut encoded_m_back = vec![0u64; ring_size as usize]; + decrypt_rlwe( + &rlwe_in_ct.data, + s.values(), + &mut encoded_m_back, + &ntt_op, + &mod_op, + ); + let m_back = encoded_m_back + .iter() + .map(|v| (((*v as f64 * p as f64) / q as f64).round() as u64) % p) + .collect_vec(); + assert_eq!(m0, m_back); + } + + #[test] + fn rlwe_by_rgsw_works() { + let logq = 50; + let logp = 2; + let ring_size = 1 << 4; + let q = generate_prime(logq, ring_size, 1u64 << logq).unwrap(); + let p: u64 = 1u64 << logp; + + let mut rng = DefaultSecureRng::new_seeded([0u8; 32]); + + let s = RlweSecret::random((ring_size >> 1) as usize, ring_size as usize); + + let mut m0 = vec![0u64; ring_size as usize]; + RandomFillUniformInModulus::<[u64], _>::random_fill( + &mut rng, + &(1u64 << logp), + m0.as_mut_slice(), + ); + let mut m1 = vec![0u64; ring_size as usize]; + m1[thread_rng().gen_range(0..ring_size) as usize] = 1; + + let ntt_op = NttBackendU64::new(&q, ring_size as usize); + let mod_op = ModularOpsU64::new(q); + let d_rgsw = 10; + let logb = 5; + let decomposer = ( + DefaultDecomposer::new(q, logb, d_rgsw), + DefaultDecomposer::new(q, logb, d_rgsw), + ); + + // create public key + let pk_seed = random_seed(); + let mut pk_prng = DefaultSecureRng::new_seeded(pk_seed); + let mut seeded_pk = + SeededRlwePublicKey::, _>::empty(ring_size as usize, pk_seed, q); + rlwe_public_key( + &mut seeded_pk.data, + s.values(), + &ntt_op, + &mod_op, + &mut pk_prng, + &mut rng, + ); + // let pk = RlwePublicKey::>, DefaultSecureRng>::from(&seeded_pk); + + // Encrypt m1 as RGSW(m1) + let rgsw_ct = { + // Encryption m1 as RGSW(m1) using secret key + let seeded_rgsw_ct = sk_encrypt_rgsw(&m1, s.values(), &decomposer, &mod_op, &ntt_op); + RgswCiphertextEvaluationDomain::>, _,DefaultSecureRng, NttBackendU64>::from(&seeded_rgsw_ct) + }; + + // Encrypt m0 as RLWE(m0) + let mut rlwe_in_ct = { + let encoded_m = m0 + .iter() + .map(|v| (((*v as f64) * q as f64) / (p as f64)).round() as u64) + .collect_vec(); + + let seed = random_seed(); + let mut p_rng = DefaultSecureRng::new_seeded(seed); + let mut seeded_rlwe = SeededRlweCiphertext::empty(ring_size as usize, seed, q); + seeded_secret_key_encrypt_rlwe( + &encoded_m, + &mut seeded_rlwe.data, + s.values(), + &mod_op, + &ntt_op, + &mut p_rng, + &mut rng, + ); + RlweCiphertext::>, DefaultSecureRng>::from(&seeded_rlwe) + }; + + // RLWE(m0m1) = RLWE(m0) x RGSW(m1) + let mut scratch_space = + vec![vec![0u64; ring_size as usize]; rlwe_x_rgsw_scratch_rows(&decomposer)]; + + // rlwe x rgsw with with soup repr + let rlwe_in_ct_shoup = { + let mut rlwe_in_ct_shoup = rlwe_in_ct.data.clone(); + + let rgsw_ct_shoup = ToShoup::to_shoup(&rgsw_ct.data, q); + + rlwe_by_rgsw_shoup( + &mut RlweCiphertextMutRef::new(rlwe_in_ct_shoup.as_mut()), + &RgswCiphertextRef::new( + rgsw_ct.data.as_ref(), + decomposer.a().decomposition_count().0, + decomposer.b().decomposition_count().0, + ), + &RgswCiphertextRef::new( + rgsw_ct_shoup.as_ref(), + decomposer.a().decomposition_count().0, + decomposer.b().decomposition_count().0, + ), + &mut RuntimeScratchMutRef::new(scratch_space.as_mut()), + &decomposer, + &ntt_op, + &mod_op, + false, + ); + + rlwe_in_ct_shoup + }; + + // rlwe x rgsw normal + { + rlwe_by_rgsw( + &mut RlweCiphertextMutRef::new(rlwe_in_ct.data.as_mut()), + &RgswCiphertextRef::new( + rgsw_ct.data.as_ref(), + decomposer.a().decomposition_count().0, + decomposer.b().decomposition_count().0, + ), + &mut RuntimeScratchMutRef::new(scratch_space.as_mut()), + &decomposer, + &ntt_op, + &mod_op, + false, + ); + } + + // output from both functions must be equal + assert_eq!(rlwe_in_ct.data, rlwe_in_ct_shoup); + + // Decrypt RLWE(m0m1) + let mut encoded_m0m1_back = vec![0u64; ring_size as usize]; + decrypt_rlwe( + &rlwe_in_ct_shoup, + s.values(), + &mut encoded_m0m1_back, + &ntt_op, + &mod_op, + ); + let m0m1_back = encoded_m0m1_back + .iter() + .map(|v| (((*v as f64 * p as f64) / (q as f64)).round() as u64) % p) + .collect_vec(); + + let mul_mod = |v0: &u64, v1: &u64| (v0 * v1) % p; + let m0m1 = negacyclic_mul(&m0, &m1, mul_mod, p); + + // { + // // measure noise + // let encoded_m_ideal = m0m1 + // .iter() + // .map(|v| (((*v as f64) * q as f64) / (p as f64)).round() as u64) + // .collect_vec(); + + // let noise = measure_noise(&rlwe_in_ct, &encoded_m_ideal, &ntt_op, + // &mod_op, s.values()); println!("Noise RLWE(m0m1)(= + // RLWE(m0)xRGSW(m1)) : {noise}"); } + + assert!( + m0m1 == m0m1_back, + "Expected {:?} \n Got {:?}", + m0m1, + m0m1_back + ); + } + + #[test] + fn rlwe_auto_works() { + let logq = 55; + let ring_size = 1 << 11; + let q = generate_prime(logq, 2 * ring_size, 1u64 << logq).unwrap(); + let logp = 3; + let p = 1u64 << logp; + let d_rgsw = 5; + let logb = 11; + + let mut rng = DefaultSecureRng::new(); + let s = RlweSecret::random((ring_size >> 1) as usize, ring_size as usize); + + let mut m = vec![0u64; ring_size as usize]; + RandomFillUniformInModulus::random_fill(&mut rng, &p, m.as_mut_slice()); + let encoded_m = m + .iter() + .map(|v| (((*v as f64 * q as f64) / (p as f64)).round() as u64)) + .collect_vec(); + + let ntt_op = NttBackendU64::new(&q, ring_size as usize); + let mod_op = ModularOpsU64::new(q); + + // RLWE_{s}(m) + let seed_rlwe = random_seed(); + let mut seeded_rlwe_m = SeededRlweCiphertext::empty(ring_size as usize, seed_rlwe, q); + let mut p_rng = DefaultSecureRng::new_seeded(seed_rlwe); + seeded_secret_key_encrypt_rlwe( + &encoded_m, + &mut seeded_rlwe_m.data, + s.values(), + &mod_op, + &ntt_op, + &mut p_rng, + &mut rng, + ); + let mut rlwe_m = RlweCiphertext::>, DefaultSecureRng>::from(&seeded_rlwe_m); + + let auto_k = -125; + + // Generate auto key to key switch from s^k to s + let decomposer = DefaultDecomposer::new(q, logb, d_rgsw); + let seed_auto = random_seed(); + let mut seeded_auto_key = + SeededAutoKey::empty(ring_size as usize, &decomposer, seed_auto, q); + let mut p_rng = DefaultSecureRng::new_seeded(seed_auto); + let gadget_vector = decomposer.gadget_vector(); + seeded_auto_key_gen( + &mut seeded_auto_key.data, + s.values(), + auto_k, + &gadget_vector, + &mod_op, + &ntt_op, + &mut p_rng, + &mut rng, + ); + let auto_key = + AutoKeyEvaluationDomain::>, DefaultSecureRng, NttBackendU64>::from( + &seeded_auto_key, + ); + + // Send RLWE_{s}(m) -> RLWE_{s}(m^k) + let mut scratch_space = + vec![vec![0; ring_size as usize]; rlwe_auto_scratch_rows(&decomposer)]; + let (auto_map_index, auto_map_sign) = generate_auto_map(ring_size as usize, auto_k); + + // galois auto with auto key in shoup repr + let rlwe_m_shoup = { + let auto_key_shoup = ToShoup::to_shoup(&auto_key.data, q); + let mut rlwe_m_shoup = rlwe_m.data.clone(); + rlwe_auto_shoup( + &mut RlweCiphertextMutRef::new(&mut rlwe_m_shoup), + &RlweKskRef::new(&auto_key.data, decomposer.decomposition_count().0), + &RlweKskRef::new(&auto_key_shoup, decomposer.decomposition_count().0), + &mut RuntimeScratchMutRef::new(&mut scratch_space), + &auto_map_index, + &auto_map_sign, + &mod_op, + &ntt_op, + &decomposer, + false, + ); + rlwe_m_shoup + }; + + // normal galois auto + { + rlwe_auto( + &mut RlweCiphertextMutRef::new(rlwe_m.data.as_mut()), + &RlweKskRef::new(auto_key.data.as_ref(), decomposer.decomposition_count().0), + &mut RuntimeScratchMutRef::new(scratch_space.as_mut()), + &auto_map_index, + &auto_map_sign, + &mod_op, + &ntt_op, + &decomposer, + false, + ); + } + + // rlwe out from both functions must be same + assert_eq!(rlwe_m.data, rlwe_m_shoup); + + let rlwe_m_k = rlwe_m; + + // Decrypt RLWE_{s}(m^k) and check + let mut encoded_m_k_back = vec![0u64; ring_size as usize]; + decrypt_rlwe( + &rlwe_m_k.data, + s.values(), + &mut encoded_m_k_back, + &ntt_op, + &mod_op, + ); + let m_k_back = encoded_m_k_back + .iter() + .map(|v| (((*v as f64 * p as f64) / q as f64).round() as u64) % p) + .collect_vec(); + + let mut m_k = vec![0u64; ring_size as usize]; + // Send \delta m -> \delta m^k + izip!(m.iter(), auto_map_index.iter(), auto_map_sign.iter()).for_each( + |(v, to_index, sign)| { + if !*sign { + m_k[*to_index] = (p - *v) % p; + } else { + m_k[*to_index] = *v; + } + }, + ); + + // { + // let encoded_m_k = m_k + // .iter() + // .map(|v| ((*v as f64 * q as f64) / p as f64).round() as u64) + // .collect_vec(); + + // let noise = measure_noise(&rlwe_m_k, &encoded_m_k, &ntt_op, &mod_op, + // s.values()); println!("Ksk noise: {noise}"); + // } + + assert_eq!(m_k_back, m_k); + } + + /// Collect noise stats of RGSW ciphertext + /// + /// - rgsw_ct: RGSW ciphertext must be in coefficient domain + fn rgsw_noise_stats + Clone>( + rgsw_ct: &[Vec], + m: &[u64], + s: &[i32], + decomposer: &(DefaultDecomposer, DefaultDecomposer), + q: &T, + ) -> Stats { + let gadget_vector_a = decomposer.a().gadget_vector(); + let gadget_vector_b = decomposer.b().gadget_vector(); + let d_a = gadget_vector_a.len(); + let d_b = gadget_vector_b.len(); + let ring_size = s.len(); + assert!(Matrix::dimension(&rgsw_ct) == (d_a * 2 + d_b * 2, ring_size)); + assert!(m.len() == ring_size); + + let mod_op = ModularOpsU64::new(q.clone()); + let ntt_op = NttBackendU64::new(q, ring_size); + + let mul_mod = + |a: &u64, b: &u64| ((*a as u128 * *b as u128) % q.q().unwrap() as u128) as u64; + let s_poly = Vec::::try_convert_from(s, q); + let mut neg_s = s_poly.clone(); + mod_op.elwise_neg_mut(neg_s.as_mut()); + let neg_sm0m1 = negacyclic_mul(&neg_s, &m, mul_mod, q.q().unwrap()); + + let mut stats = Stats::new(); + + // RLWE(\beta^j -s * m) + for j in 0..d_a { + let want_m = { + // RLWE(\beta^j -s * m) + let mut beta_neg_sm0m1 = vec![0u64; ring_size as usize]; + mod_op.elwise_scalar_mul(beta_neg_sm0m1.as_mut(), &neg_sm0m1, &gadget_vector_a[j]); + beta_neg_sm0m1 + }; + + let mut rlwe = vec![vec![0u64; ring_size as usize]; 2]; + rlwe[0].copy_from_slice(rgsw_ct.get_row_slice(j)); + rlwe[1].copy_from_slice(rgsw_ct.get_row_slice(d_a + j)); + + let mut got_m = vec![0; ring_size]; + decrypt_rlwe(&rlwe, s, &mut got_m, &ntt_op, &mod_op); + + let mut diff = want_m; + mod_op.elwise_sub_mut(diff.as_mut(), got_m.as_ref()); + stats.add_many_samples(&Vec::::try_convert_from(&diff, q)); + } + + // RLWE(\beta^j m) + for j in 0..d_b { + let want_m = { + // RLWE(\beta^j m) + let mut beta_m0m1 = vec![0u64; ring_size as usize]; + mod_op.elwise_scalar_mul(beta_m0m1.as_mut(), &m, &gadget_vector_b[j]); + beta_m0m1 + }; + + let mut rlwe = vec![vec![0u64; ring_size as usize]; 2]; + rlwe[0].copy_from_slice(rgsw_ct.get_row_slice(d_a * 2 + j)); + rlwe[1].copy_from_slice(rgsw_ct.get_row_slice(d_a * 2 + d_b + j)); + + let mut got_m = vec![0; ring_size]; + decrypt_rlwe(&rlwe, s, &mut got_m, &ntt_op, &mod_op); + + let mut diff = want_m; + mod_op.elwise_sub_mut(diff.as_mut(), got_m.as_ref()); + stats.add_many_samples(&Vec::::try_convert_from(&diff, q)); + } + + stats + } + + #[test] + fn print_noise_stats_rgsw_x_rgsw() { + let logq = 60; + let logp = 2; + let ring_size = 1 << 11; + let q = generate_prime(logq, ring_size, 1u64 << logq).unwrap(); + let d_rgsw = 12; + let logb = 5; + + let s = RlweSecret::random((ring_size >> 1) as usize, ring_size as usize); + + let ntt_op = NttBackendU64::new(&q, ring_size as usize); + let mod_op = ModularOpsU64::new(q); + let decomposer = ( + DefaultDecomposer::new(q, logb, d_rgsw), + DefaultDecomposer::new(q, logb, d_rgsw), + ); + + let d_a = decomposer.a().decomposition_count().0; + let d_b = decomposer.b().decomposition_count().0; + + let mul_mod = |a: &u64, b: &u64| ((*a as u128 * *b as u128) % q as u128) as u64; + + let mut carry_m = vec![0u64; ring_size as usize]; + carry_m[thread_rng().gen_range(0..ring_size) as usize] = 1 << logp; + + // RGSW(carry_m) + let mut rgsw_carrym = { + let seeded_rgsw = sk_encrypt_rgsw(&carry_m, s.values(), &decomposer, &mod_op, &ntt_op); + let mut rgsw_eval = + RgswCiphertextEvaluationDomain::<_, _, DefaultSecureRng, NttBackendU64>::from( + &seeded_rgsw, + ); + rgsw_eval + .data + .iter_mut() + .for_each(|ri| ntt_op.backward(ri.as_mut())); + rgsw_eval.data + }; + + let mut scratch_matrix = vec![ + vec![0u64; ring_size as usize]; + rgsw_x_rgsw_scratch_rows(&decomposer, &decomposer) + ]; + + rgsw_noise_stats(&rgsw_carrym, &carry_m, s.values(), &decomposer, &q); + + for i in 0..8 { + let mut m = vec![0u64; ring_size as usize]; + m[thread_rng().gen_range(0..ring_size) as usize] = if (i & 1) == 1 { q - 1 } else { 1 }; + let rgsw_m = + RgswCiphertextEvaluationDomain::<_, _, DefaultSecureRng, NttBackendU64>::from( + &sk_encrypt_rgsw(&m, s.values(), &decomposer, &mod_op, &ntt_op), + ); + + rgsw_by_rgsw_inplace( + &mut RgswCiphertextMutRef::new(rgsw_carrym.as_mut(), d_a, d_b), + &RgswCiphertextRef::new(rgsw_m.data.as_ref(), d_a, d_b), + &decomposer, + &decomposer, + &mut RuntimeScratchMutRef::new(scratch_matrix.as_mut()), + &ntt_op, + &mod_op, + ); + + // measure noise + carry_m = negacyclic_mul(&carry_m, &m, mul_mod, q); + let stats = rgsw_noise_stats(&rgsw_carrym, &carry_m, s.values(), &decomposer, &q); + println!( + "Log2 of noise std after {i} RGSW x RGSW: {}", + stats.std_dev().abs().log2() + ); + } + } +} diff --git a/src/rgsw/runtime.rs b/src/rgsw/runtime.rs new file mode 100644 index 0000000..bd56c5d --- /dev/null +++ b/src/rgsw/runtime.rs @@ -0,0 +1,1063 @@ +use itertools::izip; +use num_traits::Zero; + +use crate::{ + backend::{ArithmeticOps, GetModulus, ShoupMatrixFMA, VectorOps}, + decomposer::{Decomposer, RlweDecomposer}, + ntt::Ntt, + parameters::{DecompositionCount, DoubleDecomposerParams, SingleDecomposerParams}, + Matrix, MatrixEntity, MatrixMut, Row, RowEntity, RowMut, +}; + +/// Degree 1 RLWE ciphertext. +/// +/// RLWE(m) = [a, b] s.t. m+e = b - as +pub(crate) trait RlweCiphertext { + type R: RowMut; + /// Returns polynomial `a` of RLWE ciphertext as slice of elements + fn part_a(&self) -> &[::Element]; + /// Returns polynomial `a` of RLWE ciphertext as mutable slice of elements + fn part_a_mut(&mut self) -> &mut [::Element]; + /// Returns polynomial `b` of RLWE ciphertext as slice of elements + fn part_b(&self) -> &[::Element]; + /// Returns polynomial `b` of RLWE ciphertext as mut slice of elements + fn part_b_mut(&mut self) -> &mut [::Element]; + /// Returns ring size of polynomials + fn ring_size(&self) -> usize; +} + +/// RGSW ciphertext +/// +/// RGSW is a collection of RLWE' ciphertext which are collection degree 1 of +/// RLWE ciphertexts +/// +/// RGSW = [RLWE'(-sm) || RLWE'(m)] +/// +/// As usual we refer to decomposition count for RLWE_A in RLWE x RGSW +/// multiplicaiton as `d_a` and decomposition count for RLWE_B in RLWE x RGDW +/// multiplication as `d_b`. +pub(crate) trait RgswCiphertext { + type R: Row; + + /// Splits RGSW ciphertext and returns references: + /// (RLWE'_A(-sm), RLWE'_B(-sm)), (RLWE'_A(m), RLWE'_B(m)) + fn split(&self) -> ((&[Self::R], &[Self::R]), (&[Self::R], &[Self::R])); +} + +pub(crate) trait RgswCiphertextMut: RgswCiphertext { + /// Splits RGSW ciphertext and returns mutable references: + /// (RLWE'_A(-sm), RLWE'_B(-sm)), (RLWE'_A(m), RLWE'_B(m)) + fn split_mut( + &mut self, + ) -> ( + (&mut [Self::R], &mut [Self::R]), + (&mut [Self::R], &mut [Self::R]), + ); +} + +/// RLWE Key switching Key +/// +/// Key switching key from s' -> s consists of multiple RLWE cipheretxts. +/// For gadget vector: [1, beta, ..., beta^{d-1}] +/// RLWE'_{s}(-s'm) = [RWLE_{s}(-s'm), ..., RLWE_{s}(beta^{d-1} -s'm)] +pub(crate) trait RlweKsk { + type R: Row; + /// Returns reference to RLWE'_A(-s'm) polynomials + fn ksk_part_a(&self) -> &[Self::R]; + /// Returns reference to RLWE'_B(-s'm) polynomials + fn ksk_part_b(&self) -> &[Self::R]; +} + +/// Scratch matrix used in several rlwe/rgsw runtime operations +pub(crate) trait RuntimeScratchMatrix { + type R: RowMut; + type Rgsw: RgswCiphertext; + + /// Returns scratch matrix for RLWE automorphism (not trivial case) + /// + /// RLWE auto requires scratch matric to store decomposed polynomials + 1 + /// rlwe ciphertext temporarily. + /// + /// For example, if Auto decomposer has decompostion count `d` then the + /// scratch matrix must have dimension (d + 2, N) where N is the ring size. + fn scratch_for_rlwe_auto_and_zero_rlwe_space( + &mut self, + decompostion_count: usize, + ) -> (&mut [Self::R], &mut [Self::R]); + + /// Returns scratch matrix for RLWE automorphism (trivial case) + /// + /// We refer to cases where RLWE(m) = [0, b] s.t. m = b as trivial cases. In + /// such a case a single row of length N, N being the ring dimension, is + /// required as scratch buffer to store automorphism of polynomial `b` + /// temporarily. + fn scratch_for_rlwe_auto_trivial_case(&mut self) -> &mut Self::R; + + /// Returns scratch matrix + zeroed RLWE ciphertext space for + /// RLWE x RGSW + /// + /// RLWE x RGSW product requires scratch space to store decomposed + /// polynomials for both cases: (1) SignedDecompose(RLWE_A) x RLWE'(-sm) and + /// (2) SignedDecompose(RLWE_B) x RLWE'(m). Hence, scratch space returned to + /// store decomposed polynomials must have MAX(d_a, d_b) rows. + /// + /// Additional scratch space is required to store 1 RLWE ciphertext + /// temporarily. The space must be zeroed. + fn scratch_for_rlwe_x_rgsw_and_zero_rlwe_space( + &mut self, + decomposer: &D, + ) -> (&mut [Self::R], &mut [Self::R]); + + /// Returns scracth matrix + zeroed RGSW ciphertext space for RGSW0 x RGSW1 + /// + /// RGSW0 x RGSW1 requires `d_{0,a} + d_{0,b}` RLWE x RGSW1 products where + /// d_{0, a/b} are decomposition counts corresponding to decmposer used for + /// RGSW0. Hence, scratch space required to store decomposed polynomial for + /// RLWE x RGSW1 product should have MAX(d_{1, a}, d_{1, b}) rows. + /// + /// Additional scravth space is required to store RGSW0 ciphertext + /// temporarily. The space must be zeroed. + fn scratch_for_rgsw_x_rgsw_and_zero_rgsw0_space( + &mut self, + d0: &D, + d1: &D, + ) -> (&mut [Self::R], &mut [Self::R]); +} + +pub(crate) struct RlweCiphertextMutRef<'a, R> { + data: &'a mut [R], +} + +impl<'a, R> RlweCiphertextMutRef<'a, R> { + pub(crate) fn new(data: &'a mut [R]) -> Self { + Self { data } + } +} + +impl<'a, R: RowMut> RlweCiphertext for RlweCiphertextMutRef<'a, R> { + type R = R; + fn part_a(&self) -> &[::Element] { + self.data[0].as_ref() + } + fn part_a_mut(&mut self) -> &mut [::Element] { + self.data[0].as_mut() + } + fn part_b(&self) -> &[::Element] { + self.data[1].as_ref() + } + fn part_b_mut(&mut self) -> &mut [::Element] { + self.data[1].as_mut() + } + fn ring_size(&self) -> usize { + self.data[0].as_ref().len() + } +} + +pub(crate) struct RgswCiphertextRef<'a, R> { + data: &'a [R], + d_a: usize, + d_b: usize, +} + +impl<'a, R> RgswCiphertextRef<'a, R> { + pub(crate) fn new(data: &'a [R], d_a: usize, d_b: usize) -> Self { + RgswCiphertextRef { data, d_a, d_b } + } +} + +impl<'a, R> RgswCiphertext for RgswCiphertextRef<'a, R> +where + R: Row, +{ + type R = R; + + fn split(&self) -> ((&[Self::R], &[Self::R]), (&[Self::R], &[Self::R])) { + let (rlwe_dash_nsm, rlwe_dash_m) = self.data.split_at(self.d_a * 2); + ( + rlwe_dash_nsm.split_at(self.d_a), + rlwe_dash_m.split_at(self.d_b), + ) + } +} + +pub(crate) struct RgswCiphertextMutRef<'a, R> { + data: &'a mut [R], + d_a: usize, + d_b: usize, +} + +impl<'a, R> RgswCiphertextMutRef<'a, R> { + pub(crate) fn new(data: &'a mut [R], d_a: usize, d_b: usize) -> Self { + RgswCiphertextMutRef { data, d_a, d_b } + } +} + +impl<'a, R: RowMut> AsMut<[R]> for RgswCiphertextMutRef<'a, R> { + fn as_mut(&mut self) -> &mut [R] { + &mut self.data + } +} + +impl<'a, R> RgswCiphertext for RgswCiphertextMutRef<'a, R> +where + R: Row, +{ + type R = R; + + fn split(&self) -> ((&[Self::R], &[Self::R]), (&[Self::R], &[Self::R])) { + let (rlwe_dash_nsm, rlwe_dash_m) = self.data.split_at(self.d_a * 2); + ( + rlwe_dash_nsm.split_at(self.d_a), + rlwe_dash_m.split_at(self.d_b), + ) + } +} + +impl<'a, R> RgswCiphertextMut for RgswCiphertextMutRef<'a, R> +where + R: RowMut, +{ + fn split_mut( + &mut self, + ) -> ( + (&mut [Self::R], &mut [Self::R]), + (&mut [Self::R], &mut [Self::R]), + ) { + let (rlwe_dash_nsm, rlwe_dash_m) = self.data.split_at_mut(self.d_a * 2); + ( + rlwe_dash_nsm.split_at_mut(self.d_a), + rlwe_dash_m.split_at_mut(self.d_b), + ) + } +} + +pub(crate) struct RlweKskRef<'a, R> { + data: &'a [R], + decomposition_count: usize, +} +impl<'a, R: Row> RlweKskRef<'a, R> { + pub(crate) fn new(ksk: &'a [R], decomposition_count: usize) -> Self { + Self { + data: ksk, + decomposition_count, + } + } +} + +impl<'a, R: Row> RlweKsk for RlweKskRef<'a, R> { + type R = R; + + fn ksk_part_a(&self) -> &[Self::R] { + &self.data[..self.decomposition_count] + } + + fn ksk_part_b(&self) -> &[Self::R] { + &self.data[self.decomposition_count..] + } +} + +pub(crate) struct RuntimeScratchMutRef<'a, R> { + data: &'a mut [R], +} + +impl<'a, R> RuntimeScratchMutRef<'a, R> { + pub(crate) fn new(data: &'a mut [R]) -> Self { + Self { data } + } +} + +impl<'a, R: RowMut> RuntimeScratchMatrix for RuntimeScratchMutRef<'a, R> +where + R::Element: Zero + Clone, +{ + type R = R; + type Rgsw = RgswCiphertextRef<'a, R>; + + fn scratch_for_rlwe_auto_and_zero_rlwe_space( + &mut self, + decompostion_count: usize, + ) -> (&mut [Self::R], &mut [Self::R]) { + let (decomp_poly, other) = self.data.split_at_mut(decompostion_count); + let (rlwe, _) = other.split_at_mut(2); + + // zero fill rlwe + rlwe.iter_mut() + .for_each(|r| r.as_mut().fill(R::Element::zero())); + + (decomp_poly, rlwe) + } + + fn scratch_for_rlwe_auto_trivial_case(&mut self) -> &mut Self::R { + &mut self.data[0] + } + + fn scratch_for_rgsw_x_rgsw_and_zero_rgsw0_space( + &mut self, + rgsw0_decoposer: &D, + rgsw1_decoposer: &D, + ) -> (&mut [Self::R], &mut [Self::R]) { + let (decomp_poly, other) = self.data.split_at_mut(std::cmp::max( + rgsw1_decoposer.decomposition_count_a().0, + rgsw1_decoposer.decomposition_count_b().0, + )); + let (rgsw, _) = other.split_at_mut( + rgsw0_decoposer.decomposition_count_a().0 * 2 + + rgsw0_decoposer.decomposition_count_b().0 * 2, + ); + + // zero fill rgsw0 + rgsw.iter_mut() + .for_each(|r| r.as_mut().fill(R::Element::zero())); + + (decomp_poly, rgsw) + } + + fn scratch_for_rlwe_x_rgsw_and_zero_rlwe_space( + &mut self, + decomposer: &D, + ) -> (&mut [Self::R], &mut [Self::R]) { + let (decomp_poly, other) = self.data.split_at_mut(std::cmp::max( + decomposer.decomposition_count_a().0, + decomposer.decomposition_count_b().0, + )); + + let (rlwe, _) = other.split_at_mut(2); + + // zero fill rlwe + rlwe.iter_mut() + .for_each(|r| r.as_mut().fill(R::Element::zero())); + + (decomp_poly, rlwe) + } +} + +/// Returns no. of rows in scratch space for RGSW0 x RGSW1 product +pub(crate) fn rgsw_x_rgsw_scratch_rows>( + rgsw0_decomposer_param: &D, + rgsw1_decomposer_param: &D, +) -> usize { + std::cmp::max( + rgsw1_decomposer_param.decomposition_count_a().0, + rgsw1_decomposer_param.decomposition_count_b().0, + ) + rgsw0_decomposer_param.decomposition_count_a().0 * 2 + + rgsw0_decomposer_param.decomposition_count_b().0 * 2 +} + +/// Returns no. of rows in scratch space for RLWE x RGSW product +pub(crate) fn rlwe_x_rgsw_scratch_rows>( + rgsw_decomposer_param: &D, +) -> usize { + std::cmp::max( + rgsw_decomposer_param.decomposition_count_a().0, + rgsw_decomposer_param.decomposition_count_b().0, + ) + 2 +} + +/// Returns no. of rows in scratch space for RLWE auto +pub(crate) fn rlwe_auto_scratch_rows>( + param: &D, +) -> usize { + param.decomposition_count().0 + 2 +} + +pub(crate) fn poly_fma_routine>( + write_to_row: &mut [R::Element], + matrix_a: &[R], + matrix_b: &[R], + mod_op: &ModOp, +) { + izip!(matrix_a.iter(), matrix_b.iter()).for_each(|(a, b)| { + mod_op.elwise_fma_mut(write_to_row, a.as_ref(), b.as_ref()); + }); +} + +/// Decomposes ring polynomial r(X) into d polynomials using decomposer into +/// output matrix decomp_r +/// +/// Note that decomposition of r(X) requires decomposition of each of +/// coefficients. +/// +/// - decomp_r: must have dimensions d x ring_size. i^th decomposed polynomial +/// will be stored at i^th row. +pub(crate) fn decompose_r>( + r: &[R::Element], + decomp_r: &mut [R], + decomposer: &D, +) where + R::Element: Copy, +{ + let ring_size = r.len(); + + for ri in 0..ring_size { + decomposer + .decompose_iter(&r[ri]) + .enumerate() + .for_each(|(index, el)| { + decomp_r[index].as_mut()[ri] = el; + }); + } +} + +/// Sends RLWE_{s(X)}(m(X)) -> RLWE_{s(X)}(m{X^k}) where k is some galois +/// element +/// +/// - rlwe_in: Input ciphertext RLWE_{s(X)}(m(X)). +/// - ksk: Auto key switching key with polynomials in evaluation domain +/// - auto_map_index: If automorphism sends i^th coefficient of m(X) to j^th +/// coefficient of m(X^k) then auto_map_index[i] = j +/// - auto_sign_index: With a = m(X)[i], if m(X^k)[auto_map_index[i]] = -a, then +/// auto_sign_index[i] = false, else auto_sign_index[i] = true +/// - scratch_matrix: must have dimension at-least d+2 x ring_size. `d` rows to +/// store decomposed polynomials nad 2 rows to store out RLWE temporarily. +pub(crate) fn rlwe_auto< + Rlwe: RlweCiphertext, + Ksk: RlweKsk, + Sc: RuntimeScratchMatrix, + ModOp: ArithmeticOps::Element> + + VectorOps::Element>, + NttOp: Ntt::Element>, + D: Decomposer::Element>, +>( + rlwe_in: &mut Rlwe, + ksk: &Ksk, + scratch_matrix: &mut Sc, + auto_map_index: &[usize], + auto_map_sign: &[bool], + mod_op: &ModOp, + ntt_op: &NttOp, + decomposer: &D, + is_trivial: bool, +) where + ::Element: Copy + Zero, +{ + // let ring_size = rlwe_in.dimension().1; + // assert!(rlwe_in.dimension().0 == 2); + // assert!(scratch_matrix.fits(d + 2, ring_size)); + + if !is_trivial { + let (decomp_poly_scratch, tmp_rlwe) = scratch_matrix + .scratch_for_rlwe_auto_and_zero_rlwe_space(decomposer.decomposition_count().0); + let mut tmp_rlwe = RlweCiphertextMutRef::new(tmp_rlwe); + + // send a(X) -> a(X^k) and decompose a(X^k) + izip!( + rlwe_in.part_a(), + auto_map_index.iter(), + auto_map_sign.iter() + ) + .for_each(|(el_in, to_index, sign)| { + let el_out = if !*sign { mod_op.neg(el_in) } else { *el_in }; + + decomposer + .decompose_iter(&el_out) + .enumerate() + .for_each(|(index, el)| { + decomp_poly_scratch[index].as_mut()[*to_index] = el; + }); + }); + + // transform decomposed a(X^k) to evaluation domain + decomp_poly_scratch.iter_mut().for_each(|r| { + ntt_op.forward(r.as_mut()); + }); + + // RLWE(m^k) = a', b'; RLWE(m) = a, b + // key switch: (a * RLWE'(s(X^k))) + // a' = decomp * RLWE'_A(s(X^k)) + poly_fma_routine( + tmp_rlwe.part_a_mut(), + decomp_poly_scratch, + ksk.ksk_part_a(), + mod_op, + ); + + // b' += decomp * RLWE'_B(s(X^k)) + poly_fma_routine( + tmp_rlwe.part_b_mut(), + decomp_poly_scratch, + ksk.ksk_part_b(), + mod_op, + ); + + // transform RLWE(m^k) to coefficient domain + ntt_op.backward(tmp_rlwe.part_a_mut()); + ntt_op.backward(tmp_rlwe.part_b_mut()); + + // send b(X) -> b(X^k) and then b'(X) += b(X^k) + izip!( + rlwe_in.part_b(), + auto_map_index.iter(), + auto_map_sign.iter() + ) + .for_each(|(el_in, to_index, sign)| { + let row = tmp_rlwe.part_b_mut(); + if !*sign { + row[*to_index] = mod_op.sub(&row[*to_index], el_in); + } else { + row[*to_index] = mod_op.add(&row[*to_index], el_in); + } + }); + + // copy over A; Leave B for later + rlwe_in.part_a_mut().copy_from_slice(tmp_rlwe.part_a()); + rlwe_in.part_b_mut().copy_from_slice(tmp_rlwe.part_b()); + } else { + // RLWE is trivial, a(X) is 0. + // send b(X) -> b(X^k) + let tmp_row = scratch_matrix.scratch_for_rlwe_auto_trivial_case(); + izip!( + rlwe_in.part_b(), + auto_map_index.iter(), + auto_map_sign.iter() + ) + .for_each(|(el_in, to_index, sign)| { + if !*sign { + tmp_row.as_mut()[*to_index] = mod_op.neg(el_in); + } else { + tmp_row.as_mut()[*to_index] = *el_in; + } + }); + rlwe_in.part_b_mut().copy_from_slice(tmp_row.as_ref()); + } +} + +/// Sends RLWE_{s(X)}(m(X)) -> RLWE_{s(X)}(m{X^k}) where k is some galois +/// element +/// +/// This is same as `galois_auto` with the difference that alongside `ksk` with +/// key switching polynomials in evaluation domain, shoup representation, +/// `ksk_shoup`, of the polynomials in evaluation domain is also supplied. +pub(crate) fn rlwe_auto_shoup< + Rlwe: RlweCiphertext, + Ksk: RlweKsk, + Sc: RuntimeScratchMatrix, + ModOp: ArithmeticOps::Element> + // + VectorOps + + ShoupMatrixFMA, + NttOp: Ntt::Element>, + D: Decomposer::Element>, +>( + rlwe_in: &mut Rlwe, + ksk: &Ksk, + ksk_shoup: &Ksk, + scratch_matrix: &mut Sc, + auto_map_index: &[usize], + auto_map_sign: &[bool], + mod_op: &ModOp, + ntt_op: &NttOp, + decomposer: &D, + is_trivial: bool, +) where + ::Element: Copy + Zero, +{ + // let d = decomposer.decomposition_count(); + // let ring_size = rlwe_in.dimension().1; + // assert!(rlwe_in.dimension().0 == 2); + // assert!(scratch_matrix.fits(d + 2, ring_size)); + + if !is_trivial { + let (decomp_poly_scratch, tmp_rlwe) = scratch_matrix + .scratch_for_rlwe_auto_and_zero_rlwe_space(decomposer.decomposition_count().0); + let mut tmp_rlwe = RlweCiphertextMutRef::new(tmp_rlwe); + + // send a(X) -> a(X^k) and decompose a(X^k) + izip!( + rlwe_in.part_a(), + auto_map_index.iter(), + auto_map_sign.iter() + ) + .for_each(|(el_in, to_index, sign)| { + let el_out = if !*sign { mod_op.neg(el_in) } else { *el_in }; + + decomposer + .decompose_iter(&el_out) + .enumerate() + .for_each(|(index, el)| { + decomp_poly_scratch[index].as_mut()[*to_index] = el; + }); + }); + + // transform decomposed a(X^k) to evaluation domain + decomp_poly_scratch.iter_mut().for_each(|r| { + ntt_op.forward_lazy(r.as_mut()); + }); + + // RLWE(m^k) = a', b'; RLWE(m) = a, b + // key switch: (a * RLWE'(s(X^k))) + // a' = decomp * RLWE'_A(s(X^k)) + mod_op.shoup_matrix_fma( + tmp_rlwe.part_a_mut(), + ksk.ksk_part_a(), + ksk_shoup.ksk_part_a(), + decomp_poly_scratch, + ); + + // b'= decomp * RLWE'_B(s(X^k)) + mod_op.shoup_matrix_fma( + tmp_rlwe.part_b_mut(), + ksk.ksk_part_b(), + ksk_shoup.ksk_part_b(), + decomp_poly_scratch, + ); + + // transform RLWE(m^k) to coefficient domain + ntt_op.backward(tmp_rlwe.part_a_mut()); + ntt_op.backward(tmp_rlwe.part_b_mut()); + + // send b(X) -> b(X^k) and then b'(X) += b(X^k) + let row = tmp_rlwe.part_b_mut(); + izip!( + rlwe_in.part_b(), + auto_map_index.iter(), + auto_map_sign.iter() + ) + .for_each(|(el_in, to_index, sign)| { + if !*sign { + row[*to_index] = mod_op.sub(&row[*to_index], el_in); + } else { + row[*to_index] = mod_op.add(&row[*to_index], el_in); + } + }); + + // copy over A, B + rlwe_in.part_a_mut().copy_from_slice(tmp_rlwe.part_a()); + rlwe_in.part_b_mut().copy_from_slice(tmp_rlwe.part_b()); + } else { + // RLWE is trivial, a(X) is 0. + // send b(X) -> b(X^k) + let row = scratch_matrix.scratch_for_rlwe_auto_trivial_case(); + izip!( + rlwe_in.part_b(), + auto_map_index.iter(), + auto_map_sign.iter() + ) + .for_each(|(el_in, to_index, sign)| { + if !*sign { + row.as_mut()[*to_index] = mod_op.neg(el_in); + } else { + row.as_mut()[*to_index] = *el_in; + } + }); + rlwe_in.part_b_mut().copy_from_slice(row.as_ref()); + } +} + +/// Inplace mutates RLWE(m0) to equal RLWE(m0m1) = RLWE(m0) x RGSW(m1). +/// +/// - rlwe_in: is RLWE(m0) with polynomials in coefficient domain +/// - rgsw_in: is RGSW(m1) with polynomials in evaluation domain +/// - scratch_matrix: with dimension (max(d_a, d_b) + 2) x ring_size columns. +/// It's used to store decomposed polynomials and out RLWE temporarily +pub(crate) fn rlwe_by_rgsw< + Rlwe: RlweCiphertext, + Rgsw: RgswCiphertext, + Sc: RuntimeScratchMatrix, + D: RlweDecomposer::Element>, + ModOp: VectorOps::Element>, + NttOp: Ntt::Element>, +>( + rlwe_in: &mut Rlwe, + rgsw_in: &Rgsw, + scratch_matrix: &mut Sc, + decomposer: &D, + ntt_op: &NttOp, + mod_op: &ModOp, + is_trivial: bool, +) where + ::Element: Copy + Zero, +{ + let decomposer_a = decomposer.a(); + let decomposer_b = decomposer.b(); + let d_a = decomposer.decomposition_count_a().0; + let d_b = decomposer.decomposition_count_b().0; + + let ((rlwe_dash_nsm_parta, rlwe_dash_nsm_partb), (rlwe_dash_m_parta, rlwe_dash_m_partb)) = + rgsw_in.split(); + + let (decomposed_poly_scratch, tmp_rlwe) = + scratch_matrix.scratch_for_rlwe_x_rgsw_and_zero_rlwe_space(decomposer); + + // RLWE_in = a_in, b_in; RLWE_out = a_out, b_out + if !is_trivial { + // a_in = 0 when RLWE_in is trivial RLWE ciphertext + // decomp + let mut decomposed_polys_of_rlwea = &mut decomposed_poly_scratch[..d_a]; + decompose_r( + rlwe_in.part_a(), + &mut decomposed_polys_of_rlwea, + decomposer_a, + ); + + decomposed_polys_of_rlwea + .iter_mut() + .for_each(|r| ntt_op.forward(r.as_mut())); + + // a_out += decomp \cdot RLWE_A'(-sm) + poly_fma_routine( + tmp_rlwe[0].as_mut(), + &decomposed_polys_of_rlwea, + rlwe_dash_nsm_parta, + mod_op, + ); + // b_out += decomp \cdot RLWE_B'(-sm) + poly_fma_routine( + tmp_rlwe[1].as_mut(), + &decomposed_polys_of_rlwea, + &rlwe_dash_nsm_partb, + mod_op, + ); + } + + { + // decomp + let mut decomposed_polys_of_rlweb = &mut decomposed_poly_scratch[..d_b]; + decompose_r( + rlwe_in.part_b(), + &mut decomposed_polys_of_rlweb, + decomposer_b, + ); + + decomposed_polys_of_rlweb + .iter_mut() + .for_each(|r| ntt_op.forward(r.as_mut())); + + // a_out += decomp \cdot RLWE_A'(m) + poly_fma_routine( + tmp_rlwe[0].as_mut(), + &decomposed_polys_of_rlweb, + &rlwe_dash_m_parta, + mod_op, + ); + // b_out += decomp \cdot RLWE_B'(m) + poly_fma_routine( + tmp_rlwe[1].as_mut(), + &decomposed_polys_of_rlweb, + &rlwe_dash_m_partb, + mod_op, + ); + } + + // transform rlwe_out to coefficient domain + tmp_rlwe + .iter_mut() + .for_each(|r| ntt_op.backward(r.as_mut())); + + rlwe_in.part_a_mut().copy_from_slice(tmp_rlwe[0].as_mut()); + rlwe_in.part_b_mut().copy_from_slice(tmp_rlwe[1].as_mut()); +} + +/// Inplace mutates RLWE(m0) to equal RLWE(m0m1) = RLWE(m0) x RGSW(m1). +/// +/// Same as `rlwe_by_rgsw` with the difference that alongside `rgsw_in` with +/// polynomials in evaluation domain, shoup representation of polynomials in +/// evaluation domain, `rgsw_in_shoup`, is also supplied. +pub(crate) fn rlwe_by_rgsw_shoup< + Rlwe: RlweCiphertext, + Rgsw: RgswCiphertext, + Sc: RuntimeScratchMatrix, + D: RlweDecomposer::Element>, + ModOp: ShoupMatrixFMA, + NttOp: Ntt::Element>, +>( + rlwe_in: &mut Rlwe, + rgsw_in: &Rgsw, + rgsw_in_shoup: &Rgsw, + scratch_matrix: &mut Sc, + decomposer: &D, + ntt_op: &NttOp, + mod_op: &ModOp, + is_trivial: bool, +) where + ::Element: Copy + Zero, +{ + let decomposer_a = decomposer.a(); + let decomposer_b = decomposer.b(); + let d_a = decomposer.decomposition_count_a().0; + let d_b = decomposer.decomposition_count_b().0; + + let ((rlwe_dash_nsm_parta, rlwe_dash_nsm_partb), (rlwe_dash_m_parta, rlwe_dash_m_partb)) = + rgsw_in.split(); + + let ( + (rlwe_dash_nsm_parta_shoup, rlwe_dash_nsm_partb_shoup), + (rlwe_dash_m_parta_shoup, rlwe_dash_m_partb_shoup), + ) = rgsw_in_shoup.split(); + + let (decomposed_poly_scratch, tmp_rlwe) = + scratch_matrix.scratch_for_rlwe_x_rgsw_and_zero_rlwe_space(decomposer); + + // RLWE_in = a_in, b_in; RLWE_out = a_out, b_out + if !is_trivial { + // a_in = 0 when RLWE_in is trivial RLWE ciphertext + // decomp + let mut decomposed_polys_of_rlwea = &mut decomposed_poly_scratch[..d_a]; + decompose_r( + rlwe_in.part_a(), + &mut decomposed_polys_of_rlwea, + decomposer_a, + ); + decomposed_polys_of_rlwea + .iter_mut() + .for_each(|r| ntt_op.forward_lazy(r.as_mut())); + + // a_out += decomp \cdot RLWE_A'(-sm) + mod_op.shoup_matrix_fma( + tmp_rlwe[0].as_mut(), + &rlwe_dash_nsm_parta, + &rlwe_dash_nsm_parta_shoup, + &decomposed_polys_of_rlwea, + ); + + // b_out += decomp \cdot RLWE_B'(-sm) + mod_op.shoup_matrix_fma( + tmp_rlwe[1].as_mut(), + &rlwe_dash_nsm_partb, + &rlwe_dash_nsm_partb_shoup, + &decomposed_polys_of_rlwea, + ); + } + { + // decomp + let mut decomposed_polys_of_rlweb = &mut decomposed_poly_scratch[..d_b]; + decompose_r( + rlwe_in.part_b(), + &mut decomposed_polys_of_rlweb, + decomposer_b, + ); + decomposed_polys_of_rlweb + .iter_mut() + .for_each(|r| ntt_op.forward_lazy(r.as_mut())); + + // a_out += decomp \cdot RLWE_A'(m) + mod_op.shoup_matrix_fma( + tmp_rlwe[0].as_mut(), + &rlwe_dash_m_parta, + &rlwe_dash_m_parta_shoup, + &decomposed_polys_of_rlweb, + ); + + // b_out += decomp \cdot RLWE_B'(m) + mod_op.shoup_matrix_fma( + tmp_rlwe[1].as_mut(), + &rlwe_dash_m_partb, + &rlwe_dash_m_partb_shoup, + &decomposed_polys_of_rlweb, + ); + } + + // transform rlwe_out to coefficient domain + tmp_rlwe + .iter_mut() + .for_each(|r| ntt_op.backward(r.as_mut())); + + rlwe_in.part_a_mut().copy_from_slice(tmp_rlwe[0].as_mut()); + rlwe_in.part_b_mut().copy_from_slice(tmp_rlwe[1].as_mut()); +} + +/// Inplace mutates RGSW(m0) to equal RGSW(m0m1) = RGSW(m0)xRGSW(m1) +/// +/// RGSW x RGSW product requires multiple RLWE x RGSW products. For example, +/// Define +/// +/// RGSW(m0) = [RLWE(-sm), RLWE(\beta -sm), ..., RLWE(\beta^{d-1} -sm) +/// RLWE(m), RLWE(\beta m), ..., RLWE(\beta^{d-1} m)] +/// And RGSW(m1) +/// +/// Then RGSW(m0) x RGSW(m1) equals: +/// RGSW(m0m1) = [ +/// rlwe_x_rgsw(RLWE(-sm), RGSW(m1)), +/// ..., +/// rlwe_x_rgsw(RLWE(\beta^{d-1} -sm), RGSW(m1)), +/// rlwe_x_rgsw(RLWE(m), RGSW(m1)), +/// ..., +/// rlwe_x_rgsw(RLWE(\beta^{d-1} m), RGSW(m1)), +/// ] +/// +/// Since noise growth in RLWE x RGSW depends on noise in RGSW ciphertext, it is +/// clear to observe from above that noise in resulting RGSW(m0m1) equals noise +/// accumulated in a single RLWE x RGSW and depends on noise in RGSW(m1) (i.e. +/// rgsw_1_eval) +/// +/// - rgsw_0: RGSW(m0) in coefficient domain +/// - rgsw_1_eval: RGSW(m1) in evaluation domain +pub(crate) fn rgsw_by_rgsw_inplace< + Rgsw: RgswCiphertext, + RgswMut: RgswCiphertextMut, + Sc: RuntimeScratchMatrix, + D: RlweDecomposer::Element>, + ModOp: VectorOps::Element>, + NttOp: Ntt::Element>, +>( + rgsw0: &mut RgswMut, + rgsw1_eval: &Rgsw, + rgsw0_decomposer: &D, + rgsw1_decomposer: &D, + scratch_matrix: &mut Sc, + ntt_op: &NttOp, + mod_op: &ModOp, +) where + ::Element: Copy + Zero, + RgswMut: AsMut<[Rgsw::R]>, + RgswMut::R: RowMut, + // Rgsw: AsRef<[Rgsw::R]>, +{ + let (decomp_r_space, rgsw_space) = scratch_matrix + .scratch_for_rgsw_x_rgsw_and_zero_rgsw0_space(rgsw0_decomposer, rgsw1_decomposer); + + let mut rgsw_space = RgswCiphertextMutRef::new( + rgsw_space, + rgsw0_decomposer.decomposition_count_a().0, + rgsw0_decomposer.decomposition_count_b().0, + ); + let ( + (rlwe_dash_space_nsm_parta, rlwe_dash_space_nsm_partb), + (rlwe_dash_space_m_parta, rlwe_dash_space_m_partb), + ) = rgsw_space.split_mut(); + + let ((rgsw0_nsm_parta, rgsw0_nsm_partb), (rgsw0_m_parta, rgsw0_m_partb)) = rgsw0.split(); + let ((rgsw1_nsm_parta, rgsw1_nsm_partb), (rgsw1_m_parta, rgsw1_m_partb)) = rgsw1_eval.split(); + + // RGSW x RGSW + izip!( + rgsw0_nsm_parta.iter().chain(rgsw0_m_parta), + rgsw0_nsm_partb.iter().chain(rgsw0_m_partb), + rlwe_dash_space_nsm_parta + .iter_mut() + .chain(rlwe_dash_space_m_parta.iter_mut()), + rlwe_dash_space_nsm_partb + .iter_mut() + .chain(rlwe_dash_space_m_partb.iter_mut()), + ) + .for_each(|(rlwe_a, rlwe_b, rlwe_out_a, rlwe_out_b)| { + // RLWE(m0) x RGSW(m1) + + // Part A: Decomp \cdot RLWE'(-sm1) + { + let decomp_r_parta = &mut decomp_r_space[..rgsw1_decomposer.decomposition_count_a().0]; + decompose_r( + rlwe_a.as_ref(), + decomp_r_parta.as_mut(), + rgsw1_decomposer.a(), + ); + decomp_r_parta + .iter_mut() + .for_each(|ri| ntt_op.forward(ri.as_mut())); + poly_fma_routine( + rlwe_out_a.as_mut(), + &decomp_r_parta, + &rgsw1_nsm_parta, + mod_op, + ); + poly_fma_routine( + rlwe_out_b.as_mut(), + &decomp_r_parta, + &rgsw1_nsm_partb, + mod_op, + ); + } + + // Part B: Decompose \cdot RLWE'(m1) + { + let decomp_r_partb = &mut decomp_r_space[..rgsw1_decomposer.decomposition_count_b().0]; + decompose_r( + rlwe_b.as_ref(), + decomp_r_partb.as_mut(), + rgsw1_decomposer.b(), + ); + decomp_r_partb + .iter_mut() + .for_each(|ri| ntt_op.forward(ri.as_mut())); + poly_fma_routine(rlwe_out_a.as_mut(), &decomp_r_partb, &rgsw1_m_parta, mod_op); + poly_fma_routine(rlwe_out_b.as_mut(), &decomp_r_partb, &rgsw1_m_partb, mod_op); + } + }); + + // copy over RGSW(m0m1) to RGSW(m0) + // let d = rgsw0.as_mut(); + izip!(rgsw0.as_mut().iter_mut(), rgsw_space.data.iter()) + .for_each(|(to_ri, from_ri)| to_ri.as_mut().copy_from_slice(from_ri.as_ref())); + + // send back to coefficient domain + rgsw0 + .as_mut() + .iter_mut() + .for_each(|ri| ntt_op.backward(ri.as_mut())); +} + +/// Key switches input RLWE_{s'}(m) -> RLWE_{s}(m) +/// +/// Let RLWE_{s'}(m) = [a, b] s.t. m+e = b - as' +/// +/// Given key switchin key Ksk(s' -> s) = RLWE'_{s}(s') = [RLWE_{s}(beta^i s')] +/// = [a, a*s + e + beta^i s'] for i \in [0,d), key switching computes: +/// 1. RLWE_{s}(-s'a) = \sum signed_decompose(-a)[i] RLWE_{s}(beta^i s') +/// 2. RLWE_{s}(m) = (b, 0) + RLWE_{s}(-s'a) +/// +/// - rlwe_in: Input rlwe ciphertext +/// - ksk: Key switching key Ksk(s' -> s) with polynomials in evaluation domain +/// - ksk_shoup: Key switching key Ksk(s' -> s) with polynomials in evaluation +/// domain in shoup representation +/// - decomposer: Decomposer used for key switching +pub(crate) fn rlwe_key_switch< + M: MatrixMut + MatrixEntity, + ModOp: GetModulus + ShoupMatrixFMA + VectorOps, + NttOp: Ntt, + D: Decomposer, +>( + rlwe_in: &M, + ksk: &M, + ksk_shoup: &M, + decomposer: &D, + ntt_op: &NttOp, + mod_op: &ModOp, +) -> M +where + ::R: RowMut + RowEntity, + M::MatElement: Copy, +{ + let ring_size = rlwe_in.dimension().1; + assert!(rlwe_in.dimension().0 == 2); + assert!(ksk.dimension() == (decomposer.decomposition_count().0 * 2, ring_size)); + + let mut rlwe_out = M::zeros(2, ring_size); + + let mut tmp = M::zeros(decomposer.decomposition_count().0, ring_size); + let mut tmp_row = M::R::zeros(ring_size); + + // key switch RLWE part -A + // negative A + tmp_row.as_mut().copy_from_slice(rlwe_in.get_row_slice(0)); + mod_op.elwise_neg_mut(tmp_row.as_mut()); + // decompose -A and send to evaluation domain + decompose_r(tmp_row.as_ref(), tmp.as_mut(), decomposer); + tmp.iter_rows_mut() + .for_each(|r| ntt_op.forward_lazy(r.as_mut())); + + // RLWE_s(-A u) = B' + B, A' = (decomp(-A) * Ksk(u -> s)) + (B, 0) + let (ksk_part_a, ksk_part_b) = ksk.split_at_row(decomposer.decomposition_count().0); + let (ksk_part_a_shoup, ksk_part_b_shoup) = + ksk_shoup.split_at_row(decomposer.decomposition_count().0); + // Part A' + mod_op.shoup_matrix_fma( + rlwe_out.get_row_mut(0), + &ksk_part_a, + &ksk_part_a_shoup, + tmp.as_ref(), + ); + // Part B' + mod_op.shoup_matrix_fma( + rlwe_out.get_row_mut(1), + &ksk_part_b, + &ksk_part_b_shoup, + tmp.as_ref(), + ); + // back to coefficient domain + rlwe_out + .iter_rows_mut() + .for_each(|r| ntt_op.backward(r.as_mut())); + + // B' + B + mod_op.elwise_add_mut(rlwe_out.get_row_mut(1), rlwe_in.get_row_slice(1)); + + rlwe_out +} diff --git a/src/shortint/enc_dec.rs b/src/shortint/enc_dec.rs new file mode 100644 index 0000000..e8f8e77 --- /dev/null +++ b/src/shortint/enc_dec.rs @@ -0,0 +1,370 @@ +use itertools::Itertools; + +use crate::{ + bool::BoolEvaluator, + random::{DefaultSecureRng, RandomFillUniformInModulus}, + utils::WithLocal, + Decryptor, Encryptor, KeySwitchWithId, Matrix, MatrixEntity, MatrixMut, MultiPartyDecryptor, + RowMut, SampleExtractor, +}; + +/// Fhe UInt8 +/// +/// Note that `Self.data` stores encryptions of bits in little endian (i.e least +/// signficant bit stored at 0th index and most signficant bit stores at 7th +/// index) +#[derive(Clone)] +pub struct FheUint8 { + pub(super) data: Vec, +} + +impl FheUint8 { + pub(super) fn data(&self) -> &[C] { + &self.data + } + + pub(super) fn data_mut(&mut self) -> &mut [C] { + &mut self.data + } +} + +/// Stores a batch of Fhe Uint8 ciphertext as collection of unseeded RLWE +/// ciphertexts always encrypted under the ideal RLWE secret `s` of the MPC +/// protocol +/// +/// To extract Fhe Uint8 ciphertext at `index` call `self.extract(index)` +pub struct BatchedFheUint8 { + /// Vector of RLWE ciphertexts `C` + data: Vec, + /// Count of FheUint8s packed in vector of RLWE ciphertexts + count: usize, +} + +impl Encryptor<[u8], BatchedFheUint8> for K +where + K: Encryptor<[bool], Vec>, +{ + /// Encrypt a batch of uint8s packed in vector of RLWE ciphertexts + /// + /// Uint8s can be extracted from `BatchedFheUint8` with `SampleExtractor` + fn encrypt(&self, m: &[u8]) -> BatchedFheUint8 { + let bool_m = m + .iter() + .flat_map(|v| { + (0..8) + .into_iter() + .map(|i| ((*v >> i) & 1) == 1) + .collect_vec() + }) + .collect_vec(); + let cts = K::encrypt(&self, &bool_m); + BatchedFheUint8 { + data: cts, + count: m.len(), + } + } +} + +impl> From<&SeededBatchedFheUint8> + for BatchedFheUint8 +where + ::R: RowMut, +{ + /// Unseeds collection of seeded RLWE ciphertext in SeededBatchedFheUint8 + /// and returns as `Self` + fn from(value: &SeededBatchedFheUint8) -> Self { + BoolEvaluator::with_local(|e| { + let parameters = e.parameters(); + let ring_size = parameters.rlwe_n().0; + let rlwe_q = parameters.rlwe_q(); + + let mut prng = DefaultSecureRng::new_seeded(value.seed); + let rlwes = value + .data + .iter() + .map(|partb| { + let mut rlwe = M::zeros(2, ring_size); + + // sample A + RandomFillUniformInModulus::random_fill(&mut prng, rlwe_q, rlwe.get_row_mut(0)); + + // Copy over B + rlwe.get_row_mut(1).copy_from_slice(partb.as_ref()); + + rlwe + }) + .collect_vec(); + Self { + data: rlwes, + count: value.count, + } + }) + } +} + +impl SampleExtractor> for BatchedFheUint8 +where + C: SampleExtractor, +{ + /// Extract Fhe Uint8 ciphertext at `index` + /// + /// `Self` stores batch of Fhe uint8 ciphertext as vector of RLWE + /// ciphertexts. Since Fhe uint8 ciphertext is collection of 8 bool + /// ciphertexts, Fhe uint8 ciphertext at index `i` is stored in coefficients + /// `i*8...(i+1)*8`. To extract Fhe uint8 at index `i`, sample extract bool + /// ciphertext at indices `[i*8, ..., (i+1)*8)` + fn extract_at(&self, index: usize) -> FheUint8 { + assert!(index < self.count); + BoolEvaluator::with_local(|e| { + let ring_size = e.parameters().rlwe_n().0; + + let start_index = index * 8; + let end_index = (index + 1) * 8; + let data = (start_index..end_index) + .map(|i| { + let rlwe_index = i / ring_size; + let coeff_index = i % ring_size; + self.data[rlwe_index].extract_at(coeff_index) + }) + .collect_vec(); + FheUint8 { data } + }) + } + + /// Extracts all FheUint8s packed in vector of RLWE ciphertexts of `Self` + fn extract_all(&self) -> Vec> { + (0..self.count) + .map(|index| self.extract_at(index)) + .collect_vec() + } + + /// Extracts first `how_many` FheUint8s packed in vector of RLWE + /// ciphertexts of `Self` + fn extract_many(&self, how_many: usize) -> Vec> { + (0..how_many) + .map(|index| self.extract_at(index)) + .collect_vec() + } +} + +/// Stores a batch of FheUint8s packed in a collection unseeded RLWE ciphertexts +/// +/// `Self` stores unseeded RLWE ciphertexts encrypted under user's RLWE secret +/// `u_j` and is different from `BatchFheUint8` which stores collection of RLWE +/// ciphertexts under ideal RLWE secret `s` of the (non-interactive/interactive) +/// MPC protocol. +/// +/// To extract FheUint8s from `Self`'s collection of RLWE ciphertexts, first +/// switch `Self` to `BatchFheUint8` with `key_switch(user_id)` where `user_id` +/// is user's id. This key switches collection of RLWE ciphertexts from +/// user's RLWE secret `u_j` to ideal RLWE secret `s` of the MPC protocol. Then +/// proceed to use `SampleExtract` on `BatchFheUint8` (for ex, call +/// `extract_at(0)` to extract FheUint8 stored at index 0) +pub struct NonInteractiveBatchedFheUint8 { + /// Vector of RLWE ciphertexts `C` + data: Vec, + /// Count of FheUint8s packed in vector of RLWE ciphertexts + count: usize, +} + +impl> From<&SeededBatchedFheUint8> + for NonInteractiveBatchedFheUint8 +where + ::R: RowMut, +{ + /// Unseeds collection of seeded RLWE ciphertext in SeededBatchedFheUint8 + /// and returns as `Self` + fn from(value: &SeededBatchedFheUint8) -> Self { + BoolEvaluator::with_local(|e| { + let parameters = e.parameters(); + let ring_size = parameters.rlwe_n().0; + let rlwe_q = parameters.rlwe_q(); + + let mut prng = DefaultSecureRng::new_seeded(value.seed); + let rlwes = value + .data + .iter() + .map(|partb| { + let mut rlwe = M::zeros(2, ring_size); + + // sample A + RandomFillUniformInModulus::random_fill(&mut prng, rlwe_q, rlwe.get_row_mut(0)); + + // Copy over B + rlwe.get_row_mut(1).copy_from_slice(partb.as_ref()); + + rlwe + }) + .collect_vec(); + Self { + data: rlwes, + count: value.count, + } + }) + } +} + +impl KeySwitchWithId> for NonInteractiveBatchedFheUint8 +where + C: KeySwitchWithId, +{ + /// Key switch `Self`'s collection of RLWE cihertexts encrypted under user's + /// RLWE secret `u_j` to ideal RLWE secret `s` of the MPC protocol. + /// + /// - user_id: user id of user `j` + fn key_switch(&self, user_id: usize) -> BatchedFheUint8 { + let data = self + .data + .iter() + .map(|c| c.key_switch(user_id)) + .collect_vec(); + BatchedFheUint8 { + data, + count: self.count, + } + } +} + +pub struct SeededBatchedFheUint8 { + /// Vector of Seeded RLWE ciphertexts `C`. + /// + /// If RLWE(m) = [a, b] s.t. m + e = b - as, `a` can be seeded and seeded + /// RLWE ciphertext only contains `b` polynomial + data: Vec, + /// Seed for the ciphertexts + seed: S, + /// Count of FheUint8s packed in vector of RLWE ciphertexts + count: usize, +} + +impl Encryptor<[u8], SeededBatchedFheUint8> for K +where + K: Encryptor<[bool], (Vec, S)>, +{ + /// Encrypt a slice of u8s of arbitray length packed into collection of + /// seeded RLWE ciphertexts and return `SeededBatchedFheUint8` + fn encrypt(&self, m: &[u8]) -> SeededBatchedFheUint8 { + // convert vector of u8s to vector bools + let bool_m = m + .iter() + .flat_map(|v| (0..8).into_iter().map(|i| (((*v) >> i) & 1) == 1)) + .collect_vec(); + let (cts, seed) = K::encrypt(&self, &bool_m); + SeededBatchedFheUint8 { + data: cts, + seed, + count: m.len(), + } + } +} + +impl SeededBatchedFheUint8 { + /// Unseed collection of seeded RLWE ciphertexts of `Self` and returns + /// `NonInteractiveBatchedFheUint8` with collection of unseeded RLWE + /// ciphertexts. + /// + /// In non-interactive MPC setting, RLWE ciphertexts are encrypted under + /// user's RLWE secret `u_j`. The RLWE ciphertexts must be key switched to + /// ideal RLWE secret `s` of the MPC protocol before use. + /// + /// Note that we don't provide `unseed` API from `Self` to + /// `BatchedFheUint8`. This is because: + /// + /// - In non-interactive setting (1) client encrypts private inputs using + /// their secret `u_j` as `SeededBatchedFheUint8` and sends it to the + /// server. (2) Server unseeds `SeededBatchedFheUint8` into + /// `NonInteractiveBatchedFheUint8` indicating that private inputs are + /// still encrypted under user's RLWE secret `u_j`. (3) Server key + /// switches `NonInteractiveBatchedFheUint8` from user's RLWE secret `u_j` + /// to ideal RLWE secret `s` and outputs `BatchedFheUint8`. (4) + /// `BatchedFheUint8` always stores RLWE secret under ideal RLWE secret of + /// the protocol. Hence, it is safe to extract FheUint8s. Server proceeds + /// to extract necessary FheUint8s. + /// + /// - In interactive setting (1) client always encrypts private inputs using + /// public key corresponding to ideal RLWE secret `s` of the protocol and + /// produces `BatchedFheUint8`. (2) Given `BatchedFheUint8` stores + /// collection of RLWE ciphertext under ideal RLWE secret `s`, server can + /// directly extract necessary FheUint8s to use. + /// + /// Thus, there's no need to go directly from `Self` to `BatchedFheUint8`. + pub fn unseed(&self) -> NonInteractiveBatchedFheUint8 + where + NonInteractiveBatchedFheUint8: for<'a> From<&'a SeededBatchedFheUint8>, + M: Matrix, + { + NonInteractiveBatchedFheUint8::from(self) + } +} + +impl MultiPartyDecryptor> for K +where + K: MultiPartyDecryptor, + >::DecryptionShare: Clone, +{ + type DecryptionShare = Vec<>::DecryptionShare>; + fn gen_decryption_share(&self, c: &FheUint8) -> Self::DecryptionShare { + assert!(c.data().len() == 8); + c.data() + .iter() + .map(|bit_c| { + let decryption_share = + MultiPartyDecryptor::::gen_decryption_share(self, bit_c); + decryption_share + }) + .collect_vec() + } + + fn aggregate_decryption_shares(&self, c: &FheUint8, shares: &[Self::DecryptionShare]) -> u8 { + let mut out = 0u8; + + (0..8).into_iter().for_each(|i| { + // Collect bit i^th decryption share of each party + let bit_i_decryption_shares = shares.iter().map(|s| s[i].clone()).collect_vec(); + let bit_i = MultiPartyDecryptor::::aggregate_decryption_shares( + self, + &c.data()[i], + &bit_i_decryption_shares, + ); + + if bit_i { + out += 1 << i; + } + }); + + out + } +} + +impl Encryptor> for K +where + K: Encryptor, +{ + fn encrypt(&self, m: &u8) -> FheUint8 { + let cts = (0..8) + .into_iter() + .map(|i| { + let bit = ((m >> i) & 1) == 1; + K::encrypt(self, &bit) + }) + .collect_vec(); + FheUint8 { data: cts } + } +} + +impl Decryptor> for K +where + K: Decryptor, +{ + fn decrypt(&self, c: &FheUint8) -> u8 { + assert!(c.data.len() == 8); + let mut out = 0u8; + c.data().iter().enumerate().for_each(|(index, bit_c)| { + let bool = K::decrypt(self, bit_c); + if bool { + out += 1 << index; + } + }); + out + } +} diff --git a/src/shortint/mod.rs b/src/shortint/mod.rs new file mode 100644 index 0000000..a70e0cd --- /dev/null +++ b/src/shortint/mod.rs @@ -0,0 +1,294 @@ +mod enc_dec; +mod ops; + +pub type FheUint8 = enc_dec::FheUint8>; + +use std::cell::RefCell; + +use crate::bool::{BoolEvaluator, BooleanGates, FheBool, RuntimeServerKey}; + +thread_local! { + static DIV_ZERO_ERROR: RefCell> = RefCell::new(None); +} + +/// Returns Boolean ciphertext indicating whether last division was attempeted +/// with decnomiantor set to 0. +pub fn div_zero_error_flag() -> Option { + DIV_ZERO_ERROR.with_borrow(|c| c.clone()) +} + +/// Reset all error flags +/// +/// Error flags are thread local. When running multiple circuits in sequence +/// within a single program you must prevent error flags set during the +/// execution of previous circuit to affect error flags set during execution of +/// the next circuit. To do so call `reset_error_flags()`. +pub fn reset_error_flags() { + DIV_ZERO_ERROR.with_borrow_mut(|c| *c = None); +} + +mod frontend { + use super::ops::{ + arbitrary_bit_adder, arbitrary_bit_division_for_quotient_and_rem, arbitrary_bit_subtractor, + eight_bit_mul, is_zero, + }; + use crate::utils::{Global, WithLocal}; + + use super::*; + + /// Set Div by Zero flag after each divison. Div by zero flag is set to true + /// if either 1 of the division executed in circuit evaluation has + /// denominator set to 0. + fn set_div_by_zero_flag(denominator: &FheUint8) { + { + BoolEvaluator::with_local_mut(|e| { + let key = RuntimeServerKey::global(); + let is_zero = is_zero(e, denominator.data(), key); + DIV_ZERO_ERROR.with_borrow_mut(|before_is_zero| { + if before_is_zero.is_none() { + *before_is_zero = Some(FheBool { data: is_zero }); + } else { + e.or_inplace(before_is_zero.as_mut().unwrap().data_mut(), &is_zero, key); + } + }); + }) + } + } + + mod arithetic { + + use super::*; + use std::ops::{Add, AddAssign, Div, Mul, Rem, Sub}; + + impl AddAssign<&FheUint8> for FheUint8 { + fn add_assign(&mut self, rhs: &FheUint8) { + BoolEvaluator::with_local_mut_mut(&mut |e| { + let key = RuntimeServerKey::global(); + arbitrary_bit_adder(e, self.data_mut(), rhs.data(), false, key); + }); + } + } + + impl Add<&FheUint8> for &FheUint8 { + type Output = FheUint8; + fn add(self, rhs: &FheUint8) -> Self::Output { + let mut a = self.clone(); + a += rhs; + a + } + } + + impl Sub<&FheUint8> for &FheUint8 { + type Output = FheUint8; + fn sub(self, rhs: &FheUint8) -> Self::Output { + BoolEvaluator::with_local_mut(|e| { + let key = RuntimeServerKey::global(); + let (out, _, _) = arbitrary_bit_subtractor(e, self.data(), rhs.data(), key); + FheUint8 { data: out } + }) + } + } + + impl Mul<&FheUint8> for &FheUint8 { + type Output = FheUint8; + fn mul(self, rhs: &FheUint8) -> Self::Output { + BoolEvaluator::with_local_mut(|e| { + let key = RuntimeServerKey::global(); + let out = eight_bit_mul(e, self.data(), rhs.data(), key); + FheUint8 { data: out } + }) + } + } + + impl Div<&FheUint8> for &FheUint8 { + type Output = FheUint8; + fn div(self, rhs: &FheUint8) -> Self::Output { + // set div by 0 error flag + set_div_by_zero_flag(rhs); + + BoolEvaluator::with_local_mut(|e| { + let key = RuntimeServerKey::global(); + + let (quotient, _) = arbitrary_bit_division_for_quotient_and_rem( + e, + self.data(), + rhs.data(), + key, + ); + FheUint8 { data: quotient } + }) + } + } + + impl Rem<&FheUint8> for &FheUint8 { + type Output = FheUint8; + fn rem(self, rhs: &FheUint8) -> Self::Output { + BoolEvaluator::with_local_mut(|e| { + let key = RuntimeServerKey::global(); + let (_, remainder) = arbitrary_bit_division_for_quotient_and_rem( + e, + self.data(), + rhs.data(), + key, + ); + FheUint8 { data: remainder } + }) + } + } + + impl FheUint8 { + /// Calculates `Self += rhs` and returns `overflow` + /// + /// `overflow` is set to `True` if `Self += rhs` overflowed, + /// otherwise it is set to `False` + pub fn overflowing_add_assign(&mut self, rhs: &FheUint8) -> FheBool { + BoolEvaluator::with_local_mut_mut(&mut |e| { + let key = RuntimeServerKey::global(); + let (overflow, _) = + arbitrary_bit_adder(e, self.data_mut(), rhs.data(), false, key); + FheBool { data: overflow } + }) + } + + /// Returns (Self + rhs, overflow). + /// + /// `overflow` is set to `True` if `Self + rhs` overflowed, + /// otherwise it is set to `False` + pub fn overflowing_add(self, rhs: &FheUint8) -> (FheUint8, FheBool) { + BoolEvaluator::with_local_mut(|e| { + let mut lhs = self.clone(); + let key = RuntimeServerKey::global(); + let (overflow, _) = + arbitrary_bit_adder(e, lhs.data_mut(), rhs.data(), false, key); + (lhs, FheBool { data: overflow }) + }) + } + + /// Returns (Self - rhs, overflow). + /// + /// `overflow` is set to `True` if `Self - rhs` overflowed, + /// otherwise it is set to `False` + pub fn overflowing_sub(&self, rhs: &FheUint8) -> (FheUint8, FheBool) { + BoolEvaluator::with_local_mut(|e| { + let key = RuntimeServerKey::global(); + let (out, mut overflow, _) = + arbitrary_bit_subtractor(e, self.data(), rhs.data(), key); + e.not_inplace(&mut overflow); + (FheUint8 { data: out }, FheBool { data: overflow }) + }) + } + + /// Returns (quotient, remainder) s.t. self = rhs x quotient + + /// remainder. + /// + /// If rhs is 0, then quotient = 255, remainder = self, and Div by + /// Zero error flag (accessible via `div_zero_error_flag`) is set to + /// `True` + pub fn div_rem(&self, rhs: &FheUint8) -> (FheUint8, FheUint8) { + // set div by 0 error flag + set_div_by_zero_flag(rhs); + + BoolEvaluator::with_local_mut(|e| { + let key = RuntimeServerKey::global(); + + let (quotient, remainder) = arbitrary_bit_division_for_quotient_and_rem( + e, + self.data(), + rhs.data(), + key, + ); + (FheUint8 { data: quotient }, FheUint8 { data: remainder }) + }) + } + } + } + + mod booleans { + use crate::shortint::ops::{ + arbitrary_bit_comparator, arbitrary_bit_equality, arbitrary_bit_mux, + }; + + use super::*; + + impl FheUint8 { + /// Returns `FheBool` indicating `Self == other` + pub fn eq(&self, other: &FheUint8) -> FheBool { + BoolEvaluator::with_local_mut(|e| { + let key = RuntimeServerKey::global(); + let out = arbitrary_bit_equality(e, self.data(), other.data(), key); + FheBool { data: out } + }) + } + + /// Returns `FheBool` indicating `Self != other` + pub fn neq(&self, other: &FheUint8) -> FheBool { + BoolEvaluator::with_local_mut(|e| { + let key = RuntimeServerKey::global(); + let mut is_equal = arbitrary_bit_equality(e, self.data(), other.data(), key); + e.not_inplace(&mut is_equal); + FheBool { data: is_equal } + }) + } + + /// Returns `FheBool` indicating `Self < other` + pub fn lt(&self, other: &FheUint8) -> FheBool { + BoolEvaluator::with_local_mut(|e| { + let key = RuntimeServerKey::global(); + let out = arbitrary_bit_comparator(e, other.data(), self.data(), key); + FheBool { data: out } + }) + } + + /// Returns `FheBool` indicating `Self > other` + pub fn gt(&self, other: &FheUint8) -> FheBool { + BoolEvaluator::with_local_mut(|e| { + let key = RuntimeServerKey::global(); + let out = arbitrary_bit_comparator(e, self.data(), other.data(), key); + FheBool { data: out } + }) + } + + /// Returns `FheBool` indicating `Self <= other` + pub fn le(&self, other: &FheUint8) -> FheBool { + BoolEvaluator::with_local_mut(|e| { + let key = RuntimeServerKey::global(); + let mut a_greater_b = + arbitrary_bit_comparator(e, self.data(), other.data(), key); + e.not_inplace(&mut a_greater_b); + FheBool { data: a_greater_b } + }) + } + + /// Returns `FheBool` indicating `Self >= other` + pub fn ge(&self, other: &FheUint8) -> FheBool { + BoolEvaluator::with_local_mut(|e| { + let key = RuntimeServerKey::global(); + let mut a_less_b = arbitrary_bit_comparator(e, other.data(), self.data(), key); + e.not_inplace(&mut a_less_b); + FheBool { data: a_less_b } + }) + } + + /// Returns `Self` if `selector = True` else returns `other` + pub fn mux(&self, other: &FheUint8, selector: &FheBool) -> FheUint8 { + BoolEvaluator::with_local_mut(|e| { + let key = RuntimeServerKey::global(); + let out = arbitrary_bit_mux(e, selector.data(), self.data(), other.data(), key); + FheUint8 { data: out } + }) + } + + /// Returns max(`Self`, `other`) + pub fn max(&self, other: &FheUint8) -> FheUint8 { + let self_gt = self.gt(other); + self.mux(other, &self_gt) + } + + /// Returns min(`Self`, `other`) + pub fn min(&self, other: &FheUint8) -> FheUint8 { + let self_lt = self.lt(other); + self.mux(other, &self_lt) + } + } + } +} diff --git a/src/shortint/ops.rs b/src/shortint/ops.rs new file mode 100644 index 0000000..1beae6a --- /dev/null +++ b/src/shortint/ops.rs @@ -0,0 +1,356 @@ +use itertools::{izip, Itertools}; + +use crate::bool::BooleanGates; + +pub(super) fn half_adder( + evaluator: &mut E, + a: &mut E::Ciphertext, + b: &E::Ciphertext, + key: &E::Key, +) -> E::Ciphertext { + let carry = evaluator.and(a, b, key); + evaluator.xor_inplace(a, b, key); + carry +} + +pub(super) fn full_adder_plain_carry_in( + evaluator: &mut E, + a: &mut E::Ciphertext, + b: &E::Ciphertext, + carry_in: bool, + key: &E::Key, +) -> E::Ciphertext { + let mut a_and_b = evaluator.and(a, b, key); + evaluator.xor_inplace(a, b, key); //a = a ^ b + if carry_in { + // a_and_b = A & B | ((A^B) & C_in={True}) + evaluator.or_inplace(&mut a_and_b, &a, key); + } else { + // a_and_b = A & B | ((A^B) & C_in={False}) + // a_and_b = A & B + // noop + } + + // In xor if a input is 0, output equals the firt variable. If input is 1 then + // output equals !(first variable) + if carry_in { + // (A^B)^1 = !(A^B) + evaluator.not_inplace(a); + } else { + // (A^B)^0 + // no-op + } + a_and_b +} + +pub(super) fn full_adder( + evaluator: &mut E, + a: &mut E::Ciphertext, + b: &E::Ciphertext, + carry_in: &E::Ciphertext, + key: &E::Key, +) -> E::Ciphertext { + let mut a_and_b = evaluator.and(a, b, key); + evaluator.xor_inplace(a, b, key); //a = a ^ b + let a_xor_b_and_c = evaluator.and(&a, carry_in, key); + evaluator.or_inplace(&mut a_and_b, &a_xor_b_and_c, key); // a_and_b = A & B | ((A^B) & C_in) + evaluator.xor_inplace(a, &carry_in, key); + a_and_b +} + +pub(super) fn arbitrary_bit_adder( + evaluator: &mut E, + a: &mut [E::Ciphertext], + b: &[E::Ciphertext], + carry_in: bool, + key: &E::Key, +) -> (E::Ciphertext, E::Ciphertext) +where + E::Ciphertext: Clone, +{ + assert!(a.len() == b.len()); + let n = a.len(); + + let mut carry = if !carry_in { + half_adder(evaluator, &mut a[0], &b[0], key) + } else { + full_adder_plain_carry_in(evaluator, &mut a[0], &b[0], true, key) + }; + + izip!(a.iter_mut(), b.iter()) + .skip(1) + .take(n - 3) + .for_each(|(a_bit, b_bit)| { + carry = full_adder(evaluator, a_bit, b_bit, &carry, key); + }); + + let carry_last_last = full_adder(evaluator, &mut a[n - 2], &b[n - 2], &carry, key); + let carry_last = full_adder(evaluator, &mut a[n - 1], &b[n - 1], &carry_last_last, key); + + (carry_last, carry_last_last) +} + +pub(super) fn arbitrary_bit_subtractor( + evaluator: &mut E, + a: &[E::Ciphertext], + b: &[E::Ciphertext], + key: &E::Key, +) -> (Vec, E::Ciphertext, E::Ciphertext) +where + E::Ciphertext: Clone, +{ + let mut neg_b: Vec = b.iter().map(|v| evaluator.not(v)).collect(); + let (carry_last, carry_last_last) = arbitrary_bit_adder(evaluator, &mut neg_b, &a, true, key); + return (neg_b, carry_last, carry_last_last); +} + +pub(super) fn bit_mux( + evaluator: &mut E, + selector: E::Ciphertext, + if_true: &E::Ciphertext, + if_false: &E::Ciphertext, + key: &E::Key, +) -> E::Ciphertext { + // (s&a) | ((1-s)^b) + let not_selector = evaluator.not(&selector); + + let mut s_and_a = evaluator.and(&selector, if_true, key); + let s_and_b = evaluator.and(¬_selector, if_false, key); + evaluator.or(&mut s_and_a, &s_and_b, key); + s_and_a +} + +pub(super) fn arbitrary_bit_mux( + evaluator: &mut E, + selector: &E::Ciphertext, + if_true: &[E::Ciphertext], + if_false: &[E::Ciphertext], + key: &E::Key, +) -> Vec { + // (s&a) | ((1-s)^b) + let not_selector = evaluator.not(&selector); + + izip!(if_true.iter(), if_false.iter()) + .map(|(a, b)| { + let mut s_and_a = evaluator.and(&selector, a, key); + let s_and_b = evaluator.and(¬_selector, b, key); + evaluator.or_inplace(&mut s_and_a, &s_and_b, key); + s_and_a + }) + .collect() +} + +pub(super) fn eight_bit_mul( + evaluator: &mut E, + a: &[E::Ciphertext], + b: &[E::Ciphertext], + key: &E::Key, +) -> Vec { + assert!(a.len() == 8); + assert!(b.len() == 8); + let mut carries = Vec::with_capacity(7); + let mut out = Vec::with_capacity(8); + + for i in (0..8) { + if i == 0 { + let s = evaluator.and(&a[0], &b[0], key); + out.push(s); + } else if i == 1 { + let mut tmp0 = evaluator.and(&a[1], &b[0], key); + let tmp1 = evaluator.and(&a[0], &b[1], key); + let carry = half_adder(evaluator, &mut tmp0, &tmp1, key); + carries.push(carry); + out.push(tmp0); + } else { + let mut sum = { + let mut sum = evaluator.and(&a[i], &b[0], key); + let tmp = evaluator.and(&a[i - 1], &b[1], key); + carries[0] = full_adder(evaluator, &mut sum, &tmp, &carries[0], key); + sum + }; + + for j in 2..i { + let tmp = evaluator.and(&a[i - j], &b[j], key); + carries[j - 1] = full_adder(evaluator, &mut sum, &tmp, &carries[j - 1], key); + } + + let tmp = evaluator.and(&a[0], &b[i], key); + let carry = half_adder(evaluator, &mut sum, &tmp, key); + carries.push(carry); + + out.push(sum) + } + debug_assert!(carries.len() <= 7); + } + + out +} + +pub(super) fn arbitrary_bit_division_for_quotient_and_rem( + evaluator: &mut E, + a: &[E::Ciphertext], + b: &[E::Ciphertext], + key: &E::Key, +) -> (Vec, Vec) +where + E::Ciphertext: Clone, +{ + let n = a.len(); + let neg_b = b.iter().map(|v| evaluator.not(v)).collect_vec(); + + // Both remainder and quotient are initially stored in Big-endian in contrast to + // the usual little endian we use. This is more friendly to vec pushes in + // division. After computing remainder and quotient, we simply reverse the + // vectors. + let mut remainder = vec![]; + let mut quotient = vec![]; + for i in 0..n { + // left shift + remainder.push(a[n - 1 - i].clone()); + + let mut subtract = remainder.clone(); + + // subtraction + // At i^th iteration remainder is only filled with i bits and the rest of the + // bits are zero. For example, at i = 1 + // 0 0 0 0 0 0 X X => remainder + // - Y Y Y Y Y Y Y Y => divisor . + // --------------- . + // Z Z Z Z Z Z Z Z => result + // For the next iteration we only care about result if divisor is <= remainder + // (which implies result <= remainder). Otherwise we care about remainder + // (recall re-storing division). Hence we optimise subtraction and + // ignore full adders for places where remainder bits are known to be false + // bits. We instead use `ANDs` to compute the carry overs, since the + // last carry over indicates whether the value has overflown (i.e. divisor <= + // remainder). Last carry out is `true` if value has not overflown, otherwise + // false. + let mut carry = + full_adder_plain_carry_in(evaluator, &mut subtract[i], &neg_b[0], true, key); + for j in 1..i + 1 { + carry = full_adder(evaluator, &mut subtract[i - j], &neg_b[j], &carry, key); + } + for j in i + 1..n { + // All I care about are the carries + evaluator.and_inplace(&mut carry, &neg_b[j], key); + } + + let not_carry = evaluator.not(&carry); + // Choose `remainder` if subtraction has overflown (i.e. carry = false). + // Otherwise choose `subtractor`. + // + // mux k^a | !(k)^b, where k is the selector. + izip!(remainder.iter_mut(), subtract.iter_mut()).for_each(|(r, s)| { + // choose `s` when carry is true, otherwise choose r + evaluator.and_inplace(s, &carry, key); + evaluator.and_inplace(r, ¬_carry, key); + evaluator.or_inplace(r, s, key); + }); + + // Set i^th MSB of quotient to 1 if carry = true, otherwise set it to 0. + // X&1 | X&0 => X&1 => X + quotient.push(carry); + } + + remainder.reverse(); + quotient.reverse(); + + (quotient, remainder) +} + +pub(super) fn is_zero( + evaluator: &mut E, + a: &[E::Ciphertext], + key: &E::Key, +) -> E::Ciphertext { + let mut a = a.iter().map(|v| evaluator.not(v)).collect_vec(); + let (out, rest_a) = a.split_at_mut(1); + rest_a.iter().for_each(|c| { + evaluator.and_inplace(&mut out[0], c, key); + }); + return a.remove(0); +} + +pub(super) fn arbitrary_bit_equality( + evaluator: &mut E, + a: &[E::Ciphertext], + b: &[E::Ciphertext], + key: &E::Key, +) -> E::Ciphertext { + assert!(a.len() == b.len()); + let mut out = evaluator.xnor(&a[0], &b[0], key); + izip!(a.iter(), b.iter()).skip(1).for_each(|(abit, bbit)| { + let e = evaluator.xnor(abit, bbit, key); + evaluator.and_inplace(&mut out, &e, key); + }); + return out; +} + +/// Comparator handle computes comparator result 2ns MSB onwards. It is +/// separated because comparator subroutine for signed and unsgind integers +/// differs only for 1st MSB and is common second MSB onwards +fn _comparator_handler_from_second_msb( + evaluator: &mut E, + a: &[E::Ciphertext], + b: &[E::Ciphertext], + mut comp: E::Ciphertext, + mut casc: E::Ciphertext, + key: &E::Key, +) -> E::Ciphertext { + let n = a.len(); + + // handle MSB - 1 + let mut tmp = evaluator.not(&b[n - 2]); + evaluator.and_inplace(&mut tmp, &a[n - 2], key); + evaluator.and_inplace(&mut tmp, &casc, key); + evaluator.or_inplace(&mut comp, &tmp, key); + + for i in 2..n { + // calculate cascading bit + let tmp_casc = evaluator.xnor(&a[n - i], &b[n - i], key); + evaluator.and_inplace(&mut casc, &tmp_casc, key); + + // calculate computate bit + let mut tmp = evaluator.not(&b[n - 1 - i]); + evaluator.and_inplace(&mut tmp, &a[n - 1 - i], key); + evaluator.and_inplace(&mut tmp, &casc, key); + evaluator.or_inplace(&mut comp, &tmp, key); + } + + return comp; +} + +/// Signed integer comparison is same as unsigned integer with MSB flipped. +pub(super) fn arbitrary_signed_bit_comparator( + evaluator: &mut E, + a: &[E::Ciphertext], + b: &[E::Ciphertext], + key: &E::Key, +) -> E::Ciphertext { + assert!(a.len() == b.len()); + let n = a.len(); + + // handle MSB + let mut comp = evaluator.not(&a[n - 1]); + evaluator.and_inplace(&mut comp, &b[n - 1], key); // comp + let casc = evaluator.xnor(&a[n - 1], &b[n - 1], key); // casc + + return _comparator_handler_from_second_msb(evaluator, a, b, comp, casc, key); +} + +pub(super) fn arbitrary_bit_comparator( + evaluator: &mut E, + a: &[E::Ciphertext], + b: &[E::Ciphertext], + key: &E::Key, +) -> E::Ciphertext { + assert!(a.len() == b.len()); + let n = a.len(); + + // handle MSB + let mut comp = evaluator.not(&b[n - 1]); + evaluator.and_inplace(&mut comp, &a[n - 1], key); + let casc = evaluator.xnor(&a[n - 1], &b[n - 1], key); + + return _comparator_handler_from_second_msb(evaluator, a, b, comp, casc, key); +} diff --git a/src/utils.rs b/src/utils.rs index 40d007a..3cac804 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -1,9 +1,14 @@ -use std::usize; +use std::{usize, vec}; -use itertools::Itertools; -use num_traits::{PrimInt, Signed}; +use itertools::{izip, Itertools}; +use num_traits::{One, PrimInt, Signed}; -use crate::RandomUniformDist; +use crate::{ + backend::Modulus, + decomposer::NumInfo, + random::{RandomElementInModulus, RandomFill}, + Matrix, RowEntity, RowMut, +}; pub trait WithLocal { fn with_local(func: F) -> R where @@ -12,26 +17,78 @@ pub trait WithLocal { fn with_local_mut(func: F) -> R where F: Fn(&mut Self) -> R; + + fn with_local_mut_mut(func: &mut F) -> R + where + F: FnMut(&mut Self) -> R; +} + +pub trait Global { + fn global() -> &'static Self; +} + +pub(crate) trait ShoupMul { + fn representation(value: Self, q: Self) -> Self; + fn mul(a: Self, b: Self, b_shoup: Self, q: Self) -> Self; +} + +impl ShoupMul for u64 { + #[inline] + fn representation(value: Self, q: Self) -> Self { + ((value as u128 * (1u128 << 64)) / q as u128) as u64 + } + + #[inline] + /// Returns a * b % q + fn mul(a: Self, b: Self, b_shoup: Self, q: Self) -> Self { + (b.wrapping_mul(a)) + .wrapping_sub(q.wrapping_mul(((b_shoup as u128 * a as u128) >> 64) as u64)) + } +} + +pub(crate) trait ToShoup { + type Modulus; + fn to_shoup(value: &Self, modulus: Self::Modulus) -> Self; +} + +impl ToShoup for u64 { + type Modulus = u64; + fn to_shoup(value: &Self, modulus: Self) -> Self { + ((*value as u128 * (1u128 << 64)) / modulus as u128) as u64 + } +} + +impl ToShoup for Vec> { + type Modulus = u64; + fn to_shoup(value: &Self, modulus: Self::Modulus) -> Self { + let (row, col) = value.dimension(); + let mut shoup_value = vec![vec![0u64; col]; row]; + izip!(shoup_value.iter_mut(), value.iter()).for_each(|(shoup_r, r)| { + izip!(shoup_r.iter_mut(), r.iter()).for_each(|(s, e)| { + *s = u64::to_shoup(e, modulus); + }) + }); + shoup_value + } } pub fn fill_random_ternary_secret_with_hamming_weight< T: Signed, - R: RandomUniformDist<[u8], Parameters = u8> + RandomUniformDist, + R: RandomFill<[u8]> + RandomElementInModulus, >( out: &mut [T], hamming_weight: usize, rng: &mut R, ) { let mut bytes = vec![0u8; hamming_weight.div_ceil(8)]; - RandomUniformDist::<[u8]>::random_fill(rng, &0, &mut bytes); + RandomFill::<[u8]>::random_fill(rng, &mut bytes); let size = out.len(); let mut secret_indices = (0..size).into_iter().map(|i| i).collect_vec(); let mut bit_index = 0; let mut byte_index = 0; - for _ in 0..hamming_weight { - let mut s_index = 0usize; - RandomUniformDist::::random_fill(rng, &secret_indices.len(), &mut s_index); + for i in 0..hamming_weight { + let s_index = RandomElementInModulus::::random(rng, &secret_indices.len()); let curr_bit = (bytes[byte_index] >> bit_index) & 1; if curr_bit == 1 { @@ -41,7 +98,7 @@ pub fn fill_random_ternary_secret_with_hamming_weight< } secret_indices[s_index] = *secret_indices.last().unwrap(); - secret_indices.truncate(secret_indices.len()); + secret_indices.truncate(secret_indices.len() - 1); if bit_index == 7 { bit_index = 0; @@ -62,7 +119,7 @@ fn is_probably_prime(candidate: u64) -> bool { /// - $prime \lt upper_bound$ /// - $\log{prime} = num_bits$ /// - `prime % modulo == 1` -pub fn generate_prime(num_bits: usize, modulo: u64, upper_bound: u64) -> Option { +pub(crate) fn generate_prime(num_bits: usize, modulo: u64, upper_bound: u64) -> Option { let leading_zeros = (64 - num_bits) as u32; let mut tentative_prime = upper_bound - 1; @@ -107,15 +164,11 @@ pub fn mod_exponent(a: u64, mut b: u64, q: u64) -> u64 { out } -pub fn mod_inverse(a: u64, q: u64) -> u64 { +pub(crate) fn mod_inverse(a: u64, q: u64) -> u64 { mod_exponent(a, q - 2, q) } -pub fn shoup_representation_fq(v: u64, q: u64) -> u64 { - ((v as u128 * (1u128 << 64)) / q as u128) as u64 -} - -pub fn negacyclic_mul T>( +pub(crate) fn negacyclic_mul T>( a: &[T], b: &[T], mul: F, @@ -138,54 +191,167 @@ pub fn negacyclic_mul T>( return r; } -pub trait TryConvertFrom { - type Parameters: ?Sized; +/// Returns a polynomial X^{emebedding_factor * si} \mod {Z_Q / X^{N}+1} +pub(crate) fn encode_x_pow_si_with_emebedding_factor< + R: RowEntity + RowMut, + M: Modulus, +>( + si: i32, + embedding_factor: usize, + ring_size: usize, + modulus: &M, +) -> R +where + R::Element: One, +{ + assert!((si.abs() as usize) < ring_size); + let mut m = R::zeros(ring_size); + let si = si * (embedding_factor as i32); + if si < 0 { + // X^{-si} = X^{2N-si} = -X^{N-si}, assuming abs(si) < N + m.as_mut()[ring_size - (si.abs() as usize)] = modulus.neg_one(); + } else { + m.as_mut()[si as usize] = R::Element::one(); + } + m +} - fn try_convert_from(value: &T, parameters: &Self::Parameters) -> Self; +pub(crate) fn puncture_p_rng>( + p_rng: &mut R, + times: usize, +) -> S { + let mut out = S::default(); + for _ in 0..times { + RandomFill::::random_fill(p_rng, &mut out); + } + return out; } -impl TryConvertFrom<[i32]> for Vec> { - type Parameters = u32; - fn try_convert_from(value: &[i32], parameters: &Self::Parameters) -> Self { - let row0 = value - .iter() - .map(|v| { - let is_neg = v.is_negative(); - let v_u32 = v.abs() as u32; +pub(crate) fn log2(v: &T) -> usize { + if (*v & (*v - T::one())) == T::zero() { + // value is power of 2 + (T::BITS - v.leading_zeros() - 1) as usize + } else { + (T::BITS - v.leading_zeros()) as usize + } +} - assert!(v_u32 < *parameters); +pub trait TryConvertFrom1 { + fn try_convert_from(value: &T, parameters: &P) -> Self; +} - if is_neg { - parameters - v_u32 - } else { - v_u32 - } - }) - .collect_vec(); +impl> TryConvertFrom1<[i64], P> for Vec { + fn try_convert_from(value: &[i64], parameters: &P) -> Self { + value + .iter() + .map(|v| parameters.map_element_from_i64(*v)) + .collect_vec() + } +} - vec![row0] +impl> TryConvertFrom1<[i32], P> for Vec { + fn try_convert_from(value: &[i32], parameters: &P) -> Self { + value + .iter() + .map(|v| parameters.map_element_from_i64(*v as i64)) + .collect_vec() } } -impl TryConvertFrom<[i32]> for Vec> { - type Parameters = u64; - fn try_convert_from(value: &[i32], parameters: &Self::Parameters) -> Self { - let row0 = value +impl TryConvertFrom1<[P::Element], P> for Vec { + fn try_convert_from(value: &[P::Element], parameters: &P) -> Self { + value .iter() - .map(|v| { - let is_neg = v.is_negative(); - let v_u64 = v.abs() as u64; + .map(|v| parameters.map_element_to_i64(v)) + .collect_vec() + } +} + +#[cfg(test)] +pub(crate) mod tests { + use std::fmt::Debug; + + use num_traits::ToPrimitive; + + use crate::random::DefaultSecureRng; - assert!(v_u64 < *parameters); + use super::fill_random_ternary_secret_with_hamming_weight; - if is_neg { - parameters - v_u64 - } else { - v_u64 + #[derive(Clone)] + pub(crate) struct Stats { + pub(crate) samples: Vec, + } + + impl Default for Stats { + fn default() -> Self { + Stats { samples: vec![] } + } + } + + impl Stats + where + // T: for<'a> Sum<&'a T>, + T: for<'a> std::iter::Sum<&'a T> + std::iter::Sum, + { + pub(crate) fn new() -> Self { + Self { samples: vec![] } + } + + pub(crate) fn mean(&self) -> f64 { + self.samples.iter().sum::().to_f64().unwrap() / (self.samples.len() as f64) + } + + pub(crate) fn variance(&self) -> f64 { + let mean = self.mean(); + + // diff + let diff_sq = self + .samples + .iter() + .map(|v| { + let t = v.to_f64().unwrap() - mean; + t * t + }) + .into_iter() + .sum::(); + + diff_sq / (self.samples.len() as f64 - 1.0) + } + + pub(crate) fn std_dev(&self) -> f64 { + self.variance().sqrt() + } + + pub(crate) fn add_many_samples(&mut self, values: &[T]) { + self.samples.extend(values.iter()); + } + + pub(crate) fn add_sample(&mut self, value: T) { + self.samples.push(value) + } + + pub(crate) fn merge_in(&mut self, other: &Self) { + self.samples.extend(other.samples.iter()); + } + } + + #[test] + fn ternary_secret_has_correct_hw() { + let mut rng = DefaultSecureRng::new(); + for n in 4..15 { + let ring_size = 1 << n; + let mut out = vec![0i32; ring_size]; + fill_random_ternary_secret_with_hamming_weight(&mut out, ring_size >> 1, &mut rng); + + // check hamming weight of out equals ring_size/2 + let mut non_zeros = 0; + out.iter().for_each(|i| { + if *i != 0 { + non_zeros += 1; } - }) - .collect_vec(); + }); - vec![row0] + assert_eq!(ring_size >> 1, non_zeros); + } } }