Browse Source

Add version 0.1.0 (#2)

Add version 0.1.0. 

Includes:

- FHE Uint8, FHE bool, div by 0 error flag, and mux APIs
- non-interactive mpc for <= 8 parties
- interactive mpc for <= 8 parties
- multi-party decryption
- necessary examples
main
Janmajayamall 9 months ago
committed by GitHub
parent
commit
a8e6c27627
No known key found for this signature in database GPG Key ID: B5690EEEBB952194
41 changed files with 15362 additions and 1156 deletions
  1. +1
    -0
      .gitignore
  2. +507
    -10
      Cargo.lock
  3. +50
    -1
      Cargo.toml
  4. +86
    -0
      README.md
  5. +152
    -0
      benches/modulus.rs
  6. +151
    -0
      benches/ntt.rs
  7. +178
    -0
      examples/bomberman.rs
  8. +126
    -0
      examples/div_by_zero.rs
  9. +107
    -0
      examples/if_and_else.rs
  10. +180
    -0
      examples/interactive_fheuint8.rs
  11. +150
    -0
      examples/meeting_friends.rs
  12. +177
    -0
      examples/non_interactive_fheuint8.rs
  13. +0
    -163
      src/backend.rs
  14. +141
    -0
      src/backend/mod.rs
  15. +337
    -0
      src/backend/modulus_u64.rs
  16. +112
    -0
      src/backend/power_of_2.rs
  17. +124
    -0
      src/backend/word_size.rs
  18. +2323
    -0
      src/bool/evaluator.rs
  19. +1559
    -0
      src/bool/keys.rs
  20. +266
    -0
      src/bool/mod.rs
  21. +697
    -0
      src/bool/mp_api.rs
  22. +459
    -0
      src/bool/ni_mp_api.rs
  23. +738
    -0
      src/bool/parameters.rs
  24. +1020
    -0
      src/bool/print_noise.rs
  25. +295
    -88
      src/decomposer.rs
  26. +121
    -11
      src/lib.rs
  27. +221
    -174
      src/lwe.rs
  28. +1
    -3
      src/main.rs
  29. +286
    -0
      src/multi_party.rs
  30. +229
    -96
      src/ntt.rs
  31. +0
    -3
      src/num.rs
  32. +482
    -0
      src/pbs.rs
  33. +126
    -89
      src/random.rs
  34. +0
    -466
      src/rgsw.rs
  35. +677
    -0
      src/rgsw/keygen.rs
  36. +982
    -0
      src/rgsw/mod.rs
  37. +1063
    -0
      src/rgsw/runtime.rs
  38. +370
    -0
      src/shortint/enc_dec.rs
  39. +294
    -0
      src/shortint/mod.rs
  40. +356
    -0
      src/shortint/ops.rs
  41. +218
    -52
      src/utils.rs

+ 1
- 0
.gitignore

@ -1 +1,2 @@
/target /target
/.obsidian

+ 507
- 10
Cargo.lock

@ -2,6 +2,27 @@
# It is not intended for manual editing. # It is not intended for manual editing.
version = 3 version = 3
[[package]]
name = "aho-corasick"
version = "1.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8e60d3430d3a69478ad0993f19238d2df97c507009a52b3c10addcd7f6bcb916"
dependencies = [
"memchr",
]
[[package]]
name = "anes"
version = "0.1.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4b46cbb362ab8752921c97e041f5e366ee6297bd428a31275b9fcf1e380f7299"
[[package]]
name = "anstyle"
version = "1.0.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "038dfcf04a5feb68e9c60b21c9625a54c2c0616e79b72b0fd87075a056ae1d1b"
[[package]] [[package]]
name = "autocfg" name = "autocfg"
version = "1.2.0" version = "1.2.0"
@ -9,16 +30,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f1fdabc7756949593fe60f30ec81974b613357de856987752631dea1e3394c80" checksum = "f1fdabc7756949593fe60f30ec81974b613357de856987752631dea1e3394c80"
[[package]] [[package]]
name = "bin-rs"
version = "0.1.0"
dependencies = [
"itertools",
"num-bigint-dig",
"num-traits",
"rand",
"rand_chacha",
"rand_distr",
]
name = "bumpalo"
version = "3.16.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "79296716171880943b8470b5f8d03aa55eb2e645a4874bdbb28adb49162e012c"
[[package]] [[package]]
name = "byteorder" name = "byteorder"
@ -26,12 +41,137 @@ version = "1.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b"
[[package]]
name = "cast"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5"
[[package]] [[package]]
name = "cfg-if" name = "cfg-if"
version = "1.0.0" version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
[[package]]
name = "ciborium"
version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "42e69ffd6f0917f5c029256a24d0161db17cea3997d185db0d35926308770f0e"
dependencies = [
"ciborium-io",
"ciborium-ll",
"serde",
]
[[package]]
name = "ciborium-io"
version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "05afea1e0a06c9be33d539b876f1ce3692f4afea2cb41f740e7743225ed1c757"
[[package]]
name = "ciborium-ll"
version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "57663b653d948a338bfb3eeba9bb2fd5fcfaecb9e199e87e1eda4d9e8b240fd9"
dependencies = [
"ciborium-io",
"half",
]
[[package]]
name = "clap"
version = "4.5.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "90bc066a67923782aa8515dbaea16946c5bcc5addbd668bb80af688e53e548a0"
dependencies = [
"clap_builder",
]
[[package]]
name = "clap_builder"
version = "4.5.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ae129e2e766ae0ec03484e609954119f123cc1fe650337e155d03b022f24f7b4"
dependencies = [
"anstyle",
"clap_lex",
]
[[package]]
name = "clap_lex"
version = "0.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "98cc8fbded0c607b7ba9dd60cd98df59af97e84d24e49c8557331cfc26d301ce"
[[package]]
name = "criterion"
version = "0.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f2b12d017a929603d80db1831cd3a24082f8137ce19c69e6447f54f5fc8d692f"
dependencies = [
"anes",
"cast",
"ciborium",
"clap",
"criterion-plot",
"is-terminal",
"itertools 0.10.5",
"num-traits",
"once_cell",
"oorandom",
"plotters",
"rayon",
"regex",
"serde",
"serde_derive",
"serde_json",
"tinytemplate",
"walkdir",
]
[[package]]
name = "criterion-plot"
version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6b50826342786a51a89e2da3a28f1c32b06e387201bc2d19791f622c673706b1"
dependencies = [
"cast",
"itertools 0.10.5",
]
[[package]]
name = "crossbeam-deque"
version = "0.8.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "613f8cc01fe9cf1a3eb3d7f488fd2fa8388403e97039e2f73692932e291a770d"
dependencies = [
"crossbeam-epoch",
"crossbeam-utils",
]
[[package]]
name = "crossbeam-epoch"
version = "0.9.18"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5b82ac4a3c2ca9c3460964f020e1402edd5753411d7737aa39c3714ad1b5420e"
dependencies = [
"crossbeam-utils",
]
[[package]]
name = "crossbeam-utils"
version = "0.8.20"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "22ec99545bb0ed0ea7bb9b8e1e9122ea386ff8a48c0922e43f36d45ab09e0e80"
[[package]]
name = "crunchy"
version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7a81dae078cea95a014a339291cec439d2f232ebe854a9d672b796c6afafa9b7"
[[package]] [[package]]
name = "either" name = "either"
version = "1.11.0" version = "1.11.0"
@ -49,6 +189,42 @@ dependencies = [
"wasi", "wasi",
] ]
[[package]]
name = "half"
version = "2.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6dd08c532ae367adf81c312a4580bc67f1d0fe8bc9c460520283f4c0ff277888"
dependencies = [
"cfg-if",
"crunchy",
]
[[package]]
name = "hermit-abi"
version = "0.3.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d231dfb89cfffdbc30e7fc41579ed6066ad03abda9e567ccafae602b97ec5024"
[[package]]
name = "is-terminal"
version = "0.4.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f23ff5ef2b80d608d61efee834934d862cd92461afc0560dedf493e4c033738b"
dependencies = [
"hermit-abi",
"libc",
"windows-sys",
]
[[package]]
name = "itertools"
version = "0.10.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b0fd2260e829bddf4cb6ea802289de2f86d6a7a690192fbe91b3f46e0f2c8473"
dependencies = [
"either",
]
[[package]] [[package]]
name = "itertools" name = "itertools"
version = "0.12.1" version = "0.12.1"
@ -58,6 +234,21 @@ dependencies = [
"either", "either",
] ]
[[package]]
name = "itoa"
version = "1.0.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "49f1f14873335454500d59611f1cf4a4b0f786f9ac11f4312a78e4cf2566695b"
[[package]]
name = "js-sys"
version = "0.3.69"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "29c15563dc2726973df627357ce0c9ddddbea194836909d655df6a75d2cf296d"
dependencies = [
"wasm-bindgen",
]
[[package]] [[package]]
name = "lazy_static" name = "lazy_static"
version = "1.4.0" version = "1.4.0"
@ -79,6 +270,18 @@ version = "0.2.8"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4ec2a862134d2a7d32d7983ddcdd1c4923530833c9f2ea1a44fc5fa473989058" checksum = "4ec2a862134d2a7d32d7983ddcdd1c4923530833c9f2ea1a44fc5fa473989058"
[[package]]
name = "log"
version = "0.4.21"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "90ed8c1e510134f979dbc4f070f87d4313098b704861a105fe34231c70a3901c"
[[package]]
name = "memchr"
version = "2.7.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6c8640c5d730cb13ebd907d8d04b52f55ac9a2eec55b440c8892f40d56c76c1d"
[[package]] [[package]]
name = "num-bigint-dig" name = "num-bigint-dig"
version = "0.8.4" version = "0.8.4"
@ -126,6 +329,59 @@ dependencies = [
"libm", "libm",
] ]
[[package]]
name = "once_cell"
version = "1.19.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92"
[[package]]
name = "oorandom"
version = "11.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0ab1bc2a289d34bd04a330323ac98a1b4bc82c9d9fcb1e66b63caa84da26b575"
[[package]]
name = "phantom-zone"
version = "0.1.0"
dependencies = [
"criterion",
"itertools 0.12.1",
"num-bigint-dig",
"num-traits",
"rand",
"rand_chacha",
"rand_distr",
]
[[package]]
name = "plotters"
version = "0.3.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a15b6eccb8484002195a3e44fe65a4ce8e93a625797a063735536fd59cb01cf3"
dependencies = [
"num-traits",
"plotters-backend",
"plotters-svg",
"wasm-bindgen",
"web-sys",
]
[[package]]
name = "plotters-backend"
version = "0.3.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "414cec62c6634ae900ea1c56128dfe87cf63e7caece0852ec76aba307cebadb7"
[[package]]
name = "plotters-svg"
version = "0.3.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "81b30686a7d9c3e010b84284bdd26a29f2138574f52f5eb6f794fc0ad924e705"
dependencies = [
"plotters-backend",
]
[[package]] [[package]]
name = "ppv-lite86" name = "ppv-lite86"
version = "0.2.17" version = "0.2.17"
@ -190,6 +446,70 @@ dependencies = [
"rand", "rand",
] ]
[[package]]
name = "rayon"
version = "1.10.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b418a60154510ca1a002a752ca9714984e21e4241e804d32555251faf8b78ffa"
dependencies = [
"either",
"rayon-core",
]
[[package]]
name = "rayon-core"
version = "1.12.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1465873a3dfdaa8ae7cb14b4383657caab0b3e8a0aa9ae8e04b044854c8dfce2"
dependencies = [
"crossbeam-deque",
"crossbeam-utils",
]
[[package]]
name = "regex"
version = "1.10.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c117dbdfde9c8308975b6a18d71f3f385c89461f7b3fb054288ecf2a2058ba4c"
dependencies = [
"aho-corasick",
"memchr",
"regex-automata",
"regex-syntax",
]
[[package]]
name = "regex-automata"
version = "0.4.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "86b83b8b9847f9bf95ef68afb0b8e6cdb80f498442f5179a29fad448fcc1eaea"
dependencies = [
"aho-corasick",
"memchr",
"regex-syntax",
]
[[package]]
name = "regex-syntax"
version = "0.8.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "adad44e29e4c806119491a7f06f03de4d1af22c3a680dd47f1e6e179439d1f56"
[[package]]
name = "ryu"
version = "1.0.18"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f3cb5ba0dc43242ce17de99c180e96db90b235b8a9fdc9543c96d2209116bd9f"
[[package]]
name = "same-file"
version = "1.0.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "93fc1dc3aaa9bfed95e02e6eadabb4baf7e3078b0bd1b4d7b6b0b68378900502"
dependencies = [
"winapi-util",
]
[[package]] [[package]]
name = "serde" name = "serde"
version = "1.0.198" version = "1.0.198"
@ -210,6 +530,17 @@ dependencies = [
"syn", "syn",
] ]
[[package]]
name = "serde_json"
version = "1.0.117"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "455182ea6142b14f93f4bc5320a2b31c1f266b66a4a5c858b013302a5d8cbfc3"
dependencies = [
"itoa",
"ryu",
"serde",
]
[[package]] [[package]]
name = "smallvec" name = "smallvec"
version = "1.13.2" version = "1.13.2"
@ -233,14 +564,180 @@ dependencies = [
"unicode-ident", "unicode-ident",
] ]
[[package]]
name = "tinytemplate"
version = "1.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "be4d6b5f19ff7664e8c98d03e2139cb510db9b0a60b55f8e8709b689d939b6bc"
dependencies = [
"serde",
"serde_json",
]
[[package]] [[package]]
name = "unicode-ident" name = "unicode-ident"
version = "1.0.12" version = "1.0.12"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b" checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b"
[[package]]
name = "walkdir"
version = "2.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "29790946404f91d9c5d06f9874efddea1dc06c5efe94541a7d6863108e3a5e4b"
dependencies = [
"same-file",
"winapi-util",
]
[[package]] [[package]]
name = "wasi" name = "wasi"
version = "0.11.0+wasi-snapshot-preview1" version = "0.11.0+wasi-snapshot-preview1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423"
[[package]]
name = "wasm-bindgen"
version = "0.2.92"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4be2531df63900aeb2bca0daaaddec08491ee64ceecbee5076636a3b026795a8"
dependencies = [
"cfg-if",
"wasm-bindgen-macro",
]
[[package]]
name = "wasm-bindgen-backend"
version = "0.2.92"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "614d787b966d3989fa7bb98a654e369c762374fd3213d212cfc0251257e747da"
dependencies = [
"bumpalo",
"log",
"once_cell",
"proc-macro2",
"quote",
"syn",
"wasm-bindgen-shared",
]
[[package]]
name = "wasm-bindgen-macro"
version = "0.2.92"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a1f8823de937b71b9460c0c34e25f3da88250760bec0ebac694b49997550d726"
dependencies = [
"quote",
"wasm-bindgen-macro-support",
]
[[package]]
name = "wasm-bindgen-macro-support"
version = "0.2.92"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e94f17b526d0a461a191c78ea52bbce64071ed5c04c9ffe424dcb38f74171bb7"
dependencies = [
"proc-macro2",
"quote",
"syn",
"wasm-bindgen-backend",
"wasm-bindgen-shared",
]
[[package]]
name = "wasm-bindgen-shared"
version = "0.2.92"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "af190c94f2773fdb3729c55b007a722abb5384da03bc0986df4c289bf5567e96"
[[package]]
name = "web-sys"
version = "0.3.69"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "77afa9a11836342370f4817622a2f0f418b134426d91a82dfb48f532d2ec13ef"
dependencies = [
"js-sys",
"wasm-bindgen",
]
[[package]]
name = "winapi-util"
version = "0.1.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4d4cc384e1e73b93bafa6fb4f1df8c41695c8a91cf9c4c64358067d15a7b6c6b"
dependencies = [
"windows-sys",
]
[[package]]
name = "windows-sys"
version = "0.52.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d"
dependencies = [
"windows-targets",
]
[[package]]
name = "windows-targets"
version = "0.52.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6f0713a46559409d202e70e28227288446bf7841d3211583a4b53e3f6d96e7eb"
dependencies = [
"windows_aarch64_gnullvm",
"windows_aarch64_msvc",
"windows_i686_gnu",
"windows_i686_gnullvm",
"windows_i686_msvc",
"windows_x86_64_gnu",
"windows_x86_64_gnullvm",
"windows_x86_64_msvc",
]
[[package]]
name = "windows_aarch64_gnullvm"
version = "0.52.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7088eed71e8b8dda258ecc8bac5fb1153c5cffaf2578fc8ff5d61e23578d3263"
[[package]]
name = "windows_aarch64_msvc"
version = "0.52.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9985fd1504e250c615ca5f281c3f7a6da76213ebd5ccc9561496568a2752afb6"
[[package]]
name = "windows_i686_gnu"
version = "0.52.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "88ba073cf16d5372720ec942a8ccbf61626074c6d4dd2e745299726ce8b89670"
[[package]]
name = "windows_i686_gnullvm"
version = "0.52.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "87f4261229030a858f36b459e748ae97545d6f1ec60e5e0d6a3d32e0dc232ee9"
[[package]]
name = "windows_i686_msvc"
version = "0.52.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "db3c2bf3d13d5b658be73463284eaf12830ac9a26a90c717b7f771dfe97487bf"
[[package]]
name = "windows_x86_64_gnu"
version = "0.52.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4e4246f76bdeff09eb48875a0fd3e2af6aada79d409d33011886d3e1581517d9"
[[package]]
name = "windows_x86_64_gnullvm"
version = "0.52.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "852298e482cd67c356ddd9570386e2862b5673c85bd5f88df9ab6802b334c596"
[[package]]
name = "windows_x86_64_msvc"
version = "0.52.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bec47e5bfd1bff0eeaf6d8b485cc1074891a197ab4225d504cb7a1ab88b02bf0"

+ 50
- 1
Cargo.toml

@ -1,7 +1,12 @@
[package] [package]
name = "bin-rs"
name = "phantom-zone"
version = "0.1.0" version = "0.1.0"
edition = "2021" edition = "2021"
readme = "README.md"
repository = "https://github.com/gausslabs/phantom-zone"
license = "MIT"
keywords = ["fhe", "mpc", "cryptography"]
description = "Library for multi-party computation using fully-homomorphic encryption"
[dependencies] [dependencies]
itertools = "0.12.0" itertools = "0.12.0"
@ -10,3 +15,47 @@ rand = "0.8.5"
rand_chacha = "0.3.1" rand_chacha = "0.3.1"
rand_distr = "0.4.3" rand_distr = "0.4.3"
num-bigint-dig = { version = "0.8.4", features = ["prime"] } num-bigint-dig = { version = "0.8.4", features = ["prime"] }
[dev-dependencies]
criterion = "0.5.1"
[features]
interactive_mp = []
non_interactive_mp = []
[[bench]]
name = "ntt"
harness = false
[[bench]]
name = "modulus"
harness = false
[[example]]
name = "interactive_fheuint8"
path = "./examples/interactive_fheuint8.rs"
[[example]]
name = "non_interactive_fheuint8"
path = "./examples/non_interactive_fheuint8.rs"
required-features = ["non_interactive_mp"]
[[example]]
name = "meeting_friends"
path = "./examples/meeting_friends.rs"
required-features = ["non_interactive_mp"]
[[example]]
name = "bomberman"
path = "./examples/bomberman.rs"
required-features = ["non_interactive_mp"]
[[example]]
name = "div_by_zero"
path = "./examples/div_by_zero.rs"
required-features = ["non_interactive_mp"]
[[example]]
name = "if_and_else"
path = "./examples/if_and_else.rs"
required-features = ["non_interactive_mp"]

+ 86
- 0
README.md

@ -0,0 +1,86 @@
**Phantom zone** is similar to the zone where superman gets locked and observes everything outside the zone, but no one outside can see or hear superman. However our phantom zone isn't meant to lock anyone. Instead it's meant to be a new zone in parallel to reality. It's the zone to which you teleport yourself with others, take arbitrary actions, and remember only predefined set of memories when you're back. Think of the zone as a computer that erases itself off of the face of the earth after it returns the output, leaving no trace behind.
**More formally, phantom-zone is a experimental multi-party computation library that uses multi-party fully homomorphic encryption to compute arbitrary functions on private inputs from multiple parties.**
At the moment phantom-zone is pretty limited in its functionality. It offers to write circuits with encrypted 8 bit unsigned integers (referred to as FheUint8) and only supports upto 8 parties. FheUint8 supports the same arithmetic as a regular uint8, with a few exceptions mentioned below. We plan to extend APIs to other signed/unsigned types in the future.
We provide two types of multi-party protocols, both only differ in key-generation procedure.
1. **Non-interactive multi-party protocol,** which requires a single shot message from the clients to the server after which the server can evaluate any arbitrary function on encrypted client inputs.
2. **Interactive multi-party protocol**, a 2 round protocol where in the first round clients interact to generate collective public key and in the second round clients send their server key share to the server, after which server can evaluate any arbitrary function on encrypted client inputs.
Understanding that library is in experimental stage, if you want to use it for your application but find it lacking in features, please don't shy away from opening a issue or getting in touch. We don't want you to hold back those imaginative horses!
## How to use
We provide a brief overview below. Please refer to detailed examples, especially [non_interactive_fheuint8](./examples/non_interactive_fheuint8.rs) and [interactive_fheuint8](./examples/interactive_fheuint8.rs), to understand how to instantiate and run the protocols.
### Non-interactive multi-party
Each client is assigned an `id`, referred to as `user_id`, which denotes serial no. of the client out of total clients participating in the multi-party protocol. After learning their `user_id`, the client uploads their server key share along with encryptions of private inputs in a single shot message to the server. Server can then evaluate any arbitrary function on clients' private inputs. New private inputs can be provided in the future by the fix set of parties that participated in the protocol.
### Interactive multi-party
Like the non-interactive multi-party, each client is assigned `user_id`. After learning their `id`, clients participate in a 2 round protocol. In round 1, clients generate public key shares, share it with each other, and aggregate public key shares to produce the collective public key. In round 2, clients use the collective public key to generate their server key shares and encrypt their private inputs. Server receives server key shares and encryptions of private inputs from each client. Server aggregates the server key shares, after which it can evaluate any arbitrary function on clients' private inputs. New private inputs can be provided in the future by anyone with access to collective public key.
### Multi-party decryption
To decrypt output ciphertext(s) obtained as result of some computation, the clients come online. They download output ciphertext(s) from the server, generate decryption shares, and share it with other parties. Clients, after receiving decryption shares of other parties, aggregate the shares and decrypt the ciphertext(s).
### Parameter selection
We provide parameters to run both multi-party protocols for upto 8 parties.
| $\leq$ # Parties | Interactive multi-party | Non-interactive multi-party |
| ------------ | ----------------------- | --------------------------- |
| 2 | InteractiveLTE2Party | NonInteractiveLTE2Party |
| 4 | InteractiveLTE4Party | NonInteractiveLTE4Party |
| 8 | InteractiveLTE8Party | NonInteractiveLTE8Party |
If you have use-case `> 8` parties, please open an issue. We're willing to find suitable parameters for `> 8` parties.
Parameters supporting `<= N` parties must not be used for multi-party compute between `> N` parties. This will lead to increase in failure probability.
### Feature selection
To use the library for non-interactive multi-party, you must add `non_interactive_mp` feature flag like `--features "non_interactive_mp"`. And to use the library for interactive multi-party you must add `interactive_mp` feature flag like `--features "interactive_mp"`.
### FheUInt8
We provide APIs for all basic arithmetic (+, -, x, /, %) and comparison operations.
All arithmetic operation by default wrap around (i.e. $\mod{256}$ for FheUint8). We also provide `overflow_{add/add_assign}` and `overflow_sub` that returns a flag ciphertext which is set to `True` if addition/subtraction overflowed and to `False` otherwise.
Division operation (/) returns `quotient` and the remainder operation (%) returns `remainder` s.t. `dividend = division x quotient + remainder`. If both `quotient` and `remainder` are required, then `div_rem` can be used. In case of division by zero, [Div by zero error flag](#Div-by-zero-error-flag) will be set and `quotient` will be set to `255` and `remainder` to equal `dividend`.
**Div by zero error flag**
In encrypted domain there's no way to panic upon division by zero at runtime. Instead we set a local flag ciphertext accessible via `div_by_zero_flag()` that stores a boolean ciphertext indicating whether any of the divisions performed during the execution attempted division by zero. Assuming division by zero detection is critical for your application, we recommend decrypting the flag ciphertext along with other output ciphertexts in multi-party decryption procedure.
The div by zero flag is thread local. If you run multiple different FHE circuits in sequence without stopping the thread (i.e. within a single program) you will have to reset div by zero error flag before starting of the next in-sequence circuit execution with `reset_error_flags()`.
Please refer to [div_by_zero](./examples/div_by_zero.rs) example for more details.
**If and else using mux**
Branching in encrypted domain is expensive because the code must execute all the branches. Hence cost grows exponentially with no. of conditional branches. In general we recommend to modify the code to minimise conditional branches. However, if a code cannot be modified to made branchless, we provide `mux` API for FheUint8s. `mux` selects one of the two FheUint8s based on a selector bit. Please refer to [if_and_else](./examples/if_and_else.rs) example for more details.
## Security
> [!WARNING]
> Code has not been audited and we currently do not provide any security guarantees outside of the cryptographic parameters. We don't recommend to deploy it in production or use to handle sensitive data.
All provided parameters are $2^{128}$ ring operations secure according to [lattice estimator](https://github.com/malb/lattice-estimator) and have failure probability of $\leq 2^{-40}$. However, there are two vital points to keep in mind:
1. Users must not generate two different decryption shares for the same ciphertext, as it can lead to key-recovery attacks. To avoid this, we suggest users maintain a local table listing ciphertext against any previously generated decryption share. Then only generate a new decryption share if ciphertext does not exist in the table, otherwise return the existing share. We believe this should be handled by the library and will add support for this in future.
2. Users must not run the MPC protocol more than once for the same application seed and produce different outputs, as it can lead to key-recovery attacks. We believe this should be handled by the library and will add support for this in future.
## Credits
- We thank Barry Whitehat and Brian Lawrence for many helpful discussions.
- We thank Vivek Bhupatiraju and Andrew Lu for for many insightful discussions on fascinating phantom zone applications.
- Non-interactive multi-party RLWE key generation setup is a new protocol designed by Jean Philippe and Janmajaya. We thank Christian Mouchet for the review of the protocol and helpful suggestions.
- We thank Yao Wang for his help with rust compiler troubles.
## References
1. [Efficient FHEW Bootstrapping with Small Evaluation Keys, and Applications to Threshold Homomorphic Encryption](https://eprint.iacr.org/2022/198.pdf)
2. [Multiparty Homomorphic Encryption from Ring-Learning-with-Errors](https://eprint.iacr.org/2020/304.pdf)

+ 152
- 0
benches/modulus.rs

@ -0,0 +1,152 @@
use bin_rs::{
ArithmeticLazyOps, ArithmeticOps, Decomposer, DefaultDecomposer, ModInit, ModularOpsU64,
ShoupMatrixFMA, VectorOps,
};
use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion};
use itertools::{izip, Itertools};
use rand::{thread_rng, Rng};
use rand_distr::Uniform;
fn decompose_r(r: &[u64], decomp_r: &mut [Vec<u64>], decomposer: &DefaultDecomposer<u64>) {
let ring_size = r.len();
// let d = decomposer.decomposition_count();
// let mut count = 0;
for ri in 0..ring_size {
// let el_decomposed = decomposer.decompose(&r[ri]);
decomposer
.decompose_iter(&r[ri])
.enumerate()
.into_iter()
.for_each(|(j, el)| {
decomp_r[j][ri] = el;
});
}
}
fn matrix_fma(out: &mut [u64], a: &Vec<Vec<u64>>, b: &Vec<Vec<u64>>, modop: &ModularOpsU64<u64>) {
izip!(a.iter(), b.iter()).for_each(|(a_r, b_r)| {
izip!(out.iter_mut(), a_r.iter(), b_r.iter())
.for_each(|(o, ai, bi)| *o = modop.add_lazy(o, &modop.mul_lazy(ai, bi)));
});
}
fn benchmark_decomposer(c: &mut Criterion) {
let mut group = c.benchmark_group("decomposer");
// let decomposers = vec![];
// 55
for prime in [36028797017456641] {
for ring_size in [1 << 11] {
let logb = 11;
let decomposer = DefaultDecomposer::new(prime, logb, 2);
let mut rng = thread_rng();
let dist = Uniform::new(0, prime);
let a = (&mut rng).sample_iter(dist).take(ring_size).collect_vec();
group.bench_function(
BenchmarkId::new(
"decompose",
format!(
"q={prime}/N={ring_size}/logB={logb}/d={}",
*decomposer.decomposition_count().as_ref()
),
),
|b| {
b.iter_batched_ref(
|| {
(
a.clone(),
vec![
vec![0u64; ring_size];
*decomposer.decomposition_count().as_ref()
],
)
},
|(r, decomp_r)| (decompose_r(r, decomp_r, &decomposer)),
criterion::BatchSize::PerIteration,
)
},
);
}
}
group.finish();
}
fn benchmark(c: &mut Criterion) {
let mut group = c.benchmark_group("modulus");
// 55
for prime in [36028797017456641] {
for ring_size in [1 << 11] {
let modop = ModularOpsU64::new(prime);
let mut rng = thread_rng();
let dist = Uniform::new(0, prime);
let a0 = (&mut rng).sample_iter(dist).take(ring_size).collect_vec();
let a1 = (&mut rng).sample_iter(dist).take(ring_size).collect_vec();
let a2 = (&mut rng).sample_iter(dist).take(ring_size).collect_vec();
let d = 1;
let a0_matrix = (0..d)
.into_iter()
.map(|_| (&mut rng).sample_iter(dist).take(ring_size).collect_vec())
.collect_vec();
// a0 in shoup representation
let a0_shoup_matrix = a0_matrix
.iter()
.map(|r| {
r.iter()
.map(|v| {
// $(v * 2^{\beta}) / p$
((*v as u128 * (1u128 << 64)) / prime as u128) as u64
})
.collect_vec()
})
.collect_vec();
let a1_matrix = (0..d)
.into_iter()
.map(|_| (&mut rng).sample_iter(dist).take(ring_size).collect_vec())
.collect_vec();
group.bench_function(
BenchmarkId::new("matrix_fma_lazy", format!("q={prime}/N={ring_size}/d={d}")),
|b| {
b.iter_batched_ref(
|| (vec![0u64; ring_size]),
|(out)| black_box(matrix_fma(out, &a0_matrix, &a1_matrix, &modop)),
criterion::BatchSize::PerIteration,
)
},
);
group.bench_function(
BenchmarkId::new(
"matrix_shoup_fma_lazy",
format!("q={prime}/N={ring_size}/d={d}"),
),
|b| {
b.iter_batched_ref(
|| (vec![0u64; ring_size]),
|(out)| {
black_box(modop.shoup_matrix_fma(
out,
&a0_matrix,
&a0_shoup_matrix,
&a1_matrix,
))
},
criterion::BatchSize::PerIteration,
)
},
);
}
}
group.finish();
}
criterion_group!(decomposer, benchmark_decomposer);
criterion_group!(modulus, benchmark);
criterion_main!(modulus, decomposer);

+ 151
- 0
benches/ntt.rs

@ -0,0 +1,151 @@
use bin_rs::{Ntt, NttBackendU64, NttInit};
use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion};
use itertools::Itertools;
use rand::{thread_rng, Rng};
use rand_distr::Uniform;
fn forward_matrix(a: &mut [Vec<u64>], nttop: &NttBackendU64) {
a.iter_mut().for_each(|r| nttop.forward(r.as_mut_slice()));
}
fn forward_lazy_matrix(a: &mut [Vec<u64>], nttop: &NttBackendU64) {
a.iter_mut()
.for_each(|r| nttop.forward_lazy(r.as_mut_slice()));
}
fn backward_matrix(a: &mut [Vec<u64>], nttop: &NttBackendU64) {
a.iter_mut().for_each(|r| nttop.backward(r.as_mut_slice()));
}
fn backward_lazy_matrix(a: &mut [Vec<u64>], nttop: &NttBackendU64) {
a.iter_mut()
.for_each(|r| nttop.backward_lazy(r.as_mut_slice()));
}
fn benchmark(c: &mut Criterion) {
let mut group = c.benchmark_group("ntt");
// 55
for prime in [36028797017456641] {
for ring_size in [1 << 11] {
let ntt = NttBackendU64::new(&prime, ring_size);
let mut rng = thread_rng();
let a = (&mut rng)
.sample_iter(Uniform::new(0, prime))
.take(ring_size)
.collect_vec();
let d = 2;
let a_matrix = (0..d)
.map(|_| {
(&mut rng)
.sample_iter(Uniform::new(0, prime))
.take(ring_size)
.collect_vec()
})
.collect_vec();
{
group.bench_function(
BenchmarkId::new("forward", format!("q={prime}/N={ring_size}")),
|b| {
b.iter_batched_ref(
|| a.clone(),
|mut a| black_box(ntt.forward(&mut a)),
criterion::BatchSize::PerIteration,
)
},
);
group.bench_function(
BenchmarkId::new("forward_lazy", format!("q={prime}/N={ring_size}")),
|b| {
b.iter_batched_ref(
|| a.clone(),
|mut a| black_box(ntt.forward_lazy(&mut a)),
criterion::BatchSize::PerIteration,
)
},
);
group.bench_function(
BenchmarkId::new("forward_matrix", format!("q={prime}/N={ring_size}/d={d}")),
|b| {
b.iter_batched_ref(
|| a_matrix.clone(),
|a_matrix| black_box(forward_matrix(a_matrix, &ntt)),
criterion::BatchSize::PerIteration,
)
},
);
group.bench_function(
BenchmarkId::new(
"forward_lazy_matrix",
format!("q={prime}/N={ring_size}/d={d}"),
),
|b| {
b.iter_batched_ref(
|| a_matrix.clone(),
|a_matrix| black_box(forward_lazy_matrix(a_matrix, &ntt)),
criterion::BatchSize::PerIteration,
)
},
);
}
{
group.bench_function(
BenchmarkId::new("backward", format!("q={prime}/N={ring_size}")),
|b| {
b.iter_batched_ref(
|| a.clone(),
|mut a| black_box(ntt.backward(&mut a)),
criterion::BatchSize::PerIteration,
)
},
);
group.bench_function(
BenchmarkId::new("backward_lazy", format!("q={prime}/N={ring_size}")),
|b| {
b.iter_batched_ref(
|| a.clone(),
|mut a| black_box(ntt.backward_lazy(&mut a)),
criterion::BatchSize::PerIteration,
)
},
);
group.bench_function(
BenchmarkId::new("backward_matrix", format!("q={prime}/N={ring_size}")),
|b| {
b.iter_batched_ref(
|| a_matrix.clone(),
|a_matrix| black_box(backward_matrix(a_matrix, &ntt)),
criterion::BatchSize::PerIteration,
)
},
);
group.bench_function(
BenchmarkId::new(
"backward_lazy_matrix",
format!("q={prime}/N={ring_size}/d={d}"),
),
|b| {
b.iter_batched_ref(
|| a_matrix.clone(),
|a_matrix| black_box(backward_lazy_matrix(a_matrix, &ntt)),
criterion::BatchSize::PerIteration,
)
},
);
}
}
}
group.finish();
}
criterion_group!(ntt, benchmark);
criterion_main!(ntt);

+ 178
- 0
examples/bomberman.rs

@ -0,0 +1,178 @@
use std::fmt::Debug;
use itertools::Itertools;
use phantom_zone::*;
use rand::{thread_rng, Rng, RngCore};
struct Coordinates<T>(T, T);
impl<T> Coordinates<T> {
fn new(x: T, y: T) -> Self {
Coordinates(x, y)
}
fn x(&self) -> &T {
&self.0
}
fn y(&self) -> &T {
&self.1
}
}
impl<T> Debug for Coordinates<T>
where
T: Debug,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Coordinates")
.field("x", self.x())
.field("y", self.y())
.finish()
}
}
fn coordinates_is_equal(a: &Coordinates<FheUint8>, b: &Coordinates<FheUint8>) -> FheBool {
&(a.x().eq(b.x())) & &(a.y().eq(b.y()))
}
/// Traverse the map with `p0` moves and check whether any of the moves equal
/// bomb coordinates (in encrypted domain)
fn traverse_map(p0: &[Coordinates<FheUint8>], bomb_coords: &[Coordinates<FheUint8>]) -> FheBool {
// First move
let mut out = coordinates_is_equal(&p0[0], &bomb_coords[0]);
bomb_coords.iter().skip(1).for_each(|b_coord| {
out |= coordinates_is_equal(&p0[0], &b_coord);
});
// rest of the moves
p0.iter().skip(1).for_each(|m_coord| {
bomb_coords.iter().for_each(|b_coord| {
out |= coordinates_is_equal(m_coord, b_coord);
});
});
out
}
// Do you recall bomberman? It's an interesting game where the bomberman has to
// cross the map without stepping on strategically placed bombs all over the
// map. Below we implement a very basic prototype of bomberman with 4 players.
//
// The map has 256 tiles with bottom left-most tile labelled with coordinates
// (0,0) and top right-most tile labelled with coordinates (255, 255). There are
// 4 players: Player 0, Player 1, Player 2, Player 3. Player 0's task is to walk
// across the map with fixed no. of moves while preventing itself from stepping
// on any of the bombs placed on the map by Player 1, 2, and 3.
//
// The twist is that Player 0's moves and the locations of bombs placed by other
// players are encrypted. Player 0 moves across the map in encrypted domain.
// Only a boolean output indicating whether player 0 survived after all the
// moves or killed itself by stepping onto a bomb is revealed at the end. If
// player 0 survives, Player 1, 2, 3 never learn what moves did Player 0 make.
// If Player 0 kills itself by stepping onto a bomb, it only learns that bomb
// was placed on one of the coordinates it moved to. Moreover, Player 1, 2, 3
// never learn locations of each other bombs or whose bomb killed Player 0.
fn main() {
set_parameter_set(ParameterSelector::NonInteractiveLTE4Party);
// set application's common reference seed
let mut seed = [0; 32];
thread_rng().fill_bytes(&mut seed);
set_common_reference_seed(seed);
let no_of_parties = 4;
// Client side //
// Players generate client keys
let cks = (0..no_of_parties).map(|_| gen_client_key()).collect_vec();
// Players generate server keys
let server_key_shares = cks
.iter()
.enumerate()
.map(|(index, k)| gen_server_key_share(index, no_of_parties, k))
.collect_vec();
// Player 0 describes its moves as sequence of coordinates on the map
let no_of_moves = 10;
let player_0_moves = (0..no_of_moves)
.map(|_| Coordinates::new(thread_rng().gen::<u8>(), thread_rng().gen()))
.collect_vec();
// Coordinates of bomb placed by Player 1
let player_1_bomb = Coordinates::new(thread_rng().gen::<u8>(), thread_rng().gen());
// Coordinates of bomb placed by Player 2
let player_2_bomb = Coordinates::new(thread_rng().gen::<u8>(), thread_rng().gen());
// Coordinates of bomb placed by Player 3
let player_3_bomb = Coordinates::new(thread_rng().gen::<u8>(), thread_rng().gen());
println!("P0 moves coordinates: {:?}", &player_0_moves);
println!("P1 bomb coordinates : {:?}", &player_1_bomb);
println!("P2 bomb coordinates : {:?}", &player_2_bomb);
println!("P3 bomb coordinates : {:?}", &player_3_bomb);
// Players encrypt their private inputs
let player_0_enc = cks[0].encrypt(
player_0_moves
.iter()
.flat_map(|c| vec![*c.x(), *c.y()])
.collect_vec()
.as_slice(),
);
let player_1_enc = cks[1].encrypt(vec![*player_1_bomb.x(), *player_1_bomb.y()].as_slice());
let player_2_enc = cks[2].encrypt(vec![*player_2_bomb.x(), *player_2_bomb.y()].as_slice());
let player_3_enc = cks[3].encrypt(vec![*player_3_bomb.x(), *player_3_bomb.y()].as_slice());
// Players upload the encrypted inputs and server key shares to the server
// Server side //
let server_key = aggregate_server_key_shares(&server_key_shares);
server_key.set_server_key();
// server parses Player inputs
let player_0_moves_enc = {
let c = player_0_enc
.unseed::<Vec<Vec<u64>>>()
.key_switch(0)
.extract_all();
c.into_iter()
.chunks(2)
.into_iter()
.map(|mut x_y| Coordinates::new(x_y.next().unwrap(), x_y.next().unwrap()))
.collect_vec()
};
let player_1_bomb_enc = {
let c = player_1_enc.unseed::<Vec<Vec<u64>>>().key_switch(1);
Coordinates::new(c.extract_at(0), c.extract_at(1))
};
let player_2_bomb_enc = {
let c = player_2_enc.unseed::<Vec<Vec<u64>>>().key_switch(2);
Coordinates::new(c.extract_at(0), c.extract_at(1))
};
let player_3_bomb_enc = {
let c = player_3_enc.unseed::<Vec<Vec<u64>>>().key_switch(3);
Coordinates::new(c.extract_at(0), c.extract_at(1))
};
// Server runs the game
let player_0_dead_ct = traverse_map(
&player_0_moves_enc,
&vec![player_1_bomb_enc, player_2_bomb_enc, player_3_bomb_enc],
);
// Client side //
// Players generate decryption shares and send them to each other
let decryption_shares = cks
.iter()
.map(|k| k.gen_decryption_share(&player_0_dead_ct))
.collect_vec();
// Players decrypt to find whether Player 0 survived
let player_0_dead = cks[0].aggregate_decryption_shares(&player_0_dead_ct, &decryption_shares);
if player_0_dead {
println!("Oops! Player 0 dead");
} else {
println!("Wohoo! Player 0 survived");
}
}

+ 126
- 0
examples/div_by_zero.rs

@ -0,0 +1,126 @@
use itertools::Itertools;
use phantom_zone::*;
use rand::{thread_rng, Rng, RngCore};
fn main() {
set_parameter_set(ParameterSelector::NonInteractiveLTE2Party);
// set application's common reference seed
let mut seed = [0u8; 32];
thread_rng().fill_bytes(&mut seed);
set_common_reference_seed(seed);
let no_of_parties = 2;
// Generate client keys
let cks = (0..no_of_parties).map(|_| gen_client_key()).collect_vec();
// Generate server key shares
let server_key_shares = cks
.iter()
.enumerate()
.map(|(id, k)| gen_server_key_share(id, no_of_parties, k))
.collect_vec();
// Aggregate server key shares and set the server key
let server_key = aggregate_server_key_shares(&server_key_shares);
server_key.set_server_key();
// --------
// We attempt to divide by 0 in encrypted domain and then check whether div by 0
// error flag is set to True.
let numerator = thread_rng().gen::<u8>();
let numerator_enc = cks[0]
.encrypt(vec![numerator].as_slice())
.unseed::<Vec<Vec<u64>>>()
.key_switch(0)
.extract_at(0);
let zero_enc = cks[1]
.encrypt(vec![0].as_slice())
.unseed::<Vec<Vec<u64>>>()
.key_switch(1)
.extract_at(0);
let (quotient_enc, remainder_enc) = numerator_enc.div_rem(&zero_enc);
// When attempting to divide by zero, for uint8 quotient is always 255 and
// remainder = numerator
let quotient = cks[0].aggregate_decryption_shares(
&quotient_enc,
&cks.iter()
.map(|k| k.gen_decryption_share(&quotient_enc))
.collect_vec(),
);
let remainder = cks[0].aggregate_decryption_shares(
&remainder_enc,
&cks.iter()
.map(|k| k.gen_decryption_share(&remainder_enc))
.collect_vec(),
);
assert!(quotient == 255);
assert!(remainder == numerator);
// Div by zero error flag must be True
let div_by_zero_enc = div_zero_error_flag().expect("We performed division. Flag must be set");
let div_by_zero = cks[0].aggregate_decryption_shares(
&div_by_zero_enc,
&cks.iter()
.map(|k| k.gen_decryption_share(&div_by_zero_enc))
.collect_vec(),
);
assert!(div_by_zero == true);
// -------
// div by zero error flag is thread local. If we were to run another circuit
// without stopping the thread (i.e. within the same program as previous
// one), we must reset errors flags set by previous circuit with
// `reset_error_flags()` to prevent error flags of previous circuit affecting
// the flags of the next circuit.
reset_error_flags();
// We divide again but with non-zero denominator this time and check that div
// by zero flag is set to False
let numerator = thread_rng().gen::<u8>();
let mut denominator = thread_rng().gen::<u8>();
while denominator == 0 {
denominator = thread_rng().gen::<u8>();
}
let numerator_enc = cks[0]
.encrypt(vec![numerator].as_slice())
.unseed::<Vec<Vec<u64>>>()
.key_switch(0)
.extract_at(0);
let denominator_enc = cks[1]
.encrypt(vec![denominator].as_slice())
.unseed::<Vec<Vec<u64>>>()
.key_switch(1)
.extract_at(0);
let (quotient_enc, remainder_enc) = numerator_enc.div_rem(&denominator_enc);
let quotient = cks[0].aggregate_decryption_shares(
&quotient_enc,
&cks.iter()
.map(|k| k.gen_decryption_share(&quotient_enc))
.collect_vec(),
);
let remainder = cks[0].aggregate_decryption_shares(
&remainder_enc,
&cks.iter()
.map(|k| k.gen_decryption_share(&remainder_enc))
.collect_vec(),
);
assert!(quotient == numerator.div_euclid(denominator));
assert!(remainder == numerator.rem_euclid(denominator));
// Div by zero error flag must be set to False
let div_by_zero_enc = div_zero_error_flag().expect("We performed division. Flag must be set");
let div_by_zero = cks[0].aggregate_decryption_shares(
&div_by_zero_enc,
&cks.iter()
.map(|k| k.gen_decryption_share(&div_by_zero_enc))
.collect_vec(),
);
assert!(div_by_zero == false);
}

+ 107
- 0
examples/if_and_else.rs

@ -0,0 +1,107 @@
use itertools::Itertools;
use phantom_zone::*;
use rand::{thread_rng, Rng, RngCore};
/// Code that runs when conditional branch is `True`
fn circuit_branch_true(a: &FheUint8, b: &FheUint8) -> FheUint8 {
a + b
}
/// Code that runs when conditional branch is `False`
fn circuit_branch_false(a: &FheUint8, b: &FheUint8) -> FheUint8 {
a * b
}
// Conditional branching (ie. If and else) are generally expensive in encrypted
// domain. The code must execute all the branches, and, as apparent, the
// runtime cost grows exponentially with no. of conditional branches.
//
// In general we recommend to write branchless code. In case the code cannot be
// modified to be branchless, the code must execute all branches and use a
// muxer to select correct output at the end.
//
// Below we showcase example of a single conditional branch in encrypted domain.
// The code executes both the branches (i.e. program runs both If and Else) and
// selects output of one of the branches with a mux.
fn main() {
set_parameter_set(ParameterSelector::NonInteractiveLTE2Party);
// set application's common reference seed
let mut seed = [0u8; 32];
thread_rng().fill_bytes(&mut seed);
set_common_reference_seed(seed);
let no_of_parties = 2;
// Generate client keys
let cks = (0..no_of_parties).map(|_| gen_client_key()).collect_vec();
// Generate server key shares
let server_key_shares = cks
.iter()
.enumerate()
.map(|(id, k)| gen_server_key_share(id, no_of_parties, k))
.collect_vec();
// Aggregate server key shares and set the server key
let server_key = aggregate_server_key_shares(&server_key_shares);
server_key.set_server_key();
// -------
// User 0 encrypts their private input `v_a` and User 1 encrypts their
// private input `v_b`. We want to execute:
//
// if v_a < v_b:
// return v_a + v_b
// else:
// return v_a * v_b
//
// We define two functions
// (1) `circuit_branch_true`: which executes v_a + v_b in encrypted domain.
// (2) `circuit_branch_false`: which executes v_a * v_b in encrypted
// domain.
//
// The circuit runs both `circuit_branch_true` and `circuit_branch_false` and
// then selects the output of `circuit_branch_true` if `v_a < v_b == TRUE`
// otherwise selects the output of `circuit_branch_false` if `v_a < v_b ==
// FALSE` using mux.
// Clients private inputs
let v_a = thread_rng().gen::<u8>();
let v_b = thread_rng().gen::<u8>();
let v_a_enc = cks[0]
.encrypt(vec![v_a].as_slice())
.unseed::<Vec<Vec<u64>>>()
.key_switch(0)
.extract_at(0);
let v_b_enc = cks[1]
.encrypt(vec![v_b].as_slice())
.unseed::<Vec<Vec<u64>>>()
.key_switch(1)
.extract_at(0);
// Run both branches
let out_true_enc = circuit_branch_true(&v_a_enc, &v_b_enc);
let out_false_enc = circuit_branch_false(&v_a_enc, &v_b_enc);
// define condition select v_a < v_b
let selector_bit = v_a_enc.lt(&v_b_enc);
// select output of `circuit_branch_true` if selector_bit == TRUE otherwise
// select output of `circuit_branch_false`
let out_enc = out_true_enc.mux(&out_false_enc, &selector_bit);
let out = cks[0].aggregate_decryption_shares(
&out_enc,
&cks.iter()
.map(|k| k.gen_decryption_share(&out_enc))
.collect_vec(),
);
let want_out = if v_a < v_b {
v_a.wrapping_add(v_b)
} else {
v_a.wrapping_mul(v_b)
};
assert_eq!(out, want_out);
}

+ 180
- 0
examples/interactive_fheuint8.rs

@ -0,0 +1,180 @@
use itertools::Itertools;
use phantom_zone::*;
use rand::{thread_rng, Rng, RngCore};
fn function1(a: u8, b: u8, c: u8, d: u8) -> u8 {
((a + b) * c) * d
}
fn function1_fhe(a: &FheUint8, b: &FheUint8, c: &FheUint8, d: &FheUint8) -> FheUint8 {
&(&(a + b) * c) * d
}
fn function2(a: u8, b: u8, c: u8, d: u8) -> u8 {
(a * b) + (c * d)
}
fn function2_fhe(a: &FheUint8, b: &FheUint8, c: &FheUint8, d: &FheUint8) -> FheUint8 {
&(a * b) + &(c * d)
}
fn main() {
// Select parameter set
set_parameter_set(ParameterSelector::InteractiveLTE4Party);
// set application's common reference seed
let mut seed = [0u8; 32];
thread_rng().fill_bytes(&mut seed);
set_common_reference_seed(seed);
let no_of_parties = 4;
// Client side //
// Clients generate their private keys
let cks = (0..no_of_parties)
.into_iter()
.map(|_| gen_client_key())
.collect_vec();
// -- Round 1 -- //
// In round 1 each client generates their share for the collective public key.
// They send public key shares to each other with or out without the server.
// After receiving others public key shares clients independently aggregate
// the shares and produce the collective public key `pk`
let pk_shares = cks.iter().map(|k| collective_pk_share(k)).collect_vec();
// Clients aggregate public key shares to produce collective public key `pk`
let pk = aggregate_public_key_shares(&pk_shares);
// -- Round 2 -- //
// In round 2 each client generates server key share using the public key `pk`.
// Clients may also encrypt their private inputs using collective public key
// `pk`. Each client then uploads their server key share and private input
// ciphertexts to the server.
// Clients generate server key shares
//
// We assign user_id 0 to client 0, user_id 1 to client 1, user_id 2 to client
// 2, and user_id 4 to client 4.
//
// Note that `user_id`'s must be unique among the clients and must be less than
// total number of clients.
let server_key_shares = cks
.iter()
.enumerate()
.map(|(user_id, k)| collective_server_key_share(k, user_id, no_of_parties, &pk))
.collect_vec();
// Each client encrypts their private inputs using the collective public key
// `pk`. Unlike non-inteactive MPC protocol, private inputs are
// encrypted using collective public key.
let c0_a = thread_rng().gen::<u8>();
let c0_enc = pk.encrypt(vec![c0_a].as_slice());
let c1_a = thread_rng().gen::<u8>();
let c1_enc = pk.encrypt(vec![c1_a].as_slice());
let c2_a = thread_rng().gen::<u8>();
let c2_enc = pk.encrypt(vec![c2_a].as_slice());
let c3_a = thread_rng().gen::<u8>();
let c3_enc = pk.encrypt(vec![c3_a].as_slice());
// Clients upload their server key along with private encrypted inputs to
// the server
// Server side //
// Server receives server key shares from each client and proceeds to
// aggregate the shares and produce the server key
let server_key = aggregate_server_key_shares(&server_key_shares);
server_key.set_server_key();
// Server proceeds to extract clients private inputs
//
// Clients encrypt their FheUint8s inputs packed in a batched ciphertext.
// The server must extract clients private inputs from the batch ciphertext
// either (1) using `extract_at(index)` to extract `index`^{th} FheUint8
// ciphertext (2) or using `extract_all()` to extract all available FheUint8s
// (3) or using `extract_many(many)` to extract first `many` available FheUint8s
let c0_a_enc = c0_enc.extract_at(0);
let c1_a_enc = c1_enc.extract_at(0);
let c2_a_enc = c2_enc.extract_at(0);
let c3_a_enc = c3_enc.extract_at(0);
// Server proceeds to evaluate function1 on clients private inputs
let ct_out_f1 = function1_fhe(&c0_a_enc, &c1_a_enc, &c2_a_enc, &c3_a_enc);
// After server has finished evaluating the circuit on client private
// inputs, clients can proceed to multi-party decryption protocol to
// decrypt output ciphertext
// Client Side //
// In multi-party decryption protocol, client must come online, download the
// output ciphertext from the server, product "output ciphertext" dependent
// decryption share, and send it to other parties. After receiving
// decryption shares of other parties, clients independently aggregate the
// decrytion shares and decrypt the output ciphertext.
// Clients generate decryption shares
let decryption_shares = cks
.iter()
.map(|k| k.gen_decryption_share(&ct_out_f1))
.collect_vec();
// After receiving decryption shares from other parties, clients aggregate the
// shares and decrypt output ciphertext
let out_f1 = cks[0].aggregate_decryption_shares(&ct_out_f1, &decryption_shares);
// Check correctness of function1 output
let want_f1 = function1(c0_a, c1_a, c2_a, c3_a);
assert!(out_f1 == want_f1);
// --------
// Once server key is produced it can be re-used across different functions
// with different private client inputs for the same set of clients.
//
// Here we run `function2_fhe` for the same of clients but with different
// private inputs. Clients do not need to participate in the 2 round
// protocol again, instead they only upload their new private inputs to the
// server.
// Clients encrypt their private inputs
let c0_a = thread_rng().gen::<u8>();
let c0_enc = pk.encrypt(vec![c0_a].as_slice());
let c1_a = thread_rng().gen::<u8>();
let c1_enc = pk.encrypt(vec![c1_a].as_slice());
let c2_a = thread_rng().gen::<u8>();
let c2_enc = pk.encrypt(vec![c2_a].as_slice());
let c3_a = thread_rng().gen::<u8>();
let c3_enc = pk.encrypt(vec![c3_a].as_slice());
// Clients uploads only their new private inputs to the server
// Server side //
// Server receives private inputs from the clients, extracts them, and
// proceeds to evaluate `function2_fhe`
let c0_a_enc = c0_enc.extract_at(0);
let c1_a_enc = c1_enc.extract_at(0);
let c2_a_enc = c2_enc.extract_at(0);
let c3_a_enc = c3_enc.extract_at(0);
let ct_out_f2 = function2_fhe(&c0_a_enc, &c1_a_enc, &c2_a_enc, &c3_a_enc);
// Client side //
// Clients generate decryption shares for `ct_out_f2`
let decryption_shares = cks
.iter()
.map(|k| k.gen_decryption_share(&ct_out_f2))
.collect_vec();
// Clients aggregate decryption shares and decrypt `ct_out_f2`
let out_f2 = cks[0].aggregate_decryption_shares(&ct_out_f2, &decryption_shares);
// We check correctness of function2
let want_f2 = function2(c0_a, c1_a, c2_a, c3_a);
assert!(want_f2 == out_f2);
}

+ 150
- 0
examples/meeting_friends.rs

@ -0,0 +1,150 @@
use itertools::Itertools;
use phantom_zone::*;
use rand::{thread_rng, Rng, RngCore};
struct Location<T>(T, T);
impl<T> Location<T> {
fn new(x: T, y: T) -> Self {
Location(x, y)
}
fn x(&self) -> &T {
&self.0
}
fn y(&self) -> &T {
&self.1
}
}
fn should_meet(a: &Location<u8>, b: &Location<u8>, b_threshold: &u8) -> bool {
let diff_x = a.x() - b.x();
let diff_y = a.y() - b.y();
let d_sq = &(&diff_x * &diff_x) + &(&diff_y * &diff_y);
d_sq.le(b_threshold)
}
/// Calculates distance square between a's and b's location. Returns a boolean
/// indicating whether diatance sqaure is <= `b_threshold`.
fn should_meet_fhe(
a: &Location<FheUint8>,
b: &Location<FheUint8>,
b_threshold: &FheUint8,
) -> FheBool {
let diff_x = a.x() - b.x();
let diff_y = a.y() - b.y();
let d_sq = &(&diff_x * &diff_x) + &(&diff_y * &diff_y);
d_sq.le(b_threshold)
}
// Ever wondered who are the long distance friends (friends of friends or
// friends of friends of friends...) that live nearby ? But how do you find
// them? Surely no-one will simply reveal their exact location just because
// there's a slight chance that a long distance friend lives nearby.
//
// Here we write a simple application with two users `a` and `b`. User `a` wants
// to find (long distance) friends that live in their neighbourhood. User `b` is
// open to meeting new friends within some distance of their location. Both user
// `a` and `b` encrypt their locations and upload their encrypted locations to
// the server. User `b` also encrypts the distance square threshold within which
// they are interested in meeting new friends and sends encrypted distance
// square threshold to the server.
// The server calculates the square of the distance between user a's location
// and user b's location and produces encrypted boolean output indicating
// whether square of distance is <= user b's supplied distance square threshold.
// User `a` then comes online, downloads output ciphertext, produces their
// decryption share for user `b`, and uploads the decryption share to the
// server. User `b` comes online, downloads output ciphertext and user a's
// decryption share, produces their own decryption share, and then decrypts the
// encrypted boolean output. If the output is `True`, it indicates user `a` is
// within the distance square threshold defined by user `b`.
fn main() {
set_parameter_set(ParameterSelector::NonInteractiveLTE2Party);
// set application's common reference seed
let mut seed = [0u8; 32];
thread_rng().fill_bytes(&mut seed);
set_common_reference_seed(seed);
let no_of_parties = 2;
// Client Side //
// Generate client keys
let cks = (0..no_of_parties).map(|_| gen_client_key()).collect_vec();
// We assign user_id 0 to user `a` and user_id 1 user `b`
let a_id = 0;
let b_id = 1;
let user_a_secret = &cks[0];
let user_b_secret = &cks[1];
// User `a` and `b` generate server key shares
let a_server_key_share = gen_server_key_share(a_id, no_of_parties, user_a_secret);
let b_server_key_share = gen_server_key_share(b_id, no_of_parties, user_b_secret);
// User `a` and `b` encrypt their locations
let user_a_secret = &cks[0];
let user_a_location = Location::new(thread_rng().gen::<u8>(), thread_rng().gen::<u8>());
let user_a_enc =
user_a_secret.encrypt(vec![*user_a_location.x(), *user_a_location.y()].as_slice());
let user_b_location = Location::new(thread_rng().gen::<u8>(), thread_rng().gen::<u8>());
// User `b` also encrypts the distance square threshold
let user_b_threshold = 40;
let user_b_enc = user_b_secret
.encrypt(vec![*user_b_location.x(), *user_b_location.y(), user_b_threshold].as_slice());
// Server Side //
// Both user `a` and `b` upload their private inputs and server key shares to
// the server in single shot message
let server_key = aggregate_server_key_shares(&vec![a_server_key_share, b_server_key_share]);
server_key.set_server_key();
// Server parses private inputs from user `a` and `b`
let user_a_location_enc = {
let c = user_a_enc.unseed::<Vec<Vec<u64>>>().key_switch(a_id);
Location::new(c.extract_at(0), c.extract_at(1))
};
let (user_b_location_enc, user_b_threshold_enc) = {
let c = user_b_enc.unseed::<Vec<Vec<u64>>>().key_switch(b_id);
(
Location::new(c.extract_at(0), c.extract_at(1)),
c.extract_at(2),
)
};
// run the circuit
let out_c = should_meet_fhe(
&user_a_location_enc,
&user_b_location_enc,
&user_b_threshold_enc,
);
// Client Side //
// user `a` comes online, downloads `out_c`, produces a decryption share, and
// uploads the decryption share to the server.
let a_dec_share = user_a_secret.gen_decryption_share(&out_c);
// user `b` comes online downloads user `a`'s decryption share, generates their
// own decryption share, decrypts the output ciphertext. If the output is
// True, user `b` contacts user `a` to meet.
let b_dec_share = user_b_secret.gen_decryption_share(&out_c);
let out_bool =
user_b_secret.aggregate_decryption_shares(&out_c, &vec![b_dec_share, a_dec_share]);
assert_eq!(
out_bool,
should_meet(&user_a_location, &user_b_location, &user_b_threshold)
);
if out_bool {
println!("A lives nearby. B should meet A.");
} else {
println!("A lives too far away!")
}
}

+ 177
- 0
examples/non_interactive_fheuint8.rs

@ -0,0 +1,177 @@
use itertools::Itertools;
use phantom_zone::*;
use rand::{thread_rng, Rng, RngCore};
fn function1(a: u8, b: u8, c: u8, d: u8) -> u8 {
((a + b) * c) * d
}
fn function1_fhe(a: &FheUint8, b: &FheUint8, c: &FheUint8, d: &FheUint8) -> FheUint8 {
&(&(a + b) * c) * d
}
fn function2(a: u8, b: u8, c: u8, d: u8) -> u8 {
(a * b) + (c * d)
}
fn function2_fhe(a: &FheUint8, b: &FheUint8, c: &FheUint8, d: &FheUint8) -> FheUint8 {
&(a * b) + &(c * d)
}
fn main() {
set_parameter_set(ParameterSelector::NonInteractiveLTE4Party);
// set application's common reference seed
let mut seed = [0u8; 32];
thread_rng().fill_bytes(&mut seed);
set_common_reference_seed(seed);
let no_of_parties = 4;
// Clide side //
// Generate client keys
let cks = (0..no_of_parties).map(|_| gen_client_key()).collect_vec();
// client 0 encrypts its private inputs
let c0_a = thread_rng().gen::<u8>();
// Clients encrypt their private inputs in a seeded batched ciphertext using
// their private RLWE secret `u_j`.
let c0_enc = cks[0].encrypt(vec![c0_a].as_slice());
// client 1 encrypts its private inputs
let c1_a = thread_rng().gen::<u8>();
let c1_enc = cks[1].encrypt(vec![c1_a].as_slice());
// client 2 encrypts its private inputs
let c2_a = thread_rng().gen::<u8>();
let c2_enc = cks[2].encrypt(vec![c2_a].as_slice());
// client 3 encrypts its private inputs
let c3_a = thread_rng().gen::<u8>();
let c3_enc = cks[3].encrypt(vec![c3_a].as_slice());
// Clients independently generate their server key shares
//
// We assign user_id 0 to client 0, user_id 1 to client 1, user_id 2 to client
// 2, user_id 3 to client 3.
//
// Note that `user_id`s must be unique among the clients and must be less than
// total number of clients.
let server_key_shares = cks
.iter()
.enumerate()
.map(|(id, k)| gen_server_key_share(id, no_of_parties, k))
.collect_vec();
// Each client uploads their server key shares and encrypted private inputs to
// the server in a single shot message.
// Server side //
// Server receives server key shares from each client and proceeds to aggregate
// them to produce the server key. After this point, server can use the server
// key to evaluate any arbitrary function on encrypted private inputs from
// the fixed set of clients
// aggregate server shares and generate the server key
let server_key = aggregate_server_key_shares(&server_key_shares);
server_key.set_server_key();
// Server proceeds to extract private inputs sent by clients
//
// To extract client 0's (with user_id=0) private inputs we first key switch
// client 0's private inputs from theit secret `u_j` to ideal secret of the mpc
// protocol. To indicate we're key switching client 0's private input we
// supply client 0's `user_id` i.e. we call `key_switch(0)`. Then we extract
// the first ciphertext by calling `extract_at(0)`.
//
// Since client 0 only encrypts 1 input in batched ciphertext, calling
// extract_at(index) for `index` > 0 will panic. If client 0 had more private
// inputs then we can either extract them all at once with `extract_all` or
// first `many` of them with `extract_many(many)`
let ct_c0_a = c0_enc.unseed::<Vec<Vec<u64>>>().key_switch(0).extract_at(0);
let ct_c1_a = c1_enc.unseed::<Vec<Vec<u64>>>().key_switch(1).extract_at(0);
let ct_c2_a = c2_enc.unseed::<Vec<Vec<u64>>>().key_switch(2).extract_at(0);
let ct_c3_a = c3_enc.unseed::<Vec<Vec<u64>>>().key_switch(3).extract_at(0);
// After extracting each client's private inputs, server proceeds to evaluate
// function1
let now = std::time::Instant::now();
let ct_out_f1 = function1_fhe(&ct_c0_a, &ct_c1_a, &ct_c2_a, &ct_c3_a);
println!("Function1 FHE evaluation time: {:?}", now.elapsed());
// Server has finished running compute. Clients can proceed to decrypt the
// output ciphertext using multi-party decryption.
// Client side //
// In multi-party decryption, each client needs to come online, download output
// ciphertext from the server, produce "output ciphertext" dependent decryption
// share, and send it to other parties (either via p2p or via server). After
// receving decryption shares from other parties, clients can independently
// decrypt output ciphertext.
// each client produces decryption share
let decryption_shares = cks
.iter()
.map(|k| k.gen_decryption_share(&ct_out_f1))
.collect_vec();
// With all decryption shares, clients can aggregate the shares and decrypt the
// ciphertext
let out_f1 = cks[0].aggregate_decryption_shares(&ct_out_f1, &decryption_shares);
// we check correctness of function1
let want_out_f1 = function1(c0_a, c1_a, c2_a, c3_a);
assert_eq!(out_f1, want_out_f1);
// -----------
// Server key can be re-used for different functions with different private
// client inputs for the same set of clients.
//
// Here we run `function2_fhe` for the same set of client but with new inputs.
// Clients only have to upload their private inputs to the server this time.
// Each client encrypts their private input
let c0_a = thread_rng().gen::<u8>();
let c0_enc = cks[0].encrypt(vec![c0_a].as_slice());
let c1_a = thread_rng().gen::<u8>();
let c1_enc = cks[1].encrypt(vec![c1_a].as_slice());
let c2_a = thread_rng().gen::<u8>();
let c2_enc = cks[2].encrypt(vec![c2_a].as_slice());
let c3_a = thread_rng().gen::<u8>();
let c3_enc = cks[3].encrypt(vec![c3_a].as_slice());
// Clients upload only their new private inputs to the server
// Server side //
// Server receives clients private inputs and extracts them
let ct_c0_a = c0_enc.unseed::<Vec<Vec<u64>>>().key_switch(0).extract_at(0);
let ct_c1_a = c1_enc.unseed::<Vec<Vec<u64>>>().key_switch(1).extract_at(0);
let ct_c2_a = c2_enc.unseed::<Vec<Vec<u64>>>().key_switch(2).extract_at(0);
let ct_c3_a = c3_enc.unseed::<Vec<Vec<u64>>>().key_switch(3).extract_at(0);
// Server proceeds to evaluate `function2_fhe`
let now = std::time::Instant::now();
let ct_out_f2 = function2_fhe(&ct_c0_a, &ct_c1_a, &ct_c2_a, &ct_c3_a);
println!("Function2 FHE evaluation time: {:?}", now.elapsed());
// Client side //
// Each client generates decrytion share for `ct_out_f2`
let decryption_shares = cks
.iter()
.map(|k| k.gen_decryption_share(&ct_out_f2))
.collect_vec();
// Clients independently aggregate the shares and decrypt
let out_f2 = cks[0].aggregate_decryption_shares(&ct_out_f2, &decryption_shares);
// We check correctness of function2
let want_out_f2 = function2(c0_a, c1_a, c2_a, c3_a);
assert_eq!(out_f2, want_out_f2);
}

+ 0
- 163
src/backend.rs

@ -1,163 +0,0 @@
use itertools::izip;
pub trait VectorOps {
type Element;
fn elwise_scalar_mul(&self, out: &mut [Self::Element], a: &[Self::Element], b: &Self::Element);
fn elwise_mul(&self, out: &mut [Self::Element], a: &[Self::Element], b: &[Self::Element]);
fn elwise_add_mut(&self, a: &mut [Self::Element], b: &[Self::Element]);
fn elwise_mul_mut(&self, a: &mut [Self::Element], b: &[Self::Element]);
fn elwise_neg_mut(&self, a: &mut [Self::Element]);
/// inplace mutates `a`: a = a + b*c
fn elwise_fma_mut(&self, a: &mut [Self::Element], b: &[Self::Element], c: &[Self::Element]);
fn modulus(&self) -> Self::Element;
}
pub trait ArithmeticOps {
type Element;
fn mul(&self, a: &Self::Element, b: &Self::Element) -> Self::Element;
fn add(&self, a: &Self::Element, b: &Self::Element) -> Self::Element;
fn sub(&self, a: &Self::Element, b: &Self::Element) -> Self::Element;
fn modulus(&self) -> Self::Element;
}
pub struct ModularOpsU64 {
q: u64,
logq: usize,
barrett_mu: u128,
barrett_alpha: usize,
}
impl ModularOpsU64 {
pub fn new(q: u64) -> ModularOpsU64 {
let logq = 64 - q.leading_zeros();
// barrett calculation
let mu = (1u128 << (logq * 2 + 3)) / (q as u128);
let alpha = logq + 3;
ModularOpsU64 {
q,
logq: logq as usize,
barrett_alpha: alpha as usize,
barrett_mu: mu,
}
}
fn add_mod_fast(&self, a: u64, b: u64) -> u64 {
debug_assert!(a < self.q);
debug_assert!(b < self.q);
let mut o = a + b;
if o >= self.q {
o -= self.q;
}
o
}
fn sub_mod_fast(&self, a: u64, b: u64) -> u64 {
debug_assert!(a < self.q);
debug_assert!(b < self.q);
if a > b {
a - b
} else {
(self.q + a) - b
}
}
/// returns (a * b) % q
///
/// - both a and b must be in range [0, 2q)
/// - output is in range [0 , q)
fn mul_mod_fast(&self, a: u64, b: u64) -> u64 {
debug_assert!(a < 2 * self.q);
debug_assert!(b < 2 * self.q);
let ab = a as u128 * b as u128;
// ab / (2^{n + \beta})
// note: \beta is assumed to -2
let tmp = ab >> (self.logq - 2);
// k = ((ab / (2^{n + \beta})) * \mu) / 2^{\alpha - (-2)}
let k = (tmp * self.barrett_mu) >> (self.barrett_alpha + 2);
// ab - k*p
let tmp = k * (self.q as u128);
let mut out = (ab - tmp) as u64;
if out >= self.q {
out -= self.q;
}
return out;
}
}
impl ArithmeticOps for ModularOpsU64 {
type Element = u64;
fn add(&self, a: &Self::Element, b: &Self::Element) -> Self::Element {
self.add_mod_fast(*a, *b)
}
fn mul(&self, a: &Self::Element, b: &Self::Element) -> Self::Element {
self.mul_mod_fast(*a, *b)
}
fn sub(&self, a: &Self::Element, b: &Self::Element) -> Self::Element {
self.sub_mod_fast(*a, *b)
}
fn modulus(&self) -> Self::Element {
self.q
}
}
impl VectorOps for ModularOpsU64 {
type Element = u64;
fn elwise_add_mut(&self, a: &mut [Self::Element], b: &[Self::Element]) {
izip!(a.iter_mut(), b.iter()).for_each(|(ai, bi)| {
*ai = self.add_mod_fast(*ai, *bi);
});
}
fn elwise_mul_mut(&self, a: &mut [Self::Element], b: &[Self::Element]) {
izip!(a.iter_mut(), b.iter()).for_each(|(ai, bi)| {
*ai = self.mul_mod_fast(*ai, *bi);
});
}
fn elwise_neg_mut(&self, a: &mut [Self::Element]) {
a.iter_mut().for_each(|ai| *ai = self.q - *ai);
}
fn elwise_scalar_mul(&self, out: &mut [Self::Element], a: &[Self::Element], b: &Self::Element) {
izip!(out.iter_mut(), a.iter()).for_each(|(oi, ai)| {
*oi = self.mul_mod_fast(*ai, *b);
});
}
fn elwise_mul(&self, out: &mut [Self::Element], a: &[Self::Element], b: &[Self::Element]) {
izip!(out.iter_mut(), a.iter(), b.iter()).for_each(|(oi, ai, bi)| {
*oi = self.mul_mod_fast(*ai, *bi);
});
}
fn elwise_fma_mut(&self, a: &mut [Self::Element], b: &[Self::Element], c: &[Self::Element]) {
izip!(a.iter_mut(), b.iter(), c.iter()).for_each(|(ai, bi, ci)| {
*ai = self.add_mod_fast(*ai, self.mul_mod_fast(*bi, *ci));
});
}
fn modulus(&self) -> Self::Element {
self.q
}
}

+ 141
- 0
src/backend/mod.rs

@ -0,0 +1,141 @@
use num_traits::ToPrimitive;
use crate::{utils::log2, Row};
mod modulus_u64;
mod power_of_2;
mod word_size;
pub use modulus_u64::ModularOpsU64;
pub(crate) use power_of_2::ModulusPowerOf2;
pub trait Modulus {
type Element;
/// Modulus value if it fits in Element
fn q(&self) -> Option<Self::Element>;
/// Log2 of `q`
fn log_q(&self) -> usize;
/// Modulus value as f64 if it fits in f64
fn q_as_f64(&self) -> Option<f64>;
/// Is modulus native?
fn is_native(&self) -> bool;
/// -1 in signed representaiton
fn neg_one(&self) -> Self::Element;
/// Largest unsigned value that fits in the modulus. That is, q - 1.
fn largest_unsigned_value(&self) -> Self::Element;
/// Smallest unsigned value that fits in the modulus
/// Always assmed to be 0.
fn smallest_unsigned_value(&self) -> Self::Element;
/// Convert unsigned value in signed represetation to i64
fn map_element_to_i64(&self, v: &Self::Element) -> i64;
/// Convert f64 to signed represented in modulus
fn map_element_from_f64(&self, v: f64) -> Self::Element;
/// Convert i64 to signed represented in modulus
fn map_element_from_i64(&self, v: i64) -> Self::Element;
}
impl Modulus for u64 {
type Element = u64;
fn is_native(&self) -> bool {
// q that fits in u64 can never be a native modulus
false
}
fn largest_unsigned_value(&self) -> Self::Element {
self - 1
}
fn neg_one(&self) -> Self::Element {
self - 1
}
fn smallest_unsigned_value(&self) -> Self::Element {
0
}
fn map_element_to_i64(&self, v: &Self::Element) -> i64 {
assert!(v <= self, "{v} must be <= {self}");
if *v >= (self >> 1) {
-ToPrimitive::to_i64(&(self - v)).unwrap()
} else {
ToPrimitive::to_i64(v).unwrap()
}
}
fn map_element_from_f64(&self, v: f64) -> Self::Element {
let v = v.round();
let v_u64 = v.abs().to_u64().unwrap();
assert!(v_u64 <= self.largest_unsigned_value());
if v < 0.0 {
self - v_u64
} else {
v_u64
}
}
fn map_element_from_i64(&self, v: i64) -> Self::Element {
let v_u64 = v.abs().to_u64().unwrap();
assert!(v_u64 <= self.largest_unsigned_value());
if v < 0 {
self - v_u64
} else {
v_u64
}
}
fn q(&self) -> Option<Self::Element> {
Some(*self)
}
fn q_as_f64(&self) -> Option<f64> {
self.to_f64()
}
fn log_q(&self) -> usize {
log2(&self.q().unwrap())
}
}
pub trait ModInit {
type M;
fn new(modulus: Self::M) -> Self;
}
pub trait GetModulus {
type Element;
type M: Modulus<Element = Self::Element>;
fn modulus(&self) -> &Self::M;
}
pub trait VectorOps {
type Element;
/// Sets out as `out[i] = a[i] * b`
fn elwise_scalar_mul(&self, out: &mut [Self::Element], a: &[Self::Element], b: &Self::Element);
fn elwise_mul(&self, out: &mut [Self::Element], a: &[Self::Element], b: &[Self::Element]);
fn elwise_add_mut(&self, a: &mut [Self::Element], b: &[Self::Element]);
fn elwise_sub_mut(&self, a: &mut [Self::Element], b: &[Self::Element]);
fn elwise_mul_mut(&self, a: &mut [Self::Element], b: &[Self::Element]);
fn elwise_scalar_mul_mut(&self, a: &mut [Self::Element], b: &Self::Element);
fn elwise_neg_mut(&self, a: &mut [Self::Element]);
/// inplace mutates `a`: a = a + b*c
fn elwise_fma_mut(&self, a: &mut [Self::Element], b: &[Self::Element], c: &[Self::Element]);
fn elwise_fma_scalar_mut(
&self,
a: &mut [Self::Element],
b: &[Self::Element],
c: &Self::Element,
);
}
pub trait ArithmeticOps {
type Element;
fn mul(&self, a: &Self::Element, b: &Self::Element) -> Self::Element;
fn add(&self, a: &Self::Element, b: &Self::Element) -> Self::Element;
fn sub(&self, a: &Self::Element, b: &Self::Element) -> Self::Element;
fn neg(&self, a: &Self::Element) -> Self::Element;
}
pub trait ArithmeticLazyOps {
type Element;
fn mul_lazy(&self, a: &Self::Element, b: &Self::Element) -> Self::Element;
fn add_lazy(&self, a: &Self::Element, b: &Self::Element) -> Self::Element;
}
pub trait ShoupMatrixFMA<R: Row> {
/// Returns summation of `row-wise product of matrix a and b` + out where
/// each element is in range [0, 2q)
fn shoup_matrix_fma(&self, out: &mut [R::Element], a: &[R], a_shoup: &[R], b: &[R]);
}

+ 337
- 0
src/backend/modulus_u64.rs

@ -0,0 +1,337 @@
use itertools::izip;
use num_traits::WrappingMul;
use super::{
ArithmeticLazyOps, ArithmeticOps, GetModulus, ModInit, Modulus, ShoupMatrixFMA, VectorOps,
};
use crate::RowMut;
pub struct ModularOpsU64<T> {
q: u64,
q_twice: u64,
logq: usize,
barrett_mu: u128,
barrett_alpha: usize,
modulus: T,
}
impl<T> ModInit for ModularOpsU64<T>
where
T: Modulus<Element = u64>,
{
type M = T;
fn new(modulus: Self::M) -> ModularOpsU64<T> {
assert!(!modulus.is_native());
// largest unsigned value modulus fits is modulus-1
let q = modulus.largest_unsigned_value() + 1;
let logq = 64 - (q + 1u64).leading_zeros();
// barrett calculation
let mu = (1u128 << (logq * 2 + 3)) / (q as u128);
let alpha = logq + 3;
ModularOpsU64 {
q,
q_twice: q << 1,
logq: logq as usize,
barrett_alpha: alpha as usize,
barrett_mu: mu,
modulus,
}
}
}
impl<T> ModularOpsU64<T> {
fn add_mod_fast(&self, a: u64, b: u64) -> u64 {
debug_assert!(a < self.q);
debug_assert!(b < self.q);
let mut o = a + b;
if o >= self.q {
o -= self.q;
}
o
}
fn add_mod_fast_lazy(&self, a: u64, b: u64) -> u64 {
debug_assert!(a < self.q_twice);
debug_assert!(b < self.q_twice);
let mut o = a + b;
if o >= self.q_twice {
o -= self.q_twice;
}
o
}
fn sub_mod_fast(&self, a: u64, b: u64) -> u64 {
debug_assert!(a < self.q);
debug_assert!(b < self.q);
if a >= b {
a - b
} else {
(self.q + a) - b
}
}
// returns (a * b) % q
///
/// - both a and b must be in range [0, 2q)
/// - output is in range [0 , 2q)
fn mul_mod_fast_lazy(&self, a: u64, b: u64) -> u64 {
debug_assert!(a < 2 * self.q);
debug_assert!(b < 2 * self.q);
let ab = a as u128 * b as u128;
// ab / (2^{n + \beta})
// note: \beta is assumed to -2
let tmp = ab >> (self.logq - 2);
// k = ((ab / (2^{n + \beta})) * \mu) / 2^{\alpha - (-2)}
let k = (tmp * self.barrett_mu) >> (self.barrett_alpha + 2);
// ab - k*p
let tmp = k * (self.q as u128);
(ab - tmp) as u64
}
/// returns (a * b) % q
///
/// - both a and b must be in range [0, 2q)
/// - output is in range [0 , q)
fn mul_mod_fast(&self, a: u64, b: u64) -> u64 {
debug_assert!(a < 2 * self.q);
debug_assert!(b < 2 * self.q);
let ab = a as u128 * b as u128;
// ab / (2^{n + \beta})
// note: \beta is assumed to -2
let tmp = ab >> (self.logq - 2);
// k = ((ab / (2^{n + \beta})) * \mu) / 2^{\alpha - (-2)}
let k = (tmp * self.barrett_mu) >> (self.barrett_alpha + 2);
// ab - k*p
let tmp = k * (self.q as u128);
let mut out = (ab - tmp) as u64;
if out >= self.q {
out -= self.q;
}
return out;
}
}
impl<T> ArithmeticOps for ModularOpsU64<T> {
type Element = u64;
fn add(&self, a: &Self::Element, b: &Self::Element) -> Self::Element {
self.add_mod_fast(*a, *b)
}
fn mul(&self, a: &Self::Element, b: &Self::Element) -> Self::Element {
self.mul_mod_fast(*a, *b)
}
fn sub(&self, a: &Self::Element, b: &Self::Element) -> Self::Element {
self.sub_mod_fast(*a, *b)
}
fn neg(&self, a: &Self::Element) -> Self::Element {
self.q - *a
}
// fn modulus(&self) -> Self::Element {
// self.q
// }
}
impl<T> ArithmeticLazyOps for ModularOpsU64<T> {
type Element = u64;
fn add_lazy(&self, a: &Self::Element, b: &Self::Element) -> Self::Element {
self.add_mod_fast_lazy(*a, *b)
}
fn mul_lazy(&self, a: &Self::Element, b: &Self::Element) -> Self::Element {
self.mul_mod_fast_lazy(*a, *b)
}
}
impl<T> VectorOps for ModularOpsU64<T> {
type Element = u64;
fn elwise_add_mut(&self, a: &mut [Self::Element], b: &[Self::Element]) {
izip!(a.iter_mut(), b.iter()).for_each(|(ai, bi)| {
*ai = self.add_mod_fast(*ai, *bi);
});
}
fn elwise_sub_mut(&self, a: &mut [Self::Element], b: &[Self::Element]) {
izip!(a.iter_mut(), b.iter()).for_each(|(ai, bi)| {
*ai = self.sub_mod_fast(*ai, *bi);
});
}
fn elwise_mul_mut(&self, a: &mut [Self::Element], b: &[Self::Element]) {
izip!(a.iter_mut(), b.iter()).for_each(|(ai, bi)| {
*ai = self.mul_mod_fast(*ai, *bi);
});
}
fn elwise_neg_mut(&self, a: &mut [Self::Element]) {
a.iter_mut().for_each(|ai| *ai = self.q - *ai);
}
fn elwise_scalar_mul(&self, out: &mut [Self::Element], a: &[Self::Element], b: &Self::Element) {
izip!(out.iter_mut(), a.iter()).for_each(|(oi, ai)| {
*oi = self.mul_mod_fast(*ai, *b);
});
}
fn elwise_mul(&self, out: &mut [Self::Element], a: &[Self::Element], b: &[Self::Element]) {
izip!(out.iter_mut(), a.iter(), b.iter()).for_each(|(oi, ai, bi)| {
*oi = self.mul_mod_fast(*ai, *bi);
});
}
fn elwise_scalar_mul_mut(&self, a: &mut [Self::Element], b: &Self::Element) {
a.iter_mut().for_each(|ai| {
*ai = self.mul_mod_fast(*ai, *b);
});
}
fn elwise_fma_mut(&self, a: &mut [Self::Element], b: &[Self::Element], c: &[Self::Element]) {
izip!(a.iter_mut(), b.iter(), c.iter()).for_each(|(ai, bi, ci)| {
*ai = self.add_mod_fast(*ai, self.mul_mod_fast(*bi, *ci));
});
}
fn elwise_fma_scalar_mut(
&self,
a: &mut [Self::Element],
b: &[Self::Element],
c: &Self::Element,
) {
izip!(a.iter_mut(), b.iter()).for_each(|(ai, bi)| {
*ai = self.add_mod_fast(*ai, self.mul_mod_fast(*bi, *c));
});
}
// fn modulus(&self) -> Self::Element {
// self.q
// }
}
impl<R: RowMut<Element = u64>, T> ShoupMatrixFMA<R> for ModularOpsU64<T> {
fn shoup_matrix_fma(&self, out: &mut [R::Element], a: &[R], a_shoup: &[R], b: &[R]) {
assert!(a.len() == a_shoup.len());
assert!(
a.len() == b.len(),
"Unequal length {}!={}",
a.len(),
b.len()
);
let q = self.q;
let q_twice = self.q << 1;
izip!(a.iter(), a_shoup.iter(), b.iter()).for_each(|(a_row, a_shoup_row, b_row)| {
izip!(
out.as_mut().iter_mut(),
a_row.as_ref().iter(),
a_shoup_row.as_ref().iter(),
b_row.as_ref().iter()
)
.for_each(|(o, a0, a0_shoup, b0)| {
let quotient = ((*a0_shoup as u128 * *b0 as u128) >> 64) as u64;
let mut v = (a0.wrapping_mul(b0)).wrapping_add(*o);
v = v.wrapping_sub(q.wrapping_mul(quotient));
if v >= q_twice {
v -= q_twice;
}
*o = v;
});
});
}
}
impl<T> GetModulus for ModularOpsU64<T>
where
T: Modulus,
{
type Element = T::Element;
type M = T;
fn modulus(&self) -> &Self::M {
&self.modulus
}
}
#[cfg(test)]
mod tests {
use super::*;
use itertools::Itertools;
use rand::{thread_rng, Rng};
use rand_distr::Uniform;
#[test]
fn fma() {
let mut rng = thread_rng();
let prime = 36028797017456641;
let ring_size = 1 << 3;
let dist = Uniform::new(0, prime);
let d = 2;
let a0_matrix = (0..d)
.into_iter()
.map(|_| (&mut rng).sample_iter(dist).take(ring_size).collect_vec())
.collect_vec();
// a0 in shoup representation
let a0_shoup_matrix = a0_matrix
.iter()
.map(|r| {
r.iter()
.map(|v| {
// $(v * 2^{\beta}) / p$
((*v as u128 * (1u128 << 64)) / prime as u128) as u64
})
.collect_vec()
})
.collect_vec();
let a1_matrix = (0..d)
.into_iter()
.map(|_| (&mut rng).sample_iter(dist).take(ring_size).collect_vec())
.collect_vec();
let modop = ModularOpsU64::new(prime);
let mut out_shoup_fma_lazy = vec![0u64; ring_size];
modop.shoup_matrix_fma(
&mut out_shoup_fma_lazy,
&a0_matrix,
&a0_shoup_matrix,
&a1_matrix,
);
let out_shoup_fma = out_shoup_fma_lazy
.iter()
.map(|v| if *v >= prime { v - prime } else { *v })
.collect_vec();
// expected
let mut out_expected = vec![0u64; ring_size];
izip!(a0_matrix.iter(), a1_matrix.iter()).for_each(|(a_r, b_r)| {
izip!(out_expected.iter_mut(), a_r.iter(), b_r.iter()).for_each(|(o, a0, a1)| {
*o = (*o + ((*a0 as u128 * *a1 as u128) % prime as u128) as u64) % prime;
});
});
assert_eq!(out_expected, out_shoup_fma);
}
}

+ 112
- 0
src/backend/power_of_2.rs

@ -0,0 +1,112 @@
use itertools::izip;
use crate::{ArithmeticOps, ModInit, VectorOps};
use super::{GetModulus, Modulus};
pub(crate) struct ModulusPowerOf2<T> {
modulus: T,
/// Modulus mask: (1 << q) - 1
mask: u64,
}
impl<T> ArithmeticOps for ModulusPowerOf2<T> {
type Element = u64;
#[inline]
fn add(&self, a: &Self::Element, b: &Self::Element) -> Self::Element {
(a.wrapping_add(*b)) & self.mask
}
#[inline]
fn sub(&self, a: &Self::Element, b: &Self::Element) -> Self::Element {
(a.wrapping_sub(*b)) & self.mask
}
#[inline]
fn mul(&self, a: &Self::Element, b: &Self::Element) -> Self::Element {
(a.wrapping_mul(*b)) & self.mask
}
#[inline]
fn neg(&self, a: &Self::Element) -> Self::Element {
(0u64.wrapping_sub(*a)) & self.mask
}
}
impl<T> VectorOps for ModulusPowerOf2<T> {
type Element = u64;
#[inline]
fn elwise_add_mut(&self, a: &mut [Self::Element], b: &[Self::Element]) {
izip!(a.iter_mut(), b.iter()).for_each(|(a0, b0)| *a0 = (a0.wrapping_add(*b0)) & self.mask);
}
#[inline]
fn elwise_mul_mut(&self, a: &mut [Self::Element], b: &[Self::Element]) {
izip!(a.iter_mut(), b.iter()).for_each(|(a0, b0)| *a0 = (a0.wrapping_mul(*b0)) & self.mask);
}
#[inline]
fn elwise_neg_mut(&self, a: &mut [Self::Element]) {
a.iter_mut()
.for_each(|a0| *a0 = 0u64.wrapping_sub(*a0) & self.mask);
}
#[inline]
fn elwise_sub_mut(&self, a: &mut [Self::Element], b: &[Self::Element]) {
izip!(a.iter_mut(), b.iter()).for_each(|(a0, b0)| *a0 = (a0.wrapping_sub(*b0)) & self.mask);
}
#[inline]
fn elwise_fma_mut(&self, a: &mut [Self::Element], b: &[Self::Element], c: &[Self::Element]) {
izip!(a.iter_mut(), b.iter(), c.iter()).for_each(|(a0, b0, c0)| {
*a0 = a0.wrapping_add(b0.wrapping_mul(*c0)) & self.mask;
});
}
#[inline]
fn elwise_fma_scalar_mut(
&self,
a: &mut [Self::Element],
b: &[Self::Element],
c: &Self::Element,
) {
izip!(a.iter_mut(), b.iter()).for_each(|(a0, b0)| {
*a0 = a0.wrapping_add(b0.wrapping_mul(*c)) & self.mask;
});
}
#[inline]
fn elwise_scalar_mul_mut(&self, a: &mut [Self::Element], b: &Self::Element) {
a.iter_mut()
.for_each(|a0| *a0 = a0.wrapping_mul(*b) & self.mask)
}
#[inline]
fn elwise_mul(&self, out: &mut [Self::Element], a: &[Self::Element], b: &[Self::Element]) {
izip!(out.iter_mut(), a.iter(), b.iter()).for_each(|(o0, a0, b0)| {
*o0 = a0.wrapping_mul(*b0) & self.mask;
});
}
#[inline]
fn elwise_scalar_mul(&self, out: &mut [Self::Element], a: &[Self::Element], b: &Self::Element) {
izip!(out.iter_mut(), a.iter()).for_each(|(o0, a0)| {
*o0 = a0.wrapping_mul(*b) & self.mask;
});
}
}
impl<T: Modulus<Element = u64>> ModInit for ModulusPowerOf2<T> {
type M = T;
fn new(modulus: Self::M) -> Self {
assert!(!modulus.is_native());
assert!(modulus.q().unwrap().is_power_of_two());
let q = modulus.q().unwrap();
let mask = q - 1;
Self { modulus, mask }
}
}
impl<T: Modulus<Element = u64>> GetModulus for ModulusPowerOf2<T> {
type Element = u64;
type M = T;
fn modulus(&self) -> &Self::M {
&self.modulus
}
}

+ 124
- 0
src/backend/word_size.rs

@ -0,0 +1,124 @@
use itertools::izip;
use num_traits::{WrappingAdd, WrappingMul, WrappingSub, Zero};
use super::{ArithmeticOps, GetModulus, ModInit, Modulus, VectorOps};
pub struct WordSizeModulus<T> {
modulus: T,
}
impl<T> ModInit for WordSizeModulus<T>
where
T: Modulus,
{
type M = T;
fn new(modulus: T) -> Self {
assert!(modulus.is_native());
// For now assume ModulusOpsU64 is only used for u64
Self { modulus: modulus }
}
}
impl<T> ArithmeticOps for WordSizeModulus<T>
where
T: Modulus,
T::Element: WrappingAdd + WrappingSub + WrappingMul + Zero,
{
type Element = T::Element;
fn add(&self, a: &Self::Element, b: &Self::Element) -> Self::Element {
T::Element::wrapping_add(a, b)
}
fn mul(&self, a: &Self::Element, b: &Self::Element) -> Self::Element {
T::Element::wrapping_mul(a, b)
}
fn neg(&self, a: &Self::Element) -> Self::Element {
T::Element::wrapping_sub(&T::Element::zero(), a)
}
fn sub(&self, a: &Self::Element, b: &Self::Element) -> Self::Element {
T::Element::wrapping_sub(a, b)
}
}
impl<T> VectorOps for WordSizeModulus<T>
where
T: Modulus,
T::Element: WrappingAdd + WrappingSub + WrappingMul + Zero,
{
type Element = T::Element;
fn elwise_add_mut(&self, a: &mut [Self::Element], b: &[Self::Element]) {
izip!(a.iter_mut(), b.iter()).for_each(|(ai, bi)| {
*ai = T::Element::wrapping_add(ai, bi);
});
}
fn elwise_sub_mut(&self, a: &mut [Self::Element], b: &[Self::Element]) {
izip!(a.iter_mut(), b.iter()).for_each(|(ai, bi)| {
*ai = T::Element::wrapping_sub(ai, bi);
});
}
fn elwise_mul_mut(&self, a: &mut [Self::Element], b: &[Self::Element]) {
izip!(a.iter_mut(), b.iter()).for_each(|(ai, bi)| {
*ai = T::Element::wrapping_mul(ai, bi);
});
}
fn elwise_neg_mut(&self, a: &mut [Self::Element]) {
a.iter_mut()
.for_each(|ai| *ai = T::Element::wrapping_sub(&T::Element::zero(), ai));
}
fn elwise_scalar_mul(&self, out: &mut [Self::Element], a: &[Self::Element], b: &Self::Element) {
izip!(out.iter_mut(), a.iter()).for_each(|(oi, ai)| {
*oi = T::Element::wrapping_mul(ai, b);
});
}
fn elwise_mul(&self, out: &mut [Self::Element], a: &[Self::Element], b: &[Self::Element]) {
izip!(out.iter_mut(), a.iter(), b.iter()).for_each(|(oi, ai, bi)| {
*oi = T::Element::wrapping_mul(ai, bi);
});
}
fn elwise_scalar_mul_mut(&self, a: &mut [Self::Element], b: &Self::Element) {
a.iter_mut().for_each(|ai| {
*ai = T::Element::wrapping_mul(ai, b);
});
}
fn elwise_fma_mut(&self, a: &mut [Self::Element], b: &[Self::Element], c: &[Self::Element]) {
izip!(a.iter_mut(), b.iter(), c.iter()).for_each(|(ai, bi, ci)| {
*ai = T::Element::wrapping_add(ai, &T::Element::wrapping_mul(bi, ci));
});
}
fn elwise_fma_scalar_mut(
&self,
a: &mut [Self::Element],
b: &[Self::Element],
c: &Self::Element,
) {
izip!(a.iter_mut(), b.iter()).for_each(|(ai, bi)| {
*ai = T::Element::wrapping_add(ai, &T::Element::wrapping_mul(bi, c));
});
}
// fn modulus(&self) -> &T {
// &self.modulus
// }
}
impl<T> GetModulus for WordSizeModulus<T>
where
T: Modulus,
{
type Element = T::Element;
type M = T;
fn modulus(&self) -> &Self::M {
&self.modulus
}
}

+ 2323
- 0
src/bool/evaluator.rs
File diff suppressed because it is too large
View File


+ 1559
- 0
src/bool/keys.rs
File diff suppressed because it is too large
View File


+ 266
- 0
src/bool/mod.rs

@ -0,0 +1,266 @@
mod evaluator;
mod keys;
pub(crate) mod parameters;
#[cfg(feature = "interactive_mp")]
mod mp_api;
#[cfg(feature = "non_interactive_mp")]
mod ni_mp_api;
#[cfg(feature = "non_interactive_mp")]
pub use ni_mp_api::*;
#[cfg(feature = "interactive_mp")]
pub use mp_api::*;
use crate::RowEntity;
pub type ClientKey = keys::ClientKey<[u8; 32], u64>;
#[cfg(any(feature = "interactive_mp", feature = "non_interactive_mp"))]
pub type FheBool = impl_bool_frontend::FheBool<Vec<u64>>;
pub(crate) trait BooleanGates {
type Ciphertext: RowEntity;
type Key;
fn and_inplace(&mut self, c0: &mut Self::Ciphertext, c1: &Self::Ciphertext, key: &Self::Key);
fn nand_inplace(&mut self, c0: &mut Self::Ciphertext, c1: &Self::Ciphertext, key: &Self::Key);
fn or_inplace(&mut self, c0: &mut Self::Ciphertext, c1: &Self::Ciphertext, key: &Self::Key);
fn nor_inplace(&mut self, c0: &mut Self::Ciphertext, c1: &Self::Ciphertext, key: &Self::Key);
fn xor_inplace(&mut self, c0: &mut Self::Ciphertext, c1: &Self::Ciphertext, key: &Self::Key);
fn xnor_inplace(&mut self, c0: &mut Self::Ciphertext, c1: &Self::Ciphertext, key: &Self::Key);
fn not_inplace(&self, c: &mut Self::Ciphertext);
fn and(
&mut self,
c0: &Self::Ciphertext,
c1: &Self::Ciphertext,
key: &Self::Key,
) -> Self::Ciphertext;
fn nand(
&mut self,
c0: &Self::Ciphertext,
c1: &Self::Ciphertext,
key: &Self::Key,
) -> Self::Ciphertext;
fn or(
&mut self,
c0: &Self::Ciphertext,
c1: &Self::Ciphertext,
key: &Self::Key,
) -> Self::Ciphertext;
fn nor(
&mut self,
c0: &Self::Ciphertext,
c1: &Self::Ciphertext,
key: &Self::Key,
) -> Self::Ciphertext;
fn xor(
&mut self,
c0: &Self::Ciphertext,
c1: &Self::Ciphertext,
key: &Self::Key,
) -> Self::Ciphertext;
fn xnor(
&mut self,
c0: &Self::Ciphertext,
c1: &Self::Ciphertext,
key: &Self::Key,
) -> Self::Ciphertext;
fn not(&self, c: &Self::Ciphertext) -> Self::Ciphertext;
}
#[cfg(any(feature = "interactive_mp", feature = "non_interactive_mp"))]
mod impl_bool_frontend {
use crate::MultiPartyDecryptor;
/// Fhe Bool ciphertext
#[derive(Clone)]
pub struct FheBool<C> {
pub(crate) data: C,
}
impl<C> FheBool<C> {
pub(crate) fn data(&self) -> &C {
&self.data
}
pub(crate) fn data_mut(&mut self) -> &mut C {
&mut self.data
}
}
impl<C, K> MultiPartyDecryptor<bool, FheBool<C>> for K
where
K: MultiPartyDecryptor<bool, C>,
{
type DecryptionShare = <K as MultiPartyDecryptor<bool, C>>::DecryptionShare;
fn aggregate_decryption_shares(
&self,
c: &FheBool<C>,
shares: &[Self::DecryptionShare],
) -> bool {
self.aggregate_decryption_shares(&c.data, shares)
}
fn gen_decryption_share(&self, c: &FheBool<C>) -> Self::DecryptionShare {
self.gen_decryption_share(&c.data)
}
}
mod ops {
use std::ops::{BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign, Not};
use crate::{
utils::{Global, WithLocal},
BooleanGates,
};
use super::super::{BoolEvaluator, RuntimeServerKey};
type FheBool = super::super::FheBool;
impl BitAnd for &FheBool {
type Output = FheBool;
fn bitand(self, rhs: Self) -> Self::Output {
BoolEvaluator::with_local_mut(|e| {
let key = RuntimeServerKey::global();
FheBool {
data: e.and(self.data(), rhs.data(), key),
}
})
}
}
impl BitAndAssign for FheBool {
fn bitand_assign(&mut self, rhs: Self) {
BoolEvaluator::with_local_mut_mut(&mut |e| {
let key = RuntimeServerKey::global();
e.and_inplace(&mut self.data_mut(), rhs.data(), key);
});
}
}
impl BitOr for &FheBool {
type Output = FheBool;
fn bitor(self, rhs: Self) -> Self::Output {
BoolEvaluator::with_local_mut(|e| {
let key = RuntimeServerKey::global();
FheBool {
data: e.or(self.data(), rhs.data(), key),
}
})
}
}
impl BitOrAssign for FheBool {
fn bitor_assign(&mut self, rhs: Self) {
BoolEvaluator::with_local_mut_mut(&mut |e| {
let key = RuntimeServerKey::global();
e.or_inplace(&mut self.data_mut(), rhs.data(), key);
});
}
}
impl BitXor for &FheBool {
type Output = FheBool;
fn bitxor(self, rhs: Self) -> Self::Output {
BoolEvaluator::with_local_mut(|e| {
let key = RuntimeServerKey::global();
FheBool {
data: e.xor(self.data(), rhs.data(), key),
}
})
}
}
impl BitXorAssign for FheBool {
fn bitxor_assign(&mut self, rhs: Self) {
BoolEvaluator::with_local_mut_mut(&mut |e| {
let key = RuntimeServerKey::global();
e.xor_inplace(&mut self.data_mut(), rhs.data(), key);
});
}
}
impl Not for &FheBool {
type Output = FheBool;
fn not(self) -> Self::Output {
BoolEvaluator::with_local(|e| FheBool {
data: e.not(self.data()),
})
}
}
}
}
#[cfg(any(feature = "interactive_mp", feature = "non_interactive_mp"))]
mod common_mp_enc_dec {
use itertools::Itertools;
use super::BoolEvaluator;
use crate::{
pbs::{sample_extract, PbsInfo},
utils::WithLocal,
Matrix, RowEntity, SampleExtractor,
};
type Mat = Vec<Vec<u64>>;
impl SampleExtractor<<Mat as Matrix>::R> for Mat {
/// Sample extract coefficient at `index` as a LWE ciphertext from RLWE
/// ciphertext `Self`
fn extract_at(&self, index: usize) -> <Mat as Matrix>::R {
// input is RLWE ciphertext
assert!(self.dimension().0 == 2);
let ring_size = self.dimension().1;
assert!(index < ring_size);
BoolEvaluator::with_local(|e| {
let mut lwe_out = <Mat as Matrix>::R::zeros(ring_size + 1);
sample_extract(&mut lwe_out, self, e.pbs_info().modop_rlweq(), index);
lwe_out
})
}
/// Extract first `how_many` coefficients of `Self` as LWE ciphertexts
fn extract_many(&self, how_many: usize) -> Vec<<Mat as Matrix>::R> {
assert!(self.dimension().0 == 2);
let ring_size = self.dimension().1;
assert!(how_many <= ring_size);
(0..how_many)
.map(|index| {
BoolEvaluator::with_local(|e| {
let mut lwe_out = <Mat as Matrix>::R::zeros(ring_size + 1);
sample_extract(&mut lwe_out, self, e.pbs_info().modop_rlweq(), index);
lwe_out
})
})
.collect_vec()
}
/// Extracts all coefficients of `Self` as LWE ciphertexts
fn extract_all(&self) -> Vec<<Mat as Matrix>::R> {
assert!(self.dimension().0 == 2);
let ring_size = self.dimension().1;
(0..ring_size)
.map(|index| {
BoolEvaluator::with_local(|e| {
let mut lwe_out = <Mat as Matrix>::R::zeros(ring_size + 1);
sample_extract(&mut lwe_out, self, e.pbs_info().modop_rlweq(), index);
lwe_out
})
})
.collect_vec()
}
}
}
#[cfg(test)]
mod print_noise;

+ 697
- 0
src/bool/mp_api.rs

@ -0,0 +1,697 @@
use std::{cell::RefCell, sync::OnceLock};
use crate::{
backend::{ModularOpsU64, ModulusPowerOf2},
ntt::NttBackendU64,
random::{DefaultSecureRng, NewWithSeed},
utils::{Global, WithLocal},
};
use super::{evaluator::InteractiveMultiPartyCrs, keys::*, parameters::*, ClientKey};
pub(crate) type BoolEvaluator = super::evaluator::BoolEvaluator<
Vec<Vec<u64>>,
NttBackendU64,
ModularOpsU64<CiphertextModulus<u64>>,
ModulusPowerOf2<CiphertextModulus<u64>>,
ShoupServerKeyEvaluationDomain<Vec<Vec<u64>>>,
>;
thread_local! {
static BOOL_EVALUATOR: RefCell<Option<BoolEvaluator>> = RefCell::new(None);
}
static BOOL_SERVER_KEY: OnceLock<ShoupServerKeyEvaluationDomain<Vec<Vec<u64>>>> = OnceLock::new();
static MULTI_PARTY_CRS: OnceLock<InteractiveMultiPartyCrs<[u8; 32]>> = OnceLock::new();
pub enum ParameterSelector {
InteractiveLTE2Party,
InteractiveLTE4Party,
InteractiveLTE8Party,
}
/// Select Interactive multi-party parameter variant
pub fn set_parameter_set(select: ParameterSelector) {
match select {
ParameterSelector::InteractiveLTE2Party => {
BOOL_EVALUATOR.with_borrow_mut(|v| *v = Some(BoolEvaluator::new(I_2P_LB_SR)));
}
ParameterSelector::InteractiveLTE4Party => {
BOOL_EVALUATOR.with_borrow_mut(|v| *v = Some(BoolEvaluator::new(I_4P)));
}
ParameterSelector::InteractiveLTE8Party => {
BOOL_EVALUATOR.with_borrow_mut(|v| *v = Some(BoolEvaluator::new(I_8P)));
}
}
}
/// Set application specific interactive multi-party common reference string
pub fn set_common_reference_seed(seed: [u8; 32]) {
assert!(
MULTI_PARTY_CRS
.set(InteractiveMultiPartyCrs { seed: seed })
.is_ok(),
"Attempted to set MP SEED twice."
)
}
/// Generate client key for interactive multi-party protocol
pub fn gen_client_key() -> ClientKey {
BoolEvaluator::with_local(|e| e.client_key())
}
/// Generate client's share for collective public key, i.e round 1 share, of the
/// 2 round protocol
pub fn collective_pk_share(
ck: &ClientKey,
) -> CommonReferenceSeededCollectivePublicKeyShare<Vec<u64>, [u8; 32], BoolParameters<u64>> {
BoolEvaluator::with_local(|e| {
let pk_share = e.multi_party_public_key_share(InteractiveMultiPartyCrs::global(), ck);
pk_share
})
}
/// Generate clients share for collective server key, i.e. round 2, of the
/// 2 round protocol
pub fn collective_server_key_share<R, ModOp>(
ck: &ClientKey,
user_id: usize,
total_users: usize,
pk: &PublicKey<Vec<Vec<u64>>, R, ModOp>,
) -> CommonReferenceSeededInteractiveMultiPartyServerKeyShare<
Vec<Vec<u64>>,
BoolParameters<u64>,
InteractiveMultiPartyCrs<[u8; 32]>,
> {
BoolEvaluator::with_local_mut(|e| {
let server_key_share = e.gen_interactive_multi_party_server_key_share(
user_id,
total_users,
InteractiveMultiPartyCrs::global(),
pk.key(),
ck,
);
server_key_share
})
}
/// Aggregate public key shares from all parties.
///
/// Public key shares are generated per client in round 1. Aggregation of public
/// key shares marks the end of round 1.
pub fn aggregate_public_key_shares(
shares: &[CommonReferenceSeededCollectivePublicKeyShare<
Vec<u64>,
[u8; 32],
BoolParameters<u64>,
>],
) -> PublicKey<Vec<Vec<u64>>, DefaultSecureRng, ModularOpsU64<CiphertextModulus<u64>>> {
PublicKey::from(shares)
}
/// Aggregate server key shares
pub fn aggregate_server_key_shares(
shares: &[CommonReferenceSeededInteractiveMultiPartyServerKeyShare<
Vec<Vec<u64>>,
BoolParameters<u64>,
InteractiveMultiPartyCrs<[u8; 32]>,
>],
) -> SeededInteractiveMultiPartyServerKey<
Vec<Vec<u64>>,
InteractiveMultiPartyCrs<[u8; 32]>,
BoolParameters<u64>,
> {
BoolEvaluator::with_local(|e| e.aggregate_interactive_multi_party_server_key_shares(shares))
}
impl
SeededInteractiveMultiPartyServerKey<
Vec<Vec<u64>>,
InteractiveMultiPartyCrs<<DefaultSecureRng as NewWithSeed>::Seed>,
BoolParameters<u64>,
>
{
/// Sets the server key as a global reference for circuit evaluation
pub fn set_server_key(&self) {
assert!(
BOOL_SERVER_KEY
.set(ShoupServerKeyEvaluationDomain::from(
ServerKeyEvaluationDomain::<_, _, DefaultSecureRng, NttBackendU64>::from(self),
))
.is_ok(),
"Attempted to set server key twice."
);
}
}
// MULTIPARTY CRS //
impl Global for InteractiveMultiPartyCrs<[u8; 32]> {
fn global() -> &'static Self {
MULTI_PARTY_CRS
.get()
.expect("Multi Party Common Reference String not set")
}
}
// BOOL EVALUATOR //
impl WithLocal for BoolEvaluator {
fn with_local<F, R>(func: F) -> R
where
F: Fn(&Self) -> R,
{
BOOL_EVALUATOR.with_borrow(|s| func(s.as_ref().expect("Parameters not set")))
}
fn with_local_mut<F, R>(func: F) -> R
where
F: Fn(&mut Self) -> R,
{
BOOL_EVALUATOR.with_borrow_mut(|s| func(s.as_mut().expect("Parameters not set")))
}
fn with_local_mut_mut<F, R>(func: &mut F) -> R
where
F: FnMut(&mut Self) -> R,
{
BOOL_EVALUATOR.with_borrow_mut(|s| func(s.as_mut().expect("Parameters not set")))
}
}
pub(crate) type RuntimeServerKey = ShoupServerKeyEvaluationDomain<Vec<Vec<u64>>>;
impl Global for RuntimeServerKey {
fn global() -> &'static Self {
BOOL_SERVER_KEY.get().expect("Server key not set!")
}
}
mod impl_enc_dec {
use crate::{
bool::evaluator::BoolEncoding,
multi_party::{
multi_party_aggregate_decryption_shares_and_decrypt, multi_party_decryption_share,
},
pbs::{sample_extract, PbsInfo},
rgsw::public_key_encrypt_rlwe,
utils::TryConvertFrom1,
Encryptor, Matrix, MatrixEntity, MultiPartyDecryptor, RowEntity,
};
use itertools::Itertools;
use num_traits::{ToPrimitive, Zero};
use super::*;
type Mat = Vec<Vec<u64>>;
impl<Rng, ModOp> Encryptor<[bool], Vec<Mat>> for PublicKey<Mat, Rng, ModOp> {
fn encrypt(&self, m: &[bool]) -> Vec<Mat> {
BoolEvaluator::with_local(|e| {
DefaultSecureRng::with_local_mut(|rng| {
let parameters = e.parameters();
let ring_size = parameters.rlwe_n().0;
let rlwe_count = ((m.len() as f64 / ring_size as f64).ceil())
.to_usize()
.unwrap();
// encrypt `m` into ceil(len(m)/N) RLWE ciphertexts
let rlwes = (0..rlwe_count)
.map(|index| {
let mut message = vec![<Mat as Matrix>::MatElement::zero(); ring_size];
m[(index * ring_size)..std::cmp::min(m.len(), (index + 1) * ring_size)]
.iter()
.enumerate()
.for_each(|(i, v)| {
if *v {
message[i] = parameters.rlwe_q().true_el()
} else {
message[i] = parameters.rlwe_q().false_el()
}
});
// encrypt message
let mut rlwe_out =
<Mat as MatrixEntity>::zeros(2, parameters.rlwe_n().0);
public_key_encrypt_rlwe::<_, _, _, _, i32, _>(
&mut rlwe_out,
self.key(),
&message,
e.pbs_info().modop_rlweq(),
e.pbs_info().nttop_rlweq(),
rng,
);
rlwe_out
})
.collect_vec();
rlwes
})
})
}
}
impl<Rng, ModOp> Encryptor<bool, <Mat as Matrix>::R> for PublicKey<Mat, Rng, ModOp> {
fn encrypt(&self, m: &bool) -> <Mat as Matrix>::R {
let m = vec![*m];
let rlwe = &self.encrypt(m.as_slice())[0];
BoolEvaluator::with_local(|e| {
let mut lwe = <Mat as Matrix>::R::zeros(e.parameters().rlwe_n().0 + 1);
sample_extract(&mut lwe, rlwe, e.pbs_info().modop_rlweq(), 0);
lwe
})
}
}
impl<K> MultiPartyDecryptor<bool, <Mat as Matrix>::R> for K
where
K: InteractiveMultiPartyClientKey,
<Mat as Matrix>::R:
TryConvertFrom1<[K::Element], CiphertextModulus<<Mat as Matrix>::MatElement>>,
{
type DecryptionShare = <Mat as Matrix>::MatElement;
fn gen_decryption_share(&self, c: &<Mat as Matrix>::R) -> Self::DecryptionShare {
BoolEvaluator::with_local(|e| {
DefaultSecureRng::with_local_mut(|rng| {
multi_party_decryption_share(
c,
self.sk_rlwe().as_slice(),
e.pbs_info().modop_rlweq(),
rng,
)
})
})
}
fn aggregate_decryption_shares(
&self,
c: &<Mat as Matrix>::R,
shares: &[Self::DecryptionShare],
) -> bool {
BoolEvaluator::with_local(|e| {
let noisy_m = multi_party_aggregate_decryption_shares_and_decrypt(
c,
shares,
e.pbs_info().modop_rlweq(),
);
e.pbs_info().rlwe_q().decode(noisy_m)
})
}
}
}
#[cfg(test)]
mod tests {
use itertools::Itertools;
use rand::{thread_rng, Rng, RngCore};
use crate::{bool::evaluator::BoolEncoding, Encryptor, MultiPartyDecryptor, SampleExtractor};
use super::*;
#[test]
fn batched_fhe_u8s_extract_works() {
set_parameter_set(ParameterSelector::InteractiveLTE2Party);
let mut seed = [0u8; 32];
thread_rng().fill_bytes(&mut seed);
set_common_reference_seed(seed);
let parties = 2;
let cks = (0..parties).map(|_| gen_client_key()).collect_vec();
// round 1
let pk_shares = cks.iter().map(|k| collective_pk_share(k)).collect_vec();
// collective pk
let pk = aggregate_public_key_shares(&pk_shares);
let parameters = BoolEvaluator::with_local(|e| e.parameters().clone());
let batch_size = parameters.rlwe_n().0 * 3 + 123;
let m = (0..batch_size)
.map(|_| thread_rng().gen::<u8>())
.collect_vec();
let seeded_ct = pk.encrypt(m.as_slice());
let m_back = (0..batch_size)
.map(|i| {
let ct = seeded_ct.extract_at(i);
cks[0].aggregate_decryption_shares(
&ct,
&cks.iter()
.map(|k| k.gen_decryption_share(&ct))
.collect_vec(),
)
})
.collect_vec();
assert_eq!(m, m_back);
}
mod sp_api {
use num_traits::ToPrimitive;
use crate::{
bool::impl_bool_frontend::FheBool, pbs::PbsInfo, rgsw::seeded_secret_key_encrypt_rlwe,
Decryptor,
};
use super::*;
pub(crate) fn set_single_party_parameter_sets(parameter: BoolParameters<u64>) {
BOOL_EVALUATOR.with_borrow_mut(|e| *e = Some(BoolEvaluator::new(parameter)));
}
// SERVER KEY EVAL (/SHOUP) DOMAIN //
impl SeededSinglePartyServerKey<Vec<Vec<u64>>, BoolParameters<u64>, [u8; 32]> {
pub fn set_server_key(&self) {
assert!(
BOOL_SERVER_KEY
.set(
ShoupServerKeyEvaluationDomain::from(ServerKeyEvaluationDomain::<
_,
_,
DefaultSecureRng,
NttBackendU64,
>::from(
self
),)
)
.is_ok(),
"Attempted to set server key twice."
);
}
}
pub(crate) fn gen_keys() -> (
ClientKey,
SeededSinglePartyServerKey<Vec<Vec<u64>>, BoolParameters<u64>, [u8; 32]>,
) {
super::BoolEvaluator::with_local_mut(|e| {
let ck = e.client_key();
let sk = e.single_party_server_key(&ck);
(ck, sk)
})
}
impl<K: SinglePartyClientKey<Element = i32>> Encryptor<bool, Vec<u64>> for K {
fn encrypt(&self, m: &bool) -> Vec<u64> {
BoolEvaluator::with_local(|e| e.sk_encrypt(*m, self))
}
}
impl<K: SinglePartyClientKey<Element = i32>> Decryptor<bool, Vec<u64>> for K {
fn decrypt(&self, c: &Vec<u64>) -> bool {
BoolEvaluator::with_local(|e| e.sk_decrypt(c, self))
}
}
impl<K: SinglePartyClientKey<Element = i32>, C> Encryptor<bool, FheBool<C>> for K
where
K: Encryptor<bool, C>,
{
fn encrypt(&self, m: &bool) -> FheBool<C> {
FheBool {
data: self.encrypt(m),
}
}
}
impl<K: SinglePartyClientKey<Element = i32>, C> Decryptor<bool, FheBool<C>> for K
where
K: Decryptor<bool, C>,
{
fn decrypt(&self, c: &FheBool<C>) -> bool {
self.decrypt(c.data())
}
}
impl<K> Encryptor<[bool], (Vec<Vec<u64>>, [u8; 32])> for K
where
K: SinglePartyClientKey<Element = i32>,
{
fn encrypt(&self, m: &[bool]) -> (Vec<Vec<u64>>, [u8; 32]) {
BoolEvaluator::with_local(|e| {
DefaultSecureRng::with_local_mut(|rng| {
let parameters = e.parameters();
let ring_size = parameters.rlwe_n().0;
let rlwe_count = ((m.len() as f64 / ring_size as f64).ceil())
.to_usize()
.unwrap();
let mut seed = <DefaultSecureRng as NewWithSeed>::Seed::default();
rng.fill_bytes(&mut seed);
let mut prng = DefaultSecureRng::new_seeded(seed);
let sk_u = self.sk_rlwe();
// encrypt `m` into ceil(len(m)/N) RLWE ciphertexts
let rlwes = (0..rlwe_count)
.map(|index| {
let mut message = vec![0; ring_size];
m[(index * ring_size)
..std::cmp::min(m.len(), (index + 1) * ring_size)]
.iter()
.enumerate()
.for_each(|(i, v)| {
if *v {
message[i] = parameters.rlwe_q().true_el()
} else {
message[i] = parameters.rlwe_q().false_el()
}
});
// encrypt message
let mut rlwe_out = vec![0u64; parameters.rlwe_n().0];
seeded_secret_key_encrypt_rlwe(
&message,
&mut rlwe_out,
&sk_u,
e.pbs_info().modop_rlweq(),
e.pbs_info().nttop_rlweq(),
&mut prng,
rng,
);
rlwe_out
})
.collect_vec();
(rlwes, seed)
})
})
}
}
#[test]
#[cfg(feature = "interactive_mp")]
fn all_uint8_apis() {
use num_traits::Euclid;
use crate::{div_zero_error_flag, FheBool};
set_single_party_parameter_sets(SP_TEST_BOOL_PARAMS);
let (ck, sk) = gen_keys();
sk.set_server_key();
for i in 0..=255 {
for j in 0..=255 {
let m0 = i;
let m1 = j;
let c0 = ck.encrypt(&m0);
let c1 = ck.encrypt(&m1);
assert!(ck.decrypt(&c0) == m0);
assert!(ck.decrypt(&c1) == m1);
// Arithmetic
{
{
// Add
let c_add = &c0 + &c1;
let m0_plus_m1 = ck.decrypt(&c_add);
assert_eq!(
m0_plus_m1,
m0.wrapping_add(m1),
"Expected {} but got {m0_plus_m1} for
{i}+{j}",
m0.wrapping_add(m1)
);
}
{
// Sub
let c_sub = &c0 - &c1;
let m0_sub_m1 = ck.decrypt(&c_sub);
assert_eq!(
m0_sub_m1,
m0.wrapping_sub(m1),
"Expected {} but got {m0_sub_m1} for
{i}-{j}",
m0.wrapping_sub(m1)
);
}
{
// Mul
let c_m0m1 = &c0 * &c1;
let m0m1 = ck.decrypt(&c_m0m1);
assert_eq!(
m0m1,
m0.wrapping_mul(m1),
"Expected {} but got {m0m1} for {i}x{j}",
m0.wrapping_mul(m1)
);
}
// Div & Rem
{
let (c_quotient, c_rem) = c0.div_rem(&c1);
let m_quotient = ck.decrypt(&c_quotient);
let m_remainder = ck.decrypt(&c_rem);
if j != 0 {
let (q, r) = i.div_rem_euclid(&j);
assert_eq!(
m_quotient, q,
"Expected {} but got {m_quotient} for
{i}/{j}",
q
);
assert_eq!(
m_remainder, r,
"Expected {} but got {m_remainder} for
{i}%{j}",
r
);
} else {
assert_eq!(
m_quotient, 255,
"Expected 255 but got {m_quotient}. Case
div by zero"
);
assert_eq!(
m_remainder, i,
"Expected {i} but got {m_remainder}. Case
div by zero"
);
let div_by_zero = ck.decrypt(&div_zero_error_flag().unwrap());
assert_eq!(
div_by_zero, true,
"Expected true but got {div_by_zero}"
);
}
}
}
// // Comparisons
{
{
let c_eq = c0.eq(&c1);
let is_eq = ck.decrypt(&c_eq);
assert_eq!(
is_eq,
i == j,
"Expected {} but got {is_eq} for {i}=={j}",
i == j
);
}
{
let c_gt = c0.gt(&c1);
let is_gt = ck.decrypt(&c_gt);
assert_eq!(
is_gt,
i > j,
"Expected {} but got {is_gt} for {i}>{j}",
i > j
);
}
{
let c_lt = c0.lt(&c1);
let is_lt = ck.decrypt(&c_lt);
assert_eq!(
is_lt,
i < j,
"Expected {} but got {is_lt} for {i}<{j}",
i < j
);
}
{
let c_ge = c0.ge(&c1);
let is_ge = ck.decrypt(&c_ge);
assert_eq!(
is_ge,
i >= j,
"Expected {} but got {is_ge} for {i}>={j}",
i >= j
);
}
{
let c_le = c0.le(&c1);
let is_le = ck.decrypt(&c_le);
assert_eq!(
is_le,
i <= j,
"Expected {} but got {is_le} for {i}<={j}",
i <= j
);
}
}
// mux
{
let selector = thread_rng().gen_bool(0.5);
let selector_enc: FheBool = ck.encrypt(&selector);
let mux_out = ck.decrypt(&c0.mux(&c1, &selector_enc));
let want_mux_out = if selector { m0 } else { m1 };
assert_eq!(mux_out, want_mux_out);
}
}
}
}
#[test]
#[cfg(feature = "interactive_mp")]
fn all_bool_apis() {
use crate::FheBool;
set_single_party_parameter_sets(SP_TEST_BOOL_PARAMS);
let (ck, sk) = gen_keys();
sk.set_server_key();
for _ in 0..100 {
let a = thread_rng().gen_bool(0.5);
let b = thread_rng().gen_bool(0.5);
let c_a: FheBool = ck.encrypt(&a);
let c_b: FheBool = ck.encrypt(&b);
let c_out = &c_a & &c_b;
let out = ck.decrypt(&c_out);
assert_eq!(out, a & b, "Expected {} but got {out}", a & b);
let c_out = &c_a | &c_b;
let out = ck.decrypt(&c_out);
assert_eq!(out, a | b, "Expected {} but got {out}", a | b);
let c_out = &c_a ^ &c_b;
let out = ck.decrypt(&c_out);
assert_eq!(out, a ^ b, "Expected {} but got {out}", a ^ b);
let c_out = !(&c_a);
let out = ck.decrypt(&c_out);
assert_eq!(out, !a, "Expected {} but got {out}", !a);
}
}
}
}

+ 459
- 0
src/bool/ni_mp_api.rs

@ -0,0 +1,459 @@
use std::{cell::RefCell, sync::OnceLock};
use crate::{
backend::ModulusPowerOf2,
bool::parameters::ParameterVariant,
random::DefaultSecureRng,
utils::{Global, WithLocal},
ModularOpsU64, NttBackendU64,
};
use super::{
evaluator::NonInteractiveMultiPartyCrs,
keys::{
CommonReferenceSeededNonInteractiveMultiPartyServerKeyShare,
NonInteractiveServerKeyEvaluationDomain, SeededNonInteractiveMultiPartyServerKey,
ShoupNonInteractiveServerKeyEvaluationDomain,
},
parameters::{BoolParameters, CiphertextModulus, NI_2P, NI_4P_HB_FR, NI_8P},
ClientKey,
};
pub(crate) type BoolEvaluator = super::evaluator::BoolEvaluator<
Vec<Vec<u64>>,
NttBackendU64,
ModularOpsU64<CiphertextModulus<u64>>,
ModulusPowerOf2<CiphertextModulus<u64>>,
ShoupNonInteractiveServerKeyEvaluationDomain<Vec<Vec<u64>>>,
>;
thread_local! {
static BOOL_EVALUATOR: RefCell<Option<BoolEvaluator>> = RefCell::new(None);
}
static BOOL_SERVER_KEY: OnceLock<ShoupNonInteractiveServerKeyEvaluationDomain<Vec<Vec<u64>>>> =
OnceLock::new();
static MULTI_PARTY_CRS: OnceLock<NonInteractiveMultiPartyCrs<[u8; 32]>> = OnceLock::new();
pub enum ParameterSelector {
NonInteractiveLTE2Party,
NonInteractiveLTE4Party,
NonInteractiveLTE8Party,
}
pub fn set_parameter_set(select: ParameterSelector) {
match select {
ParameterSelector::NonInteractiveLTE2Party => {
BOOL_EVALUATOR.with_borrow_mut(|v| *v = Some(BoolEvaluator::new(NI_2P)));
}
ParameterSelector::NonInteractiveLTE4Party => {
BOOL_EVALUATOR.with_borrow_mut(|v| *v = Some(BoolEvaluator::new(NI_4P_HB_FR)));
}
ParameterSelector::NonInteractiveLTE8Party => {
BOOL_EVALUATOR.with_borrow_mut(|v| *v = Some(BoolEvaluator::new(NI_8P)));
}
}
}
pub fn set_common_reference_seed(seed: [u8; 32]) {
BoolEvaluator::with_local(|e| {
assert_eq!(
e.parameters().variant(),
&ParameterVariant::NonInteractiveMultiParty,
"Set parameters do not support Non interactive multi-party"
);
});
assert!(
MULTI_PARTY_CRS
.set(NonInteractiveMultiPartyCrs { seed: seed })
.is_ok(),
"Attempted to set MP SEED twice."
)
}
pub fn gen_client_key() -> ClientKey {
BoolEvaluator::with_local(|e| e.client_key())
}
pub fn gen_server_key_share(
user_id: usize,
total_users: usize,
client_key: &ClientKey,
) -> CommonReferenceSeededNonInteractiveMultiPartyServerKeyShare<
Vec<Vec<u64>>,
BoolParameters<u64>,
NonInteractiveMultiPartyCrs<[u8; 32]>,
> {
BoolEvaluator::with_local(|e| {
let cr_seed = NonInteractiveMultiPartyCrs::global();
e.gen_non_interactive_multi_party_key_share(cr_seed, user_id, total_users, client_key)
})
}
pub fn aggregate_server_key_shares(
shares: &[CommonReferenceSeededNonInteractiveMultiPartyServerKeyShare<
Vec<Vec<u64>>,
BoolParameters<u64>,
NonInteractiveMultiPartyCrs<[u8; 32]>,
>],
) -> SeededNonInteractiveMultiPartyServerKey<
Vec<Vec<u64>>,
NonInteractiveMultiPartyCrs<[u8; 32]>,
BoolParameters<u64>,
> {
BoolEvaluator::with_local(|e| {
let cr_seed = NonInteractiveMultiPartyCrs::global();
e.aggregate_non_interactive_multi_party_server_key_shares(cr_seed, shares)
})
}
impl
SeededNonInteractiveMultiPartyServerKey<
Vec<Vec<u64>>,
NonInteractiveMultiPartyCrs<[u8; 32]>,
BoolParameters<u64>,
>
{
pub fn set_server_key(&self) {
let eval_key = NonInteractiveServerKeyEvaluationDomain::<
_,
BoolParameters<u64>,
DefaultSecureRng,
NttBackendU64,
>::from(self);
assert!(
BOOL_SERVER_KEY
.set(ShoupNonInteractiveServerKeyEvaluationDomain::from(eval_key))
.is_ok(),
"Attempted to set server key twice!"
);
}
}
impl Global for NonInteractiveMultiPartyCrs<[u8; 32]> {
fn global() -> &'static Self {
MULTI_PARTY_CRS
.get()
.expect("Non-interactive multi-party common reference string not set")
}
}
// BOOL EVALUATOR //
impl WithLocal for BoolEvaluator {
fn with_local<F, R>(func: F) -> R
where
F: Fn(&Self) -> R,
{
BOOL_EVALUATOR.with_borrow(|s| func(s.as_ref().expect("Parameters not set")))
}
fn with_local_mut<F, R>(func: F) -> R
where
F: Fn(&mut Self) -> R,
{
BOOL_EVALUATOR.with_borrow_mut(|s| func(s.as_mut().expect("Parameters not set")))
}
fn with_local_mut_mut<F, R>(func: &mut F) -> R
where
F: FnMut(&mut Self) -> R,
{
BOOL_EVALUATOR.with_borrow_mut(|s| func(s.as_mut().expect("Parameters not set")))
}
}
pub(crate) type RuntimeServerKey = ShoupNonInteractiveServerKeyEvaluationDomain<Vec<Vec<u64>>>;
impl Global for RuntimeServerKey {
fn global() -> &'static Self {
BOOL_SERVER_KEY.get().expect("Server key not set!")
}
}
/// Batch of bool ciphertexts stored as vector of RLWE ciphertext under user j's
/// RLWE secret `u_j`
///
/// To use the bool ciphertexts in multi-party protocol first key switch the
/// ciphertexts from u_j to ideal RLWE secret `s` with
/// `self.key_switch(user_id)` where `user_id` is user j's id. Key switch
/// returns `BatchedFheBools` that stored key vector of key switched RLWE
/// ciphertext.
pub(super) struct NonInteractiveBatchedFheBools<C> {
data: Vec<C>,
}
/// Batch of Bool cipphertexts stored as vector of RLWE ciphertexts under the
/// ideal RLWE secret key `s` of the protocol
///
/// Bool ciphertext at `index` can be extracted from the coefficient at `index %
/// N` of `index / N`th RLWE ciphertext.
///
/// To extract bool ciphertext at `index` as LWE ciphertext use
/// `self.extract(index)`
pub(super) struct BatchedFheBools<C> {
pub(in super::super) data: Vec<C>,
}
/// Non interactive multi-party specfic encryptor decryptor routines
mod impl_enc_dec {
use crate::{
bool::{evaluator::BoolEncoding, keys::NonInteractiveMultiPartyClientKey},
multi_party::{
multi_party_aggregate_decryption_shares_and_decrypt, multi_party_decryption_share,
},
pbs::{sample_extract, PbsInfo, WithShoupRepr},
random::{NewWithSeed, RandomFillUniformInModulus},
rgsw::{rlwe_key_switch, seeded_secret_key_encrypt_rlwe},
utils::TryConvertFrom1,
Encryptor, KeySwitchWithId, Matrix, MatrixEntity, MatrixMut, MultiPartyDecryptor,
RowEntity, RowMut,
};
use itertools::Itertools;
use num_traits::{ToPrimitive, Zero};
use super::*;
type Mat = Vec<Vec<u64>>;
// Implement `extract` to extract Bool LWE ciphertext at `index` from
// `BatchedFheBools`
impl<C: MatrixMut<MatElement = u64>> BatchedFheBools<C>
where
C::R: RowEntity + RowMut,
{
pub(crate) fn extract(&self, index: usize) -> C::R {
BoolEvaluator::with_local(|e| {
let ring_size = e.parameters().rlwe_n().0;
let ct_index = index / ring_size;
let coeff_index = index % ring_size;
let mut lwe_out = C::R::zeros(e.parameters().rlwe_n().0 + 1);
sample_extract(
&mut lwe_out,
&self.data[ct_index],
e.pbs_info().modop_rlweq(),
coeff_index,
);
lwe_out
})
}
}
impl<M: MatrixEntity + MatrixMut<MatElement = u64>> From<&(Vec<M::R>, [u8; 32])>
for NonInteractiveBatchedFheBools<M>
where
<M as Matrix>::R: RowMut,
{
/// Derive `NonInteractiveBatchedFheBools` from a vector seeded RLWE
/// ciphertexts (Vec<RLWE>, Seed)
///
/// Unseed the RLWE ciphertexts and store them as vector RLWE
/// ciphertexts in `NonInteractiveBatchedFheBools`
fn from(value: &(Vec<M::R>, [u8; 32])) -> Self {
BoolEvaluator::with_local(|e| {
let parameters = e.parameters();
let ring_size = parameters.rlwe_n().0;
let rlwe_q = parameters.rlwe_q();
let mut prng = DefaultSecureRng::new_seeded(value.1);
let rlwes = value
.0
.iter()
.map(|partb| {
let mut rlwe = M::zeros(2, ring_size);
// sample A
RandomFillUniformInModulus::random_fill(
&mut prng,
rlwe_q,
rlwe.get_row_mut(0),
);
// Copy over B
rlwe.get_row_mut(1).copy_from_slice(partb.as_ref());
rlwe
})
.collect_vec();
Self { data: rlwes }
})
}
}
impl<K> Encryptor<[bool], NonInteractiveBatchedFheBools<Mat>> for K
where
K: Encryptor<[bool], (Mat, [u8; 32])>,
{
/// Encrypt a vector bool of arbitrary length as vector of unseeded RLWE
/// ciphertexts in `NonInteractiveBatchedFheBools`
fn encrypt(&self, m: &[bool]) -> NonInteractiveBatchedFheBools<Mat> {
NonInteractiveBatchedFheBools::from(&K::encrypt(&self, m))
}
}
impl<K> Encryptor<[bool], (Vec<<Mat as Matrix>::R>, [u8; 32])> for K
where
K: NonInteractiveMultiPartyClientKey,
<Mat as Matrix>::R:
TryConvertFrom1<[K::Element], CiphertextModulus<<Mat as Matrix>::MatElement>>,
{
/// Encrypt a vector of bool of arbitrary length as vector of seeded
/// RLWE ciphertexts and returns (Vec<RLWE>, Seed)
fn encrypt(&self, m: &[bool]) -> (Mat, [u8; 32]) {
BoolEvaluator::with_local(|e| {
DefaultSecureRng::with_local_mut(|rng| {
let parameters = e.parameters();
let ring_size = parameters.rlwe_n().0;
let rlwe_count = ((m.len() as f64 / ring_size as f64).ceil())
.to_usize()
.unwrap();
let mut seed = <DefaultSecureRng as NewWithSeed>::Seed::default();
rng.fill_bytes(&mut seed);
let mut prng = DefaultSecureRng::new_seeded(seed);
let sk_u = self.sk_u_rlwe();
// encrypt `m` into ceil(len(m)/N) RLWE ciphertexts
let rlwes = (0..rlwe_count)
.map(|index| {
let mut message = vec![<Mat as Matrix>::MatElement::zero(); ring_size];
m[(index * ring_size)..std::cmp::min(m.len(), (index + 1) * ring_size)]
.iter()
.enumerate()
.for_each(|(i, v)| {
if *v {
message[i] = parameters.rlwe_q().true_el()
} else {
message[i] = parameters.rlwe_q().false_el()
}
});
// encrypt message
let mut rlwe_out =
<<Mat as Matrix>::R as RowEntity>::zeros(parameters.rlwe_n().0);
seeded_secret_key_encrypt_rlwe(
&message,
&mut rlwe_out,
&sk_u,
e.pbs_info().modop_rlweq(),
e.pbs_info().nttop_rlweq(),
&mut prng,
rng,
);
rlwe_out
})
.collect_vec();
(rlwes, seed)
})
})
}
}
impl<K> MultiPartyDecryptor<bool, <Mat as Matrix>::R> for K
where
K: NonInteractiveMultiPartyClientKey,
<Mat as Matrix>::R:
TryConvertFrom1<[K::Element], CiphertextModulus<<Mat as Matrix>::MatElement>>,
{
type DecryptionShare = <Mat as Matrix>::MatElement;
fn gen_decryption_share(&self, c: &<Mat as Matrix>::R) -> Self::DecryptionShare {
BoolEvaluator::with_local(|e| {
DefaultSecureRng::with_local_mut(|rng| {
multi_party_decryption_share(
c,
self.sk_rlwe().as_slice(),
e.pbs_info().modop_rlweq(),
rng,
)
})
})
}
fn aggregate_decryption_shares(
&self,
c: &<Mat as Matrix>::R,
shares: &[Self::DecryptionShare],
) -> bool {
BoolEvaluator::with_local(|e| {
let noisy_m = multi_party_aggregate_decryption_shares_and_decrypt(
c,
shares,
e.pbs_info().modop_rlweq(),
);
e.pbs_info().rlwe_q().decode(noisy_m)
})
}
}
impl KeySwitchWithId<Mat> for Mat {
/// Key switch RLWE ciphertext `Self` from user j's RLWE secret u_j
/// to ideal RLWE secret `s` of non-interactive multi-party protocol.
///
/// - user_id: user j's user_id in the protocol
fn key_switch(&self, user_id: usize) -> Mat {
BoolEvaluator::with_local(|e| {
assert!(self.dimension() == (2, e.parameters().rlwe_n().0));
let server_key = BOOL_SERVER_KEY.get().unwrap();
let ksk = server_key.ui_to_s_ksk(user_id);
let decomposer = e.ni_ui_to_s_ks_decomposer().as_ref().unwrap();
// perform key switch
rlwe_key_switch(
self,
ksk.as_ref(),
ksk.shoup_repr(),
decomposer,
e.pbs_info().nttop_rlweq(),
e.pbs_info().modop_rlweq(),
)
})
}
}
impl<C> KeySwitchWithId<BatchedFheBools<C>> for NonInteractiveBatchedFheBools<C>
where
C: KeySwitchWithId<C>,
{
/// Key switch `Self`'s vector of RLWE ciphertexts from user j's RLWE
/// secret u_j to ideal RLWE secret `s` of non-interactive
/// multi-party protocol.
///
/// Returns vector of key switched RLWE ciphertext as `BatchedFheBools`
/// which can then be used to extract individual Bool LWE ciphertexts.
///
/// - user_id: user j's user_id in the protocol
fn key_switch(&self, user_id: usize) -> BatchedFheBools<C> {
let data = self
.data
.iter()
.map(|c| c.key_switch(user_id))
.collect_vec();
BatchedFheBools { data }
}
}
}
#[cfg(test)]
mod tests {
use itertools::Itertools;
use rand::{thread_rng, RngCore};
use crate::{
backend::Modulus,
bool::{
keys::tests::{ideal_sk_rlwe, measure_noise_lwe},
BooleanGates,
},
utils::tests::Stats,
Encoder, Encryptor, KeySwitchWithId, MultiPartyDecryptor,
};
use super::*;
}

+ 738
- 0
src/bool/parameters.rs

@ -0,0 +1,738 @@
use num_traits::{ConstZero, FromPrimitive, PrimInt};
use crate::{
backend::Modulus,
decomposer::{Decomposer, NumInfo},
utils::log2,
};
pub(crate) trait DoubleDecomposerCount {
type Count;
fn a(&self) -> Self::Count;
fn b(&self) -> Self::Count;
}
pub(crate) trait DoubleDecomposerParams {
type Base;
type Count;
fn decomposition_base(&self) -> Self::Base;
fn decomposition_count_a(&self) -> Self::Count;
fn decomposition_count_b(&self) -> Self::Count;
}
pub(crate) trait SingleDecomposerParams {
type Base;
type Count;
// fn new(base: Self::Base, count: Self::Count) -> Self;
fn decomposition_base(&self) -> Self::Base;
fn decomposition_count(&self) -> Self::Count;
}
impl DoubleDecomposerParams
for (
DecompostionLogBase,
// Assume (Decomposition count for A, Decomposition count for B)
(DecompositionCount, DecompositionCount),
)
{
type Base = DecompostionLogBase;
type Count = DecompositionCount;
// fn new(
// base: DecompostionLogBase,
// count_a: DecompositionCount,
// count_b: DecompositionCount,
// ) -> Self {
// (base, (count_a, count_b))
// }
fn decomposition_base(&self) -> Self::Base {
self.0
}
fn decomposition_count_a(&self) -> Self::Count {
self.1 .0
}
fn decomposition_count_b(&self) -> Self::Count {
self.1 .1
}
}
impl SingleDecomposerParams for (DecompostionLogBase, DecompositionCount) {
type Base = DecompostionLogBase;
type Count = DecompositionCount;
// fn new(base: DecompostionLogBase, count: DecompositionCount) -> Self {
// (base, count)
// }
fn decomposition_base(&self) -> Self::Base {
self.0
}
fn decomposition_count(&self) -> Self::Count {
self.1
}
}
#[derive(Clone, PartialEq, Debug)]
pub(crate) enum SecretKeyDistribution {
/// Elements of secret key are sample from Gaussian distribitution with
/// \sigma = 3.19 and \mu = 0.0
ErrorDistribution,
/// Elements of secret key are chosen from the set {1,0,-1} with hamming
/// weight `floor(N/2)` where `N` is the secret dimension.
TernaryDistribution,
}
#[derive(Clone, PartialEq, Debug)]
pub(crate) enum ParameterVariant {
SingleParty,
InteractiveMultiParty,
NonInteractiveMultiParty,
}
#[derive(Clone, PartialEq)]
pub struct BoolParameters<El> {
/// RLWE secret key distribution
rlwe_secret_key_dist: SecretKeyDistribution,
/// LWE secret key distribtuion
lwe_secret_key_dist: SecretKeyDistribution,
/// RLWE ciphertext modulus Q
rlwe_q: CiphertextModulus<El>,
/// LWE ciphertext modulus q (usually referred to as Q_{ks})
lwe_q: CiphertextModulus<El>,
/// Blind rotation modulus. It is the modulus to which we switch before
/// blind rotation.
///
/// Since blind rotation decrypts LWE ciphertext in the exponent of a ring
/// polynomial, which is a ring mod 2N, blind rotation modulus is
/// always <= 2N.
br_q: usize,
/// Ring dimension `N` for 2N^{th} cyclotomic polynomial ring
rlwe_n: PolynomialSize,
/// LWE dimension `n`
lwe_n: LweDimension,
/// LWE key switch decompositon params
lwe_decomposer_params: (DecompostionLogBase, DecompositionCount),
/// Decompostion parameters for RLWE x RGSW.
///
/// We restrict decomposition for RLWE'(-sm) and RLWE'(m) to have same base
/// but can have different decomposition count. We refer to this
/// DoubleDecomposer / RlweDecomposer
///
/// Decomposition count `d_a` (i.e. for SignedDecompose(RLWE_A(m)) x
/// RLWE'(-sm)) and `d_b` (i.e. for SignedDecompose(RLWE_B(m)) x RLWE'(m))
/// are always stored as `(d_a, d_b)`
rlrg_decomposer_params: (
DecompostionLogBase,
(DecompositionCount, DecompositionCount),
),
/// Decomposition parameters for RLWE automorphism
auto_decomposer_params: (DecompostionLogBase, DecompositionCount),
/// Decomposition parameters for RGSW0 x RGSW1
///
/// `0` and `1` indicate that RGSW0 and RGSW1 may not use same decomposition
/// parameters.
///
/// In RGSW0 x RGSW1, decomposition parameters for RGSW1 are required.
/// Hence, the parameters we store are decomposition parameters of RGSW1.
///
/// Like RLWE x RGSW decomposition parameters (1) we restrict to same base
/// but can have different decomposition counts `d_a` and `d_b` and (2)
/// decomposition count `d_a` and `d_b` are always stored as `(d_a, d_b)`
///
/// RGSW0 x RGSW1 are optional because they only necessary to be supplied in
/// multi-party setting.
rgrg_decomposer_params: Option<(
DecompostionLogBase,
(DecompositionCount, DecompositionCount),
)>,
/// Decomposition parameters for non-interactive key switching from u_j to
/// s, hwere u_j is RLWE secret `u` of party `j` and `s` is the ideal RLWE
/// secret key.
///
/// Decomposition parameters for non-interactive key switching are optional
/// and must be supplied only for non-interactive multi-party
non_interactive_ui_to_s_key_switch_decomposer:
Option<(DecompostionLogBase, DecompositionCount)>,
/// Group generator for Z^*_{br_q}
g: usize,
/// Window size parameter for LMKC++ blind rotation
w: usize,
/// Parameter variant
variant: ParameterVariant,
}
impl<El> BoolParameters<El> {
pub(crate) fn rlwe_secret_key_dist(&self) -> &SecretKeyDistribution {
&self.rlwe_secret_key_dist
}
pub(crate) fn lwe_secret_key_dist(&self) -> &SecretKeyDistribution {
&self.lwe_secret_key_dist
}
pub(crate) fn rlwe_q(&self) -> &CiphertextModulus<El> {
&self.rlwe_q
}
pub(crate) fn lwe_q(&self) -> &CiphertextModulus<El> {
&self.lwe_q
}
pub(crate) fn br_q(&self) -> &usize {
&self.br_q
}
pub(crate) fn rlwe_n(&self) -> &PolynomialSize {
&self.rlwe_n
}
pub(crate) fn lwe_n(&self) -> &LweDimension {
&self.lwe_n
}
pub(crate) fn g(&self) -> usize {
self.g
}
pub(crate) fn w(&self) -> usize {
self.w
}
pub(crate) fn rlwe_by_rgsw_decomposition_params(
&self,
) -> &(
DecompostionLogBase,
(DecompositionCount, DecompositionCount),
) {
&self.rlrg_decomposer_params
}
pub(crate) fn rgsw_by_rgsw_decomposition_params(
&self,
) -> (
DecompostionLogBase,
(DecompositionCount, DecompositionCount),
) {
self.rgrg_decomposer_params.expect(&format!(
"Parameter variant {:?} does not support RGSWxRGSW",
self.variant
))
}
pub(crate) fn rlwe_rgsw_decomposition_base(&self) -> DecompostionLogBase {
self.rlrg_decomposer_params.0
}
pub(crate) fn rlwe_rgsw_decomposition_count(&self) -> (DecompositionCount, DecompositionCount) {
self.rlrg_decomposer_params.1
}
pub(crate) fn rgsw_rgsw_decomposition_count(&self) -> (DecompositionCount, DecompositionCount) {
let params = self.rgrg_decomposer_params.expect(&format!(
"Parameter variant {:?} does not support RGSW x RGSW",
self.variant
));
params.1
}
pub(crate) fn auto_decomposition_param(&self) -> &(DecompostionLogBase, DecompositionCount) {
&self.auto_decomposer_params
}
pub(crate) fn auto_decomposition_base(&self) -> DecompostionLogBase {
self.auto_decomposer_params.decomposition_base()
}
pub(crate) fn auto_decomposition_count(&self) -> DecompositionCount {
self.auto_decomposer_params.decomposition_count()
}
pub(crate) fn lwe_decomposition_base(&self) -> DecompostionLogBase {
self.lwe_decomposer_params.decomposition_base()
}
pub(crate) fn lwe_decomposition_count(&self) -> DecompositionCount {
self.lwe_decomposer_params.decomposition_count()
}
pub(crate) fn non_interactive_ui_to_s_key_switch_decomposition_count(
&self,
) -> DecompositionCount {
let params = self
.non_interactive_ui_to_s_key_switch_decomposer
.expect(&format!(
"Parameter variant {:?} does not support non-interactive",
self.variant
));
params.decomposition_count()
}
pub(crate) fn rgsw_rgsw_decomposer<D: Decomposer<Element = El>>(&self) -> (D, D)
where
El: Copy,
{
let params = self.rgrg_decomposer_params.expect(&format!(
"Parameter variant {:?} does not support RGSW x RGSW",
self.variant
));
(
// A
D::new(
self.rlwe_q.0,
params.decomposition_base().0,
params.decomposition_count_a().0,
),
// B
D::new(
self.rlwe_q.0,
params.decomposition_base().0,
params.decomposition_count_b().0,
),
)
}
pub(crate) fn auto_decomposer<D: Decomposer<Element = El>>(&self) -> D
where
El: Copy,
{
D::new(
self.rlwe_q.0,
self.auto_decomposer_params.decomposition_base().0,
self.auto_decomposer_params.decomposition_count().0,
)
}
pub(crate) fn lwe_decomposer<D: Decomposer<Element = El>>(&self) -> D
where
El: Copy,
{
D::new(
self.lwe_q.0,
self.lwe_decomposer_params.decomposition_base().0,
self.lwe_decomposer_params.decomposition_count().0,
)
}
pub(crate) fn rlwe_rgsw_decomposer<D: Decomposer<Element = El>>(&self) -> (D, D)
where
El: Copy,
{
(
// A
D::new(
self.rlwe_q.0,
self.rlrg_decomposer_params.decomposition_base().0,
self.rlrg_decomposer_params.decomposition_count_a().0,
),
// B
D::new(
self.rlwe_q.0,
self.rlrg_decomposer_params.decomposition_base().0,
self.rlrg_decomposer_params.decomposition_count_b().0,
),
)
}
pub(crate) fn non_interactive_ui_to_s_key_switch_decomposer<D: Decomposer<Element = El>>(
&self,
) -> D
where
El: Copy,
{
let params = self
.non_interactive_ui_to_s_key_switch_decomposer
.expect(&format!(
"Parameter variant {:?} does not support non-interactive",
self.variant
));
D::new(
self.rlwe_q.0,
params.decomposition_base().0,
params.decomposition_count().0,
)
}
/// Returns dlogs of `g` for which auto keys are required as
/// per the parameter. Given that autos are required for [-g, g, g^2, ...,
/// g^w] function returns the following [0, 1, 2, ..., w] where `w` is
/// the window size. Note that although g^0 = 1, we use 0 for -g.
pub(crate) fn auto_element_dlogs(&self) -> Vec<usize> {
let mut els = vec![0];
(1..self.w + 1).into_iter().for_each(|e| {
els.push(e);
});
els
}
pub(crate) fn variant(&self) -> &ParameterVariant {
&self.variant
}
}
#[derive(Clone, Copy, PartialEq)]
pub struct DecompostionLogBase(pub(crate) usize);
impl AsRef<usize> for DecompostionLogBase {
fn as_ref(&self) -> &usize {
&self.0
}
}
#[derive(Clone, Copy, PartialEq)]
pub struct DecompositionCount(pub(crate) usize);
impl AsRef<usize> for DecompositionCount {
fn as_ref(&self) -> &usize {
&self.0
}
}
#[derive(Clone, Copy, PartialEq)]
pub(crate) struct LweDimension(pub(crate) usize);
#[derive(Clone, Copy, PartialEq)]
pub(crate) struct PolynomialSize(pub(crate) usize);
#[derive(Clone, Copy, PartialEq, Debug)]
/// T equals modulus when modulus is non-native. Otherwise T equals 0. bool is
/// true when modulus is native, false otherwise.
pub struct CiphertextModulus<T>(T, bool);
impl<T: ConstZero> CiphertextModulus<T> {
const fn new_native() -> Self {
// T::zero is stored only for convenience. It has no use when modulus
// is native. That is, either u128,u64,u32,u16
Self(T::ZERO, true)
}
const fn new_non_native(q: T) -> Self {
Self(q, false)
}
}
impl<T> CiphertextModulus<T>
where
T: PrimInt + NumInfo,
{
fn _bits() -> usize {
T::BITS as usize
}
fn _native(&self) -> bool {
self.1
}
fn _half_q(&self) -> T {
if self._native() {
T::one() << (Self::_bits() - 1)
} else {
self.0 >> 1
}
}
fn _q(&self) -> Option<T> {
if self._native() {
None
} else {
Some(self.0)
}
}
}
impl<T> Modulus for CiphertextModulus<T>
where
T: PrimInt + FromPrimitive + NumInfo,
{
type Element = T;
fn is_native(&self) -> bool {
self._native()
}
fn largest_unsigned_value(&self) -> Self::Element {
if self._native() {
T::max_value()
} else {
self.0 - T::one()
}
}
fn neg_one(&self) -> Self::Element {
if self._native() {
T::max_value()
} else {
self.0 - T::one()
}
}
// fn signed_max(&self) -> Self::Element {}
// fn signed_min(&self) -> Self::Element {}
fn smallest_unsigned_value(&self) -> Self::Element {
T::zero()
}
fn map_element_to_i64(&self, v: &Self::Element) -> i64 {
assert!(*v <= self.largest_unsigned_value());
if *v > self._half_q() {
-((self.largest_unsigned_value() - *v) + T::one())
.to_i64()
.unwrap()
} else {
v.to_i64().unwrap()
}
}
fn map_element_from_f64(&self, v: f64) -> Self::Element {
let v = v.round();
let v_el = T::from_f64(v.abs()).unwrap();
assert!(v_el <= self.largest_unsigned_value());
if v < 0.0 {
self.largest_unsigned_value() - v_el + T::one()
} else {
v_el
}
}
fn map_element_from_i64(&self, v: i64) -> Self::Element {
let v_el = T::from_i64(v.abs()).unwrap();
assert!(v_el <= self.largest_unsigned_value());
if v < 0 {
self.largest_unsigned_value() - v_el + T::one()
} else {
v_el
}
}
fn q(&self) -> Option<Self::Element> {
self._q()
}
fn q_as_f64(&self) -> Option<f64> {
if self._native() {
Some(T::max_value().to_f64().unwrap() + 1.0)
} else {
self.0.to_f64()
}
}
fn log_q(&self) -> usize {
if self.is_native() {
Self::_bits()
} else {
log2(&self.q().unwrap())
}
}
}
pub(crate) const I_2P_LB_SR: BoolParameters<u64> = BoolParameters::<u64> {
rlwe_secret_key_dist: SecretKeyDistribution::TernaryDistribution,
lwe_secret_key_dist: SecretKeyDistribution::TernaryDistribution,
rlwe_q: CiphertextModulus::new_non_native(18014398509404161),
lwe_q: CiphertextModulus::new_non_native(1 << 15),
br_q: 1 << 11,
rlwe_n: PolynomialSize(1 << 11),
lwe_n: LweDimension(580),
lwe_decomposer_params: (DecompostionLogBase(1), DecompositionCount(12)),
rlrg_decomposer_params: (
DecompostionLogBase(17),
(DecompositionCount(1), DecompositionCount(1)),
),
rgrg_decomposer_params: Some((
DecompostionLogBase(7),
(DecompositionCount(6), DecompositionCount(5)),
)),
auto_decomposer_params: (DecompostionLogBase(24), DecompositionCount(1)),
non_interactive_ui_to_s_key_switch_decomposer: None,
g: 5,
w: 10,
variant: ParameterVariant::InteractiveMultiParty,
};
pub(crate) const I_4P: BoolParameters<u64> = BoolParameters::<u64> {
rlwe_secret_key_dist: SecretKeyDistribution::TernaryDistribution,
lwe_secret_key_dist: SecretKeyDistribution::TernaryDistribution,
rlwe_q: CiphertextModulus::new_non_native(18014398509404161),
lwe_q: CiphertextModulus::new_non_native(1 << 16),
br_q: 1 << 11,
rlwe_n: PolynomialSize(1 << 11),
lwe_n: LweDimension(620),
lwe_decomposer_params: (DecompostionLogBase(1), DecompositionCount(13)),
rlrg_decomposer_params: (
DecompostionLogBase(17),
(DecompositionCount(1), DecompositionCount(1)),
),
rgrg_decomposer_params: Some((
DecompostionLogBase(6),
(DecompositionCount(7), DecompositionCount(6)),
)),
auto_decomposer_params: (DecompostionLogBase(24), DecompositionCount(1)),
non_interactive_ui_to_s_key_switch_decomposer: None,
g: 5,
w: 10,
variant: ParameterVariant::InteractiveMultiParty,
};
pub(crate) const I_8P: BoolParameters<u64> = BoolParameters::<u64> {
rlwe_secret_key_dist: SecretKeyDistribution::TernaryDistribution,
lwe_secret_key_dist: SecretKeyDistribution::TernaryDistribution,
rlwe_q: CiphertextModulus::new_non_native(18014398509404161),
lwe_q: CiphertextModulus::new_non_native(1 << 17),
br_q: 1 << 12,
rlwe_n: PolynomialSize(1 << 11),
lwe_n: LweDimension(660),
lwe_decomposer_params: (DecompostionLogBase(1), DecompositionCount(14)),
rlrg_decomposer_params: (
DecompostionLogBase(17),
(DecompositionCount(1), DecompositionCount(1)),
),
rgrg_decomposer_params: Some((
DecompostionLogBase(5),
(DecompositionCount(9), DecompositionCount(8)),
)),
auto_decomposer_params: (DecompostionLogBase(24), DecompositionCount(1)),
non_interactive_ui_to_s_key_switch_decomposer: None,
g: 5,
w: 10,
variant: ParameterVariant::InteractiveMultiParty,
};
pub(crate) const NI_2P: BoolParameters<u64> = BoolParameters::<u64> {
rlwe_secret_key_dist: SecretKeyDistribution::TernaryDistribution,
lwe_secret_key_dist: SecretKeyDistribution::ErrorDistribution,
rlwe_q: CiphertextModulus::new_non_native(18014398509404161),
lwe_q: CiphertextModulus::new_non_native(1 << 16),
br_q: 1 << 12,
rlwe_n: PolynomialSize(1 << 11),
lwe_n: LweDimension(520),
lwe_decomposer_params: (DecompostionLogBase(1), DecompositionCount(13)),
rlrg_decomposer_params: (
DecompostionLogBase(17),
(DecompositionCount(1), DecompositionCount(1)),
),
rgrg_decomposer_params: Some((
DecompostionLogBase(4),
(DecompositionCount(10), DecompositionCount(9)),
)),
auto_decomposer_params: (DecompostionLogBase(24), DecompositionCount(1)),
non_interactive_ui_to_s_key_switch_decomposer: Some((
DecompostionLogBase(1),
DecompositionCount(50),
)),
g: 5,
w: 10,
variant: ParameterVariant::NonInteractiveMultiParty,
};
pub(crate) const NI_4P_HB_FR: BoolParameters<u64> = BoolParameters::<u64> {
rlwe_secret_key_dist: SecretKeyDistribution::TernaryDistribution,
lwe_secret_key_dist: SecretKeyDistribution::TernaryDistribution,
rlwe_q: CiphertextModulus::new_non_native(18014398509404161),
lwe_q: CiphertextModulus::new_non_native(1 << 16),
br_q: 1 << 11,
rlwe_n: PolynomialSize(1 << 11),
lwe_n: LweDimension(620),
lwe_decomposer_params: (DecompostionLogBase(1), DecompositionCount(13)),
rlrg_decomposer_params: (
DecompostionLogBase(17),
(DecompositionCount(1), DecompositionCount(1)),
),
rgrg_decomposer_params: Some((
DecompostionLogBase(3),
(DecompositionCount(13), DecompositionCount(12)),
)),
auto_decomposer_params: (DecompostionLogBase(24), DecompositionCount(1)),
non_interactive_ui_to_s_key_switch_decomposer: Some((
DecompostionLogBase(1),
DecompositionCount(50),
)),
g: 5,
w: 10,
variant: ParameterVariant::NonInteractiveMultiParty,
};
pub(crate) const NI_4P_LB_SR: BoolParameters<u64> = BoolParameters::<u64> {
rlwe_secret_key_dist: SecretKeyDistribution::TernaryDistribution,
lwe_secret_key_dist: SecretKeyDistribution::TernaryDistribution,
rlwe_q: CiphertextModulus::new_non_native(18014398509404161),
lwe_q: CiphertextModulus::new_non_native(1 << 16),
br_q: 1 << 12,
rlwe_n: PolynomialSize(1 << 11),
lwe_n: LweDimension(620),
lwe_decomposer_params: (DecompostionLogBase(1), DecompositionCount(13)),
rlrg_decomposer_params: (
DecompostionLogBase(17),
(DecompositionCount(1), DecompositionCount(1)),
),
rgrg_decomposer_params: Some((
DecompostionLogBase(4),
(DecompositionCount(10), DecompositionCount(9)),
)),
auto_decomposer_params: (DecompostionLogBase(24), DecompositionCount(1)),
non_interactive_ui_to_s_key_switch_decomposer: Some((
DecompostionLogBase(1),
DecompositionCount(50),
)),
g: 5,
w: 10,
variant: ParameterVariant::NonInteractiveMultiParty,
};
pub(crate) const NI_8P: BoolParameters<u64> = BoolParameters::<u64> {
rlwe_secret_key_dist: SecretKeyDistribution::TernaryDistribution,
lwe_secret_key_dist: SecretKeyDistribution::TernaryDistribution,
rlwe_q: CiphertextModulus::new_non_native(18014398509404161),
lwe_q: CiphertextModulus::new_non_native(1 << 17),
br_q: 1 << 12,
rlwe_n: PolynomialSize(1 << 11),
lwe_n: LweDimension(660),
lwe_decomposer_params: (DecompostionLogBase(1), DecompositionCount(14)),
rlrg_decomposer_params: (
DecompostionLogBase(17),
(DecompositionCount(1), DecompositionCount(1)),
),
rgrg_decomposer_params: Some((
DecompostionLogBase(2),
(DecompositionCount(20), DecompositionCount(18)),
)),
auto_decomposer_params: (DecompostionLogBase(24), DecompositionCount(1)),
non_interactive_ui_to_s_key_switch_decomposer: Some((
DecompostionLogBase(1),
DecompositionCount(50),
)),
g: 5,
w: 10,
variant: ParameterVariant::NonInteractiveMultiParty,
};
#[cfg(test)]
pub(crate) const SP_TEST_BOOL_PARAMS: BoolParameters<u64> = BoolParameters::<u64> {
rlwe_secret_key_dist: SecretKeyDistribution::TernaryDistribution,
lwe_secret_key_dist: SecretKeyDistribution::ErrorDistribution,
rlwe_q: CiphertextModulus::new_non_native(268369921u64),
lwe_q: CiphertextModulus::new_non_native(1 << 16),
br_q: 1 << 9,
rlwe_n: PolynomialSize(1 << 9),
lwe_n: LweDimension(100),
lwe_decomposer_params: (DecompostionLogBase(4), DecompositionCount(4)),
rlrg_decomposer_params: (
DecompostionLogBase(7),
(DecompositionCount(4), DecompositionCount(4)),
),
rgrg_decomposer_params: None,
auto_decomposer_params: (DecompostionLogBase(7), DecompositionCount(4)),
non_interactive_ui_to_s_key_switch_decomposer: None,
g: 5,
w: 5,
variant: ParameterVariant::SingleParty,
};
// #[cfg(test)]
// mod tests {
// #[test]
// fn find_prime() {
// let bits = 60;
// let ring_size = 1 << 11;
// let prime = crate::utils::generate_prime(bits, ring_size * 2, 1 <<
// bits).unwrap(); dbg!(prime);
// }
// }

+ 1020
- 0
src/bool/print_noise.rs
File diff suppressed because it is too large
View File


+ 295
- 88
src/decomposer.rs

@ -1,32 +1,114 @@
use itertools::Itertools;
use num_traits::{AsPrimitive, One, PrimInt, ToPrimitive, WrappingSub, Zero};
use std::{fmt::Debug, marker::PhantomData, ops::Rem};
use itertools::{izip, Itertools};
use num_traits::{FromPrimitive, PrimInt, ToPrimitive, WrappingAdd, WrappingSub};
use std::fmt::{Debug, Display};
use crate::backend::{ArithmeticOps, ModularOpsU64};
use crate::{
backend::ArithmeticOps,
parameters::{
DecompositionCount, DecompostionLogBase, DoubleDecomposerParams, SingleDecomposerParams,
},
utils::log2,
};
pub fn gadget_vector<T: PrimInt>(logq: usize, logb: usize, d: usize) -> Vec<T> {
let d_ideal = (logq as f64 / logb as f64).ceil().to_usize().unwrap();
let ignored_limbs = d_ideal - d;
(ignored_limbs..ignored_limbs + d)
fn gadget_vector<T: PrimInt>(logq: usize, logb: usize, d: usize) -> Vec<T> {
assert!(logq >= (logb * d));
let ignored_bits = logq - (logb * d);
(0..d)
.into_iter() .into_iter()
.map(|i| T::one() << (logb * i))
.map(|i| T::one() << (logb * i + ignored_bits))
.collect_vec() .collect_vec()
} }
pub trait RlweDecomposer {
type Element;
type D: Decomposer<Element = Self::Element>;
/// Decomposer for RLWE Part A
fn a(&self) -> &Self::D;
/// Decomposer for RLWE Part B
fn b(&self) -> &Self::D;
}
impl<D> RlweDecomposer for (D, D)
where
D: Decomposer,
{
type D = D;
type Element = D::Element;
fn a(&self) -> &Self::D {
&self.0
}
fn b(&self) -> &Self::D {
&self.1
}
}
impl<D> DoubleDecomposerParams for D
where
D: RlweDecomposer,
{
type Base = DecompostionLogBase;
type Count = DecompositionCount;
fn decomposition_base(&self) -> Self::Base {
assert!(
Decomposer::decomposition_base(self.a()) == Decomposer::decomposition_base(self.b())
);
Decomposer::decomposition_base(self.a())
}
fn decomposition_count_a(&self) -> Self::Count {
Decomposer::decomposition_count(self.a())
}
fn decomposition_count_b(&self) -> Self::Count {
Decomposer::decomposition_count(self.b())
}
}
impl<D> SingleDecomposerParams for D
where
D: Decomposer,
{
type Base = DecompostionLogBase;
type Count = DecompositionCount;
fn decomposition_base(&self) -> Self::Base {
Decomposer::decomposition_base(self)
}
fn decomposition_count(&self) -> Self::Count {
Decomposer::decomposition_count(self)
}
}
pub trait Decomposer { pub trait Decomposer {
type Element; type Element;
//FIXME(Jay): there's no reason why it returns a vec instead of an iterator
fn decompose(&self, v: &Self::Element) -> Vec<Self::Element>;
fn d(&self) -> usize;
type Iter: Iterator<Item = Self::Element>;
fn new(q: Self::Element, logb: usize, d: usize) -> Self;
fn decompose_to_vec(&self, v: &Self::Element) -> Vec<Self::Element>;
fn decompose_iter(&self, v: &Self::Element) -> Self::Iter;
fn decomposition_count(&self) -> DecompositionCount;
fn decomposition_base(&self) -> DecompostionLogBase;
fn gadget_vector(&self) -> Vec<Self::Element>;
} }
pub struct DefaultDecomposer<T> { pub struct DefaultDecomposer<T> {
/// Ciphertext modulus
q: T, q: T,
/// Log of ciphertext modulus
logq: usize, logq: usize,
/// Log of base B
logb: usize, logb: usize,
/// base B
b: T,
/// (B - 1). To simulate (% B) as &(B-1), that is extract least significant
/// logb bits
b_mask: T,
/// B/2
bby2: T,
/// Decomposition count
d: usize, d: usize,
/// No. of bits to ignore in rounding
ignore_bits: usize, ignore_bits: usize,
ignore_limbs: usize,
} }
pub trait NumInfo { pub trait NumInfo {
@ -44,121 +126,246 @@ impl NumInfo for u128 {
} }
impl<T: PrimInt + NumInfo + Debug> DefaultDecomposer<T> { impl<T: PrimInt + NumInfo + Debug> DefaultDecomposer<T> {
pub fn new(q: T, logb: usize, d: usize) -> DefaultDecomposer<T> {
// if q is power of 2, then BITS - leading zeros outputs logq + 1.
let logq = if q & (q - T::one()) == T::zero() {
(T::BITS - q.leading_zeros() - 1) as usize
} else {
(T::BITS - q.leading_zeros()) as usize
};
fn recompose<Op>(&self, limbs: &[T], modq_op: &Op) -> T
where
Op: ArithmeticOps<Element = T>,
{
let mut value = T::zero();
let gadget_vector = gadget_vector(self.logq, self.logb, self.d);
assert!(limbs.len() == gadget_vector.len());
izip!(limbs.iter(), gadget_vector.iter())
.for_each(|(d_el, beta)| value = modq_op.add(&value, &modq_op.mul(d_el, beta)));
value
}
}
impl<
T: PrimInt
+ ToPrimitive
+ FromPrimitive
+ WrappingSub
+ WrappingAdd
+ NumInfo
+ From<bool>
+ Display
+ Debug,
> Decomposer for DefaultDecomposer<T>
{
type Element = T;
type Iter = DecomposerIter<T>;
fn new(q: T, logb: usize, d: usize) -> DefaultDecomposer<T> {
// if q is power of 2, then `BITS - leading_zeros` outputs logq + 1.
let logq = log2(&q);
assert!(
logq >= (logb * d),
"Decomposer wants logq >= logb*d but got logq={logq}, logb={logb}, d={d}"
);
let d_ideal = (logq as f64 / logb as f64).ceil().to_usize().unwrap();
let ignore_limbs = (d_ideal - d);
let ignore_bits = (d_ideal - d) * logb;
let ignore_bits = logq - (logb * d);
DefaultDecomposer { DefaultDecomposer {
q, q,
logq, logq,
logb, logb,
b: T::one() << logb,
b_mask: (T::one() << logb) - T::one(),
bby2: T::one() << (logb - 1),
d, d,
ignore_bits, ignore_bits,
ignore_limbs,
} }
} }
fn recompose<Op>(&self, limbs: &[T], modq_op: &Op) -> T
where
Op: ArithmeticOps<Element = T>,
{
let mut value = T::zero();
for i in self.ignore_limbs..self.ignore_limbs + self.d {
value = modq_op.add(
&value,
&(modq_op.mul(&limbs[i], &(T::one() << (self.logb * i)))),
)
fn decompose_to_vec(&self, value: &T) -> Vec<T> {
let q = self.q;
let logb = self.logb;
let b = T::one() << logb;
let full_mask = b - T::one();
let bby2 = b >> 1;
let mut value = *value;
if value >= (q >> 1) {
value = !(q - value) + T::one()
} }
value
}
}
value = round_value(value, self.ignore_bits);
let mut out = Vec::with_capacity(self.d);
for _ in 0..(self.d) {
let k_i = value & full_mask;
impl<T: PrimInt + WrappingSub + Debug> Decomposer for DefaultDecomposer<T> {
type Element = T;
fn decompose(&self, value: &T) -> Vec<T> {
let value = round_value(*value, self.ignore_bits);
value = (value - k_i) >> logb;
let q = self.q;
let logb = self.logb;
// let b = T::one() << logb; // base
let b_by2 = T::one() << (logb - 1);
// let neg_b_by2_modq = q - b_by2;
let full_mask = (T::one() << logb) - T::one();
// let half_mask = b_by2 - T::one();
let mut carry = T::zero();
let mut out = Vec::<T>::with_capacity(self.d);
for i in 0..self.d {
let mut limb = ((value >> (logb * i)) & full_mask) + carry;
carry = limb & b_by2;
limb = (q + limb) - (carry << 1);
if limb > q {
limb = limb - q;
if k_i > bby2 || (k_i == bby2 && ((value & T::one()) == T::one())) {
out.push(q - (b - k_i));
value = value + T::one();
} else {
out.push(k_i);
} }
out.push(limb);
carry = carry >> (logb - 1);
} }
return out; return out;
} }
fn d(&self) -> usize {
self.d
fn decomposition_count(&self) -> DecompositionCount {
DecompositionCount(self.d)
}
fn decomposition_base(&self) -> DecompostionLogBase {
DecompostionLogBase(self.logb)
}
fn decompose_iter(&self, value: &T) -> DecomposerIter<T> {
let mut value = *value;
if value >= (self.q >> 1) {
value = !(self.q - value) + T::one()
}
value = round_value(value, self.ignore_bits);
DecomposerIter {
value,
q: self.q,
logq: self.logq,
logb: self.logb,
b: self.b,
bby2: self.bby2,
b_mask: self.b_mask,
steps_left: self.d,
}
}
fn gadget_vector(&self) -> Vec<T> {
return gadget_vector(self.logq, self.logb, self.d);
} }
} }
fn round_value<T: PrimInt>(value: T, ignore_bits: usize) -> T {
impl<T: PrimInt> DefaultDecomposer<T> {}
pub struct DecomposerIter<T> {
/// Value to decompose
value: T,
steps_left: usize,
/// (1 << logb) - 1 (for % (1<<logb); i.e. to extract least signiciant logb
/// bits)
b_mask: T,
logb: usize,
// b/2 = 1 << (logb-1)
bby2: T,
/// Ciphertext modulus
q: T,
/// Log of ciphertext modulus
logq: usize,
/// b = 1 << logb
b: T,
}
impl<T: PrimInt + From<bool> + WrappingSub + Display> Iterator for DecomposerIter<T> {
type Item = T;
fn next(&mut self) -> Option<Self::Item> {
if self.steps_left != 0 {
self.steps_left -= 1;
let k_i = self.value & self.b_mask;
self.value = (self.value - k_i) >> self.logb;
// if k_i > self.bby2 || (k_i == self.bby2 && ((self.value &
// T::one()) == T::one())) { self.value = self.value
// + T::one(); Some(self.q + k_i - self.b)
// } else {
// Some(k_i)
// }
// Following is without branching impl of the commented version above. It
// happens to speed up bootstrapping for `SMALL_MP_BOOL_PARAMS` (& other
// parameters as well but I haven't tested) by roughly 15ms.
// Suprisingly the improvement does not show up when I benchmark
// `decomposer_iter` in isolation. Putting this remark here as a
// future task to investiage (TODO).
let carry_bool =
k_i > self.bby2 || (k_i == self.bby2 && ((self.value & T::one()) == T::one()));
let carry = <T as From<bool>>::from(carry_bool);
let neg_carry = (T::zero().wrapping_sub(&carry));
self.value = self.value + carry;
Some((neg_carry & self.q) + k_i - (carry << self.logb))
// Some(
// (self.q & ((carry << self.logq) - (T::one() & carry))) + k_i
// - (carry << self.logb), )
// Some(k_i)
} else {
None
}
}
}
fn round_value<T: PrimInt + WrappingAdd>(value: T, ignore_bits: usize) -> T {
if ignore_bits == 0 { if ignore_bits == 0 {
return value; return value;
} }
let ignored_msb = (value & ((T::one() << ignore_bits) - T::one())) >> (ignore_bits - 1); let ignored_msb = (value & ((T::one() << ignore_bits) - T::one())) >> (ignore_bits - 1);
(value >> ignore_bits) + ignored_msb
(value >> ignore_bits).wrapping_add(&ignored_msb)
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use itertools::Itertools;
use rand::{thread_rng, Rng}; use rand::{thread_rng, Rng};
use crate::{backend::ModularOpsU64, decomposer::round_value, utils::generate_prime};
use crate::{
backend::{ModInit, ModularOpsU64},
decomposer::round_value,
utils::generate_prime,
};
use super::{Decomposer, DefaultDecomposer}; use super::{Decomposer, DefaultDecomposer};
#[test] #[test]
fn decomposition_works() { fn decomposition_works() {
let logq = 50;
let logb = 5;
let d = 10;
// q is prime of bits logq and i is true, other q = 1<<logq
for i in [true, false] {
let q = if i {
generate_prime(logq, 1 << 4, 1u64 << logq).unwrap()
} else {
1u64 << 50
};
let decomposer = DefaultDecomposer::new(q, logb, d);
let modq_op = ModularOpsU64::new(q);
for _ in 0..100 {
let value = 1000000;
let limbs = decomposer.decompose(&value);
let value_back = decomposer.recompose(&limbs, &modq_op);
let rounded_value = round_value(value, decomposer.ignore_bits);
assert_eq!(
rounded_value, value_back,
"Expected {rounded_value} got {value_back} for q={q}"
);
let ring_size = 1 << 11;
let mut rng = thread_rng();
for logq in [37, 55] {
let logb = 11;
let d = 3;
// let mut stats = vec![Stats::new(); d];
for i in [true, false] {
let q = if i {
generate_prime(logq, 2 * ring_size, 1u64 << logq).unwrap()
} else {
1u64 << logq
};
let decomposer = DefaultDecomposer::new(q, logb, d);
let modq_op = ModularOpsU64::new(q);
for _ in 0..1000000 {
let value = rng.gen_range(0..q);
let limbs = decomposer.decompose_to_vec(&value);
let limbs_from_iter = decomposer.decompose_iter(&value).collect_vec();
assert_eq!(limbs, limbs_from_iter);
let value_back = round_value(
decomposer.recompose(&limbs, &modq_op),
decomposer.ignore_bits,
);
let rounded_value = round_value(value, decomposer.ignore_bits);
assert!((rounded_value as i64 - value_back as i64).abs() <= 1,);
// izip!(stats.iter_mut(), limbs.iter()).for_each(|(s, l)| {
// s.add_more(&vec![q.map_element_to_i64(l)]);
// });
}
} }
// stats.iter().enumerate().for_each(|(index, s)| {
// println!(
// "Limb {index} - Mean: {}, Std: {}",
// s.mean(),
// s.std_dev().abs().log2()
// );
// });
} }
} }
} }

+ 121
- 11
src/lib.rs

@ -1,18 +1,29 @@
use itertools::{izip, Itertools};
use num::UnsignedInteger;
use num_traits::{abs, Zero};
use rand::CryptoRng;
use random::{RandomGaussianDist, RandomUniformDist};
use utils::TryConvertFrom;
use num_traits::Zero;
mod backend; mod backend;
mod bool;
mod decomposer; mod decomposer;
mod lwe; mod lwe;
mod multi_party;
mod ntt; mod ntt;
mod num;
mod pbs;
mod random; mod random;
mod rgsw; mod rgsw;
#[cfg(any(feature = "interactive_mp", feature = "non_interactive_mp"))]
mod shortint;
mod utils; mod utils;
pub use backend::{
ArithmeticLazyOps, ArithmeticOps, ModInit, ModularOpsU64, ShoupMatrixFMA, VectorOps,
};
pub use bool::*;
pub use ntt::{Ntt, NttBackendU64, NttInit};
#[cfg(any(feature = "interactive_mp", feature = "non_interactive_mp"))]
pub use shortint::{div_zero_error_flag, reset_error_flags, FheUint8};
pub use decomposer::{Decomposer, DecomposerIter, DefaultDecomposer};
pub trait Matrix: AsRef<[Self::R]> { pub trait Matrix: AsRef<[Self::R]> {
type MatElement; type MatElement;
type R: Row<Element = Self::MatElement>; type R: Row<Element = Self::MatElement>;
@ -34,6 +45,13 @@ pub trait Matrix: AsRef<[Self::R]> {
fn get(&self, row_idx: usize, column_idx: usize) -> &Self::MatElement { fn get(&self, row_idx: usize, column_idx: usize) -> &Self::MatElement {
&self.as_ref()[row_idx].as_ref()[column_idx] &self.as_ref()[row_idx].as_ref()[column_idx]
} }
fn split_at_row(&self, idx: usize) -> (&[<Self as Matrix>::R], &[<Self as Matrix>::R]) {
self.as_ref().split_at(idx)
}
/// Does the matrix fit sub-matrix of dimension row x col
fn fits(&self, row: usize, col: usize) -> bool;
} }
pub trait MatrixMut: Matrix + AsMut<[<Self as Matrix>::R]> pub trait MatrixMut: Matrix + AsMut<[<Self as Matrix>::R]>
@ -52,7 +70,7 @@ where
self.as_mut()[row_idx].as_mut()[column_idx] = val; self.as_mut()[row_idx].as_mut()[column_idx] = val;
} }
fn split_at_row(
fn split_at_row_mut(
&mut self, &mut self,
idx: usize, idx: usize,
) -> (&mut [<Self as Matrix>::R], &mut [<Self as Matrix>::R]) { ) -> (&mut [<Self as Matrix>::R], &mut [<Self as Matrix>::R]) {
@ -72,9 +90,8 @@ pub trait Row: AsRef<[Self::Element]> {
pub trait RowMut: Row + AsMut<[<Self as Row>::Element]> {} pub trait RowMut: Row + AsMut<[<Self as Row>::Element]> {}
trait Secret {
type Element;
fn values(&self) -> &[Self::Element];
pub trait RowEntity: Row {
fn zeros(col: usize) -> Self;
} }
impl<T> Matrix for Vec<Vec<T>> { impl<T> Matrix for Vec<Vec<T>> {
@ -84,9 +101,40 @@ impl Matrix for Vec> {
fn dimension(&self) -> (usize, usize) { fn dimension(&self) -> (usize, usize) {
(self.len(), self[0].len()) (self.len(), self[0].len())
} }
fn fits(&self, row: usize, col: usize) -> bool {
self.len() >= row && self[0].len() >= col
}
}
impl<T> Matrix for &[Vec<T>] {
type MatElement = T;
type R = Vec<T>;
fn dimension(&self) -> (usize, usize) {
(self.len(), self[0].len())
}
fn fits(&self, row: usize, col: usize) -> bool {
self.len() >= row && self[0].len() >= col
}
}
impl<T> Matrix for &mut [Vec<T>] {
type MatElement = T;
type R = Vec<T>;
fn dimension(&self) -> (usize, usize) {
(self.len(), self[0].len())
}
fn fits(&self, row: usize, col: usize) -> bool {
self.len() >= row && self[0].len() >= col
}
} }
impl<T> MatrixMut for Vec<Vec<T>> {} impl<T> MatrixMut for Vec<Vec<T>> {}
impl<T> MatrixMut for &mut [Vec<T>] {}
impl<T: Zero + Clone> MatrixEntity for Vec<Vec<T>> { impl<T: Zero + Clone> MatrixEntity for Vec<Vec<T>> {
fn zeros(row: usize, col: usize) -> Self { fn zeros(row: usize, col: usize) -> Self {
@ -98,4 +146,66 @@ impl Row for Vec {
type Element = T; type Element = T;
} }
impl<T> Row for [T] {
type Element = T;
}
impl<T> RowMut for Vec<T> {} impl<T> RowMut for Vec<T> {}
impl<T: Zero + Clone> RowEntity for Vec<T> {
fn zeros(col: usize) -> Self {
vec![T::zero(); col]
}
}
pub trait Encryptor<M: ?Sized, C> {
fn encrypt(&self, m: &M) -> C;
}
pub trait Decryptor<M, C> {
fn decrypt(&self, c: &C) -> M;
}
pub trait MultiPartyDecryptor<M, C> {
type DecryptionShare;
fn gen_decryption_share(&self, c: &C) -> Self::DecryptionShare;
fn aggregate_decryption_shares(&self, c: &C, shares: &[Self::DecryptionShare]) -> M;
}
pub trait KeySwitchWithId<C> {
fn key_switch(&self, user_id: usize) -> C;
}
pub trait SampleExtractor<R> {
/// Extract ciphertext at `index`
fn extract_at(&self, index: usize) -> R;
/// Extract all ciphertexts
fn extract_all(&self) -> Vec<R>;
/// Extract first `how_many` ciphertexts
fn extract_many(&self, how_many: usize) -> Vec<R>;
}
trait Encoder<F, T> {
fn encode(&self, v: F) -> T;
}
trait SizeInBitsWithLogModulus {
/// Returns size of `Self` containing several elements mod Q where
/// 2^{log_modulus-1} < Q <= `2^log_modulus`
fn size(&self, log_modulus: usize) -> usize;
}
impl SizeInBitsWithLogModulus for Vec<Vec<u64>> {
fn size(&self, log_modulus: usize) -> usize {
let mut total = 0;
self.iter().for_each(|r| total += log_modulus * r.len());
total
}
}
impl SizeInBitsWithLogModulus for Vec<u64> {
fn size(&self, log_modulus: usize) -> usize {
self.len() * log_modulus
}
}

+ 221
- 174
src/lwe.rs

@ -1,220 +1,250 @@
use std::fmt::Debug; use std::fmt::Debug;
use itertools::{izip, Itertools};
use num_traits::{abs, Zero};
use itertools::izip;
use num_traits::Zero;
use crate::{ use crate::{
backend::{ArithmeticOps, VectorOps},
backend::{ArithmeticOps, GetModulus, VectorOps},
decomposer::Decomposer, decomposer::Decomposer,
lwe,
num::UnsignedInteger,
random::{DefaultSecureRng, RandomGaussianDist, RandomUniformDist, DEFAULT_RNG},
utils::{fill_random_ternary_secret_with_hamming_weight, TryConvertFrom, WithLocal},
Matrix, MatrixEntity, MatrixMut, Row, RowMut, Secret,
random::{RandomFillUniformInModulus, RandomGaussianElementInModulus},
utils::TryConvertFrom1,
Matrix, Row, RowEntity, RowMut,
}; };
trait LweKeySwitchParameters {
fn n_in(&self) -> usize;
fn n_out(&self) -> usize;
fn d_ks(&self) -> usize;
}
trait LweCiphertext<M: Matrix> {}
struct LweSecret {
values: Vec<i32>,
}
impl Secret for LweSecret {
type Element = i32;
fn values(&self) -> &[Self::Element] {
&self.values
}
}
impl LweSecret {
fn random(hw: usize, n: usize) -> LweSecret {
DefaultSecureRng::with_local_mut(|rng| {
let mut out = vec![0i32; n];
fill_random_ternary_secret_with_hamming_weight(&mut out, hw, rng);
LweSecret { values: out }
})
}
}
fn lwe_key_switch<
pub(crate) fn lwe_key_switch<
M: Matrix, M: Matrix,
Mmut: MatrixMut<MatElement = M::MatElement> + MatrixEntity,
Ro: AsMut<[M::MatElement]> + AsRef<[M::MatElement]>,
Op: VectorOps<Element = M::MatElement> + ArithmeticOps<Element = M::MatElement>, Op: VectorOps<Element = M::MatElement> + ArithmeticOps<Element = M::MatElement>,
D: Decomposer<Element = M::MatElement>, D: Decomposer<Element = M::MatElement>,
>( >(
lwe_out: &mut Mmut,
lwe_in: &M,
lwe_out: &mut Ro,
lwe_in: &Ro,
lwe_ksk: &M, lwe_ksk: &M,
operator: &Op, operator: &Op,
decomposer: &D, decomposer: &D,
) where
<Mmut as Matrix>::R: RowMut,
{
assert!(lwe_ksk.dimension().0 == ((lwe_in.dimension().1 - 1) * decomposer.d()));
assert!(lwe_out.dimension() == (1, lwe_ksk.dimension().1));
let mut scratch_space = Mmut::zeros(1, lwe_out.dimension().1);
) {
assert!(
lwe_ksk.dimension().0 == ((lwe_in.as_ref().len() - 1) * decomposer.decomposition_count().0)
);
assert!(lwe_out.as_ref().len() == lwe_ksk.dimension().1);
let lwe_in_a_decomposed = lwe_in let lwe_in_a_decomposed = lwe_in
.get_row(0)
.as_ref()
.iter()
.skip(1) .skip(1)
.flat_map(|ai| decomposer.decompose(ai));
.flat_map(|ai| decomposer.decompose_iter(ai));
izip!(lwe_in_a_decomposed, lwe_ksk.iter_rows()).for_each(|(ai_j, beta_ij_lwe)| { izip!(lwe_in_a_decomposed, lwe_ksk.iter_rows()).for_each(|(ai_j, beta_ij_lwe)| {
operator.elwise_scalar_mul(scratch_space.get_row_mut(0), beta_ij_lwe.as_ref(), &ai_j);
operator.elwise_add_mut(lwe_out.get_row_mut(0), scratch_space.get_row_slice(0))
// let now = std::time::Instant::now();
operator.elwise_fma_scalar_mut(lwe_out.as_mut(), beta_ij_lwe.as_ref(), &ai_j);
// println!("Time elwise_fma_scalar_mut: {:?}", now.elapsed());
}); });
let out_b = operator.add(lwe_out.get(0, 0), lwe_in.get(0, 0));
lwe_out.set(0, 0, out_b);
let out_b = operator.add(&lwe_out.as_ref()[0], &lwe_in.as_ref()[0]);
lwe_out.as_mut()[0] = out_b;
} }
fn lwe_ksk_keygen<
Mmut: MatrixMut,
S: Secret,
Op: VectorOps<Element = Mmut::MatElement> + ArithmeticOps<Element = Mmut::MatElement>,
R: RandomGaussianDist<Mmut::MatElement, Parameters = Mmut::MatElement>
+ RandomUniformDist<[Mmut::MatElement], Parameters = Mmut::MatElement>,
pub(crate) fn seeded_lwe_ksk_keygen<
Ro: RowMut + RowEntity,
S,
Op: VectorOps<Element = Ro::Element>
+ ArithmeticOps<Element = Ro::Element>
+ GetModulus<Element = Ro::Element>,
R: RandomGaussianElementInModulus<Ro::Element, Op::M>,
PR: RandomFillUniformInModulus<[Ro::Element], Op::M>,
>( >(
lwe_sk_in: &S,
lwe_sk_out: &S,
ksk_out: &mut Mmut,
gadget: &[Mmut::MatElement],
from_lwe_sk: &[S],
to_lwe_sk: &[S],
gadget: &[Ro::Element],
operator: &Op, operator: &Op,
p_rng: &mut PR,
rng: &mut R, rng: &mut R,
) where
<Mmut as Matrix>::R: RowMut,
Mmut: TryConvertFrom<[S::Element], Parameters = Mmut::MatElement>,
Mmut::MatElement: Zero + Debug,
) -> Ro
where
Ro: TryConvertFrom1<[S], Op::M>,
Ro::Element: Zero + Debug,
{ {
assert!(
ksk_out.dimension()
== (
lwe_sk_in.values().len() * gadget.len(),
lwe_sk_out.values().len() + 1,
)
);
let mut ksk_out = Ro::zeros(from_lwe_sk.len() * gadget.len());
let d = gadget.len(); let d = gadget.len();
let modulus = VectorOps::modulus(operator);
let mut neg_sk_in_m = Mmut::try_convert_from(lwe_sk_in.values(), &modulus);
operator.elwise_neg_mut(neg_sk_in_m.get_row_mut(0));
let sk_out_m = Mmut::try_convert_from(lwe_sk_out.values(), &modulus);
izip!(
neg_sk_in_m.get_row(0),
ksk_out.iter_rows_mut().chunks(d).into_iter()
)
.for_each(|(neg_sk_in_si, d_ks_lwes)| {
izip!(gadget.iter(), d_ks_lwes.into_iter()).for_each(|(f, lwe)| {
// sample `a`
RandomUniformDist::random_fill(rng, &modulus, &mut lwe.as_mut()[1..]);
// a * z
let mut az = Mmut::MatElement::zero();
izip!(lwe.as_ref()[1..].iter(), sk_out_m.get_row(0)).for_each(|(ai, si)| {
let ai_si = operator.mul(ai, si);
az = operator.add(&az, &ai_si);
});
// a*z + (-s_i)*\beta^j + e
let mut b = operator.add(&az, &operator.mul(f, neg_sk_in_si));
let mut e = Mmut::MatElement::zero();
RandomGaussianDist::random_fill(rng, &modulus, &mut e);
b = operator.add(&b, &e);
lwe.as_mut()[0] = b;
let modulus = operator.modulus();
let mut neg_sk_in_m = Ro::try_convert_from(from_lwe_sk, modulus);
operator.elwise_neg_mut(neg_sk_in_m.as_mut());
let sk_out_m = Ro::try_convert_from(to_lwe_sk, modulus);
let mut scratch = Ro::zeros(to_lwe_sk.len());
izip!(neg_sk_in_m.as_ref(), ksk_out.as_mut().chunks_mut(d)).for_each(
|(neg_sk_in_si, d_lwes_partb)| {
izip!(gadget.iter(), d_lwes_partb.into_iter()).for_each(|(beta, lwe_b)| {
// sample `a`
RandomFillUniformInModulus::random_fill(p_rng, &modulus, scratch.as_mut());
// a * z
let mut az = Ro::Element::zero();
izip!(scratch.as_ref().iter(), sk_out_m.as_ref()).for_each(|(ai, si)| {
let ai_si = operator.mul(ai, si);
az = operator.add(&az, &ai_si);
});
// a*z + (-s_i)*\beta^j + e
let mut b = operator.add(&az, &operator.mul(beta, neg_sk_in_si));
let e = RandomGaussianElementInModulus::random(rng, &modulus);
b = operator.add(&b, &e);
*lwe_b = b;
})
},
);
// dbg!(&lwe.as_mut(), &f);
})
});
ksk_out
} }
/// Encrypts encoded message m as LWE ciphertext /// Encrypts encoded message m as LWE ciphertext
fn encrypt_lwe<
Mmut: MatrixMut + MatrixEntity,
R: RandomGaussianDist<Mmut::MatElement, Parameters = Mmut::MatElement>
+ RandomUniformDist<[Mmut::MatElement], Parameters = Mmut::MatElement>,
S: Secret,
Op: ArithmeticOps<Element = Mmut::MatElement>,
pub(crate) fn encrypt_lwe<
Ro: RowMut + RowEntity,
Op: ArithmeticOps<Element = Ro::Element> + GetModulus<Element = Ro::Element>,
R: RandomGaussianElementInModulus<Ro::Element, Op::M>
+ RandomFillUniformInModulus<[Ro::Element], Op::M>,
S,
>( >(
lwe_out: &mut Mmut,
m: Mmut::MatElement,
s: &S,
m: &Ro::Element,
s: &[S],
operator: &Op, operator: &Op,
rng: &mut R, rng: &mut R,
) where
Mmut: TryConvertFrom<[S::Element], Parameters = Mmut::MatElement>,
Mmut::MatElement: Zero,
<Mmut as Matrix>::R: RowMut,
) -> Ro
where
Ro: TryConvertFrom1<[S], Op::M>,
Ro::Element: Zero,
{ {
let s = Mmut::try_convert_from(s.values(), &operator.modulus());
assert!(s.dimension().0 == (lwe_out.dimension().0));
assert!(s.dimension().1 == (lwe_out.dimension().1 - 1));
let s = Ro::try_convert_from(s, operator.modulus());
let mut lwe_out = Ro::zeros(s.as_ref().len() + 1);
// a*s // a*s
RandomUniformDist::random_fill(rng, &operator.modulus(), &mut lwe_out.get_row_mut(0)[1..]);
let mut sa = Mmut::MatElement::zero();
izip!(lwe_out.get_row(0).skip(1), s.get_row(0)).for_each(|(ai, si)| {
RandomFillUniformInModulus::random_fill(rng, operator.modulus(), &mut lwe_out.as_mut()[1..]);
let mut sa = Ro::Element::zero();
izip!(lwe_out.as_mut().iter().skip(1), s.as_ref()).for_each(|(ai, si)| {
let tmp = operator.mul(ai, si); let tmp = operator.mul(ai, si);
sa = operator.add(&tmp, &sa); sa = operator.add(&tmp, &sa);
}); });
// b = a*s + e + m // b = a*s + e + m
let mut e = Mmut::MatElement::zero();
RandomGaussianDist::random_fill(rng, &operator.modulus(), &mut e);
let b = operator.add(&operator.add(&sa, &e), &m);
lwe_out.set(0, 0, b);
let e = RandomGaussianElementInModulus::random(rng, operator.modulus());
let b = operator.add(&operator.add(&sa, &e), m);
lwe_out.as_mut()[0] = b;
lwe_out
} }
fn decrypt_lwe<M: Matrix, Op: ArithmeticOps<Element = M::MatElement>, S: Secret>(
lwe_ct: &M,
s: &S,
pub(crate) fn decrypt_lwe<
Ro: Row,
Op: ArithmeticOps<Element = Ro::Element> + GetModulus<Element = Ro::Element>,
S,
>(
lwe_ct: &Ro,
s: &[S],
operator: &Op, operator: &Op,
) -> M::MatElement
) -> Ro::Element
where where
M: TryConvertFrom<[S::Element], Parameters = M::MatElement>,
M::MatElement: Zero,
Ro: TryConvertFrom1<[S], Op::M>,
Ro::Element: Zero,
{ {
let s = M::try_convert_from(s.values(), &operator.modulus());
let s = Ro::try_convert_from(s, operator.modulus());
let mut sa = M::MatElement::zero();
izip!(lwe_ct.get_row(0).skip(1), s.get_row(0)).for_each(|(ai, si)| {
let mut sa = Ro::Element::zero();
izip!(lwe_ct.as_ref().iter().skip(1), s.as_ref()).for_each(|(ai, si)| {
let tmp = operator.mul(ai, si); let tmp = operator.mul(ai, si);
sa = operator.add(&tmp, &sa); sa = operator.add(&tmp, &sa);
}); });
let b = &lwe_ct.get_row_slice(0)[0];
let b = &lwe_ct.as_ref()[0];
operator.sub(b, &sa) operator.sub(b, &sa)
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use std::marker::PhantomData;
use itertools::izip;
use crate::{ use crate::{
backend::ModularOpsU64,
decomposer::{gadget_vector, DefaultDecomposer},
lwe::lwe_key_switch,
random::DefaultSecureRng,
backend::{ModInit, ModulusPowerOf2},
decomposer::DefaultDecomposer,
random::{DefaultSecureRng, NewWithSeed},
utils::{fill_random_ternary_secret_with_hamming_weight, WithLocal},
MatrixEntity, MatrixMut,
}; };
use super::{decrypt_lwe, encrypt_lwe, lwe_ksk_keygen, LweSecret};
use super::*;
const K: usize = 50;
#[derive(Clone)]
struct LweSecret {
pub(crate) values: Vec<i32>,
}
impl LweSecret {
fn random(hw: usize, n: usize) -> LweSecret {
DefaultSecureRng::with_local_mut(|rng| {
let mut out = vec![0i32; n];
fill_random_ternary_secret_with_hamming_weight(&mut out, hw, rng);
LweSecret { values: out }
})
}
fn values(&self) -> &[i32] {
&self.values
}
}
struct LweKeySwitchingKey<M, R> {
data: M,
_phantom: PhantomData<R>,
}
impl<
M: MatrixMut + MatrixEntity,
R: NewWithSeed + RandomFillUniformInModulus<[M::MatElement], M::MatElement>,
> From<&(M::R, R::Seed, usize, M::MatElement)> for LweKeySwitchingKey<M, R>
where
M::R: RowMut,
R::Seed: Clone,
M::MatElement: Copy,
{
fn from(value: &(M::R, R::Seed, usize, M::MatElement)) -> Self {
let data_in = &value.0;
let seed = &value.1;
let to_lwe_n = value.2;
let modulus = value.3;
let mut p_rng = R::new_with_seed(seed.clone());
let mut data = M::zeros(data_in.as_ref().len(), to_lwe_n + 1);
izip!(data_in.as_ref().iter(), data.iter_rows_mut()).for_each(|(bi, lwe_i)| {
RandomFillUniformInModulus::random_fill(
&mut p_rng,
&modulus,
&mut lwe_i.as_mut()[1..],
);
lwe_i.as_mut()[0] = *bi;
});
LweKeySwitchingKey {
data,
_phantom: PhantomData,
}
}
}
#[test] #[test]
fn encrypt_decrypt_works() { fn encrypt_decrypt_works() {
let logq = 20;
let logq = 16;
let q = 1u64 << logq; let q = 1u64 << logq;
let lwe_n = 1024; let lwe_n = 1024;
let logp = 3; let logp = 3;
let modq_op = ModularOpsU64::new(q);
let modq_op = ModulusPowerOf2::new(q);
let lwe_sk = LweSecret::random(lwe_n >> 1, lwe_n); let lwe_sk = LweSecret::random(lwe_n >> 1, lwe_n);
let mut rng = DefaultSecureRng::new(); let mut rng = DefaultSecureRng::new();
@ -222,9 +252,9 @@ mod tests {
// encrypt // encrypt
for m in 0..1u64 << logp { for m in 0..1u64 << logp {
let encoded_m = m << (logq - logp); let encoded_m = m << (logq - logp);
let mut lwe_ct = vec![vec![0u64; lwe_n + 1]];
encrypt_lwe(&mut lwe_ct, encoded_m, &lwe_sk, &modq_op, &mut rng);
let encoded_m_back = decrypt_lwe(&lwe_ct, &lwe_sk, &modq_op);
let lwe_ct =
encrypt_lwe::<Vec<u64>, _, _, _>(&encoded_m, &lwe_sk.values(), &modq_op, &mut rng);
let encoded_m_back = decrypt_lwe(&lwe_ct, &lwe_sk.values(), &modq_op);
let m_back = ((((encoded_m_back as f64) * ((1 << logp) as f64)) / q as f64).round() let m_back = ((((encoded_m_back as f64) * ((1 << logp) as f64)) / q as f64).round()
as u64) as u64)
% (1u64 << logp); % (1u64 << logp);
@ -234,52 +264,69 @@ mod tests {
#[test] #[test]
fn key_switch_works() { fn key_switch_works() {
let logq = 16;
let logp = 3;
let logq = 20;
let logp = 2;
let q = 1u64 << logq; let q = 1u64 << logq;
let lwe_in_n = 1024;
let lwe_out_n = 470;
let d_ks = 3;
let lwe_in_n = 2048;
let lwe_out_n = 600;
let d_ks = 5;
let logb = 4; let logb = 4;
let lwe_sk_in = LweSecret::random(lwe_in_n >> 1, lwe_in_n); let lwe_sk_in = LweSecret::random(lwe_in_n >> 1, lwe_in_n);
let lwe_sk_out = LweSecret::random(lwe_out_n >> 1, lwe_out_n); let lwe_sk_out = LweSecret::random(lwe_out_n >> 1, lwe_out_n);
let mut rng = DefaultSecureRng::new(); let mut rng = DefaultSecureRng::new();
let modq_op = ModularOpsU64::new(q);
let modq_op = ModulusPowerOf2::new(q);
// genrate ksk // genrate ksk
for _ in 0..10 {
let mut ksk = vec![vec![0u64; lwe_out_n + 1]; d_ks * lwe_in_n];
let gadget = gadget_vector(logq, logb, d_ks);
lwe_ksk_keygen(
&lwe_sk_in,
&lwe_sk_out,
&mut ksk,
for _ in 0..1 {
let mut ksk_seed = [0u8; 32];
rng.fill_bytes(&mut ksk_seed);
let mut p_rng = DefaultSecureRng::new_seeded(ksk_seed);
let decomposer = DefaultDecomposer::new(q, logb, d_ks);
let gadget = decomposer.gadget_vector();
let seeded_ksk = seeded_lwe_ksk_keygen(
&lwe_sk_in.values(),
&lwe_sk_out.values(),
&gadget, &gadget,
&modq_op, &modq_op,
&mut p_rng,
&mut rng, &mut rng,
); );
// println!("{:?}", ksk); // println!("{:?}", ksk);
let ksk = LweKeySwitchingKey::<Vec<Vec<u64>>, DefaultSecureRng>::from(&(
seeded_ksk, ksk_seed, lwe_out_n, q,
));
for m in 0..(1 << logp) { for m in 0..(1 << logp) {
// encrypt using lwe_sk_in // encrypt using lwe_sk_in
let encoded_m = m << (logq - logp); let encoded_m = m << (logq - logp);
let mut lwe_in_ct = vec![vec![0u64; lwe_in_n + 1]];
encrypt_lwe(&mut lwe_in_ct, encoded_m, &lwe_sk_in, &modq_op, &mut rng);
let lwe_in_ct = encrypt_lwe(&encoded_m, lwe_sk_in.values(), &modq_op, &mut rng);
// key switch from lwe_sk_in to lwe_sk_out // key switch from lwe_sk_in to lwe_sk_out
let decomposer = DefaultDecomposer::new(1u64 << logq, logb, d_ks);
let mut lwe_out_ct = vec![vec![0u64; lwe_out_n + 1]];
lwe_key_switch(&mut lwe_out_ct, &lwe_in_ct, &ksk, &modq_op, &decomposer);
let mut lwe_out_ct = vec![0u64; lwe_out_n + 1];
let now = std::time::Instant::now();
lwe_key_switch(
&mut lwe_out_ct,
&lwe_in_ct,
&ksk.data,
&modq_op,
&decomposer,
);
println!("Time: {:?}", now.elapsed());
// decrypt lwe_out_ct using lwe_sk_out // decrypt lwe_out_ct using lwe_sk_out
let encoded_m_back = decrypt_lwe(&lwe_out_ct, &lwe_sk_out, &modq_op);
let m_back = ((((encoded_m_back as f64) * ((1 << logp) as f64)) / q as f64).round()
as u64)
% (1u64 << logp);
assert_eq!(m, m_back, "Expected {m} but got {m_back}");
// dbg!(m, m_back);
// TODO(Jay): Fix me
// let encoded_m_back = decrypt_lwe(&lwe_out_ct,
// &lwe_sk_out.values(), &modq_op); let m_back =
// ((((encoded_m_back as f64) * ((1 << logp) as f64)) / q as
// f64).round() as u64)
// % (1u64 << logp);
// let noise =
// measure_noise_lwe(&lwe_out_ct, lwe_sk_out.values(),
// &modq_op, &encoded_m); println!("Noise:
// {noise}"); assert_eq!(m, m_back, "Expected
// {m} but got {m_back}"); dbg!(m, m_back);
// dbg!(encoded_m, encoded_m_back); // dbg!(encoded_m, encoded_m_back);
} }
} }

+ 1
- 3
src/main.rs

@ -1,3 +1 @@
fn main() {
println!("Hello, world!");
}
fn main() {}

+ 286
- 0
src/multi_party.rs

@ -0,0 +1,286 @@
use std::fmt::Debug;
use itertools::izip;
use num_traits::Zero;
use crate::{
backend::{GetModulus, Modulus, VectorOps},
ntt::Ntt,
random::{
RandomFillGaussianInModulus, RandomFillUniformInModulus, RandomGaussianElementInModulus,
},
utils::TryConvertFrom1,
ArithmeticOps, Matrix, MatrixEntity, MatrixMut, Row, RowEntity, RowMut,
};
pub(crate) fn public_key_share<
R: Row + RowMut + RowEntity,
S,
ModOp: VectorOps<Element = R::Element> + GetModulus<Element = R::Element>,
NttOp: Ntt<Element = R::Element>,
Rng: RandomFillGaussianInModulus<[R::Element], ModOp::M>,
PRng: RandomFillUniformInModulus<[R::Element], ModOp::M>,
>(
share_out: &mut R,
s_i: &[S],
modop: &ModOp,
nttop: &NttOp,
p_rng: &mut PRng,
rng: &mut Rng,
) where
R: TryConvertFrom1<[S], ModOp::M>,
{
let ring_size = share_out.as_ref().len();
assert!(s_i.len() == ring_size);
let q = modop.modulus();
// sample a
let mut a = {
let mut a = R::zeros(ring_size);
RandomFillUniformInModulus::random_fill(p_rng, &q, a.as_mut());
a
};
// s*a
nttop.forward(a.as_mut());
let mut s = R::try_convert_from(s_i, &q);
nttop.forward(s.as_mut());
modop.elwise_mul_mut(s.as_mut(), a.as_ref());
nttop.backward(s.as_mut());
RandomFillGaussianInModulus::random_fill(rng, &q, share_out.as_mut());
modop.elwise_add_mut(share_out.as_mut(), s.as_ref()); // s*e + e
}
/// Generate decryption share for LWE ciphertext `lwe_ct` with user's secret `s`
pub(crate) fn multi_party_decryption_share<
R: RowMut + RowEntity,
Mod: Modulus<Element = R::Element>,
ModOp: ArithmeticOps<Element = R::Element> + VectorOps<Element = R::Element> + GetModulus<M = Mod>,
Rng: RandomGaussianElementInModulus<R::Element, Mod>,
S,
>(
lwe_ct: &R,
s: &[S],
mod_op: &ModOp,
rng: &mut Rng,
) -> R::Element
where
R: TryConvertFrom1<[S], Mod>,
R::Element: Zero,
{
assert!(lwe_ct.as_ref().len() == s.len() + 1);
let mut neg_s = R::try_convert_from(s, mod_op.modulus());
mod_op.elwise_neg_mut(neg_s.as_mut());
// share = (\sum -s_i * a_i) + e
let mut share = R::Element::zero();
izip!(neg_s.as_ref().iter(), lwe_ct.as_ref().iter().skip(1)).for_each(|(si, ai)| {
share = mod_op.add(&share, &mod_op.mul(si, ai));
});
let e = rng.random(mod_op.modulus());
share = mod_op.add(&share, &e);
share
}
/// Aggregate decryption shares for `lwe_ct` and return noisy decryption output
/// `m + e`
pub(crate) fn multi_party_aggregate_decryption_shares_and_decrypt<
R: RowMut + RowEntity,
ModOp: ArithmeticOps<Element = R::Element>,
>(
lwe_ct: &R,
shares: &[R::Element],
mod_op: &ModOp,
) -> R::Element
where
R::Element: Zero,
{
let mut sum_shares = R::Element::zero();
shares
.iter()
.for_each(|v| sum_shares = mod_op.add(&sum_shares, v));
mod_op.add(&lwe_ct.as_ref()[0], &sum_shares)
}
pub(crate) fn non_interactive_rgsw_ct<
M: MatrixMut + MatrixEntity,
S,
PRng: RandomFillUniformInModulus<[M::MatElement], ModOp::M>,
Rng: RandomFillGaussianInModulus<[M::MatElement], ModOp::M>,
NttOp: Ntt<Element = M::MatElement>,
ModOp: VectorOps<Element = M::MatElement> + GetModulus<Element = M::MatElement>,
>(
s: &[S],
u: &[S],
m: &[M::MatElement],
gadget_vec: &[M::MatElement],
p_rng: &mut PRng,
rng: &mut Rng,
nttop: &NttOp,
modop: &ModOp,
) -> (M, M)
where
<M as Matrix>::R: RowMut + TryConvertFrom1<[S], ModOp::M> + RowEntity,
M::MatElement: Copy,
{
assert_eq!(s.len(), u.len());
assert_eq!(s.len(), m.len());
let q = modop.modulus();
let d = gadget_vec.len();
let ring_size = s.len();
let mut s_poly_eval = M::R::try_convert_from(s, q);
let mut u_poly_eval = M::R::try_convert_from(u, q);
nttop.forward(s_poly_eval.as_mut());
nttop.forward(u_poly_eval.as_mut());
// encryptions of a_i*u + e + \beta m
let mut enc_beta_m = M::zeros(d, ring_size);
// zero encrypition: a_i*s + e'
let mut zero_encryptions = M::zeros(d, ring_size);
let mut scratch_space = M::R::zeros(ring_size);
izip!(
enc_beta_m.iter_rows_mut(),
zero_encryptions.iter_rows_mut(),
gadget_vec.iter()
)
.for_each(|(e_beta_m, e_zero, beta)| {
// sample a_i
RandomFillUniformInModulus::random_fill(p_rng, q, e_beta_m.as_mut());
e_zero.as_mut().copy_from_slice(e_beta_m.as_ref());
// a_i * u + \beta m + e //
// a_i * u
nttop.forward(e_beta_m.as_mut());
modop.elwise_mul_mut(e_beta_m.as_mut(), u_poly_eval.as_ref());
nttop.backward(e_beta_m.as_mut());
// sample error e
RandomFillGaussianInModulus::random_fill(rng, q, scratch_space.as_mut());
// a_i * u + e
modop.elwise_add_mut(e_beta_m.as_mut(), scratch_space.as_ref());
// beta * m
modop.elwise_scalar_mul(scratch_space.as_mut(), m.as_ref(), beta);
// a_i * u + e + \beta m
modop.elwise_add_mut(e_beta_m.as_mut(), scratch_space.as_ref());
// a_i * s + e //
// a_i * s
nttop.forward(e_zero.as_mut());
modop.elwise_mul_mut(e_zero.as_mut(), s_poly_eval.as_ref());
nttop.backward(e_zero.as_mut());
// sample error e
RandomFillGaussianInModulus::random_fill(rng, q, scratch_space.as_mut());
// a_i * s + e
modop.elwise_add_mut(e_zero.as_mut(), scratch_space.as_ref());
});
(enc_beta_m, zero_encryptions)
}
pub(crate) fn non_interactive_ksk_gen<
M: MatrixMut + MatrixEntity,
S,
PRng: RandomFillUniformInModulus<[M::MatElement], ModOp::M>,
Rng: RandomFillGaussianInModulus<[M::MatElement], ModOp::M>,
NttOp: Ntt<Element = M::MatElement>,
ModOp: VectorOps<Element = M::MatElement> + GetModulus<Element = M::MatElement>,
>(
s: &[S],
u: &[S],
gadget_vec: &[M::MatElement],
p_rng: &mut PRng,
rng: &mut Rng,
nttop: &NttOp,
modop: &ModOp,
) -> M
where
<M as Matrix>::R: RowMut + TryConvertFrom1<[S], ModOp::M> + RowEntity,
M::MatElement: Copy + Debug,
{
assert_eq!(s.len(), u.len());
let q = modop.modulus();
let d = gadget_vec.len();
let ring_size = s.len();
let mut s_poly_eval = M::R::try_convert_from(s, q);
nttop.forward(s_poly_eval.as_mut());
let u_poly = M::R::try_convert_from(u, q);
// a_i * s + \beta u + e
let mut ksk = M::zeros(d, ring_size);
let mut scratch_space = M::R::zeros(ring_size);
izip!(ksk.iter_rows_mut(), gadget_vec.iter()).for_each(|(e_ksk, beta)| {
// sample a_i
RandomFillUniformInModulus::random_fill(p_rng, q, e_ksk.as_mut());
// a_i * s + e + beta u
nttop.forward(e_ksk.as_mut());
modop.elwise_mul_mut(e_ksk.as_mut(), s_poly_eval.as_ref());
nttop.backward(e_ksk.as_mut());
// sample error e
RandomFillGaussianInModulus::random_fill(rng, q, scratch_space.as_mut());
// a_i * s + e
modop.elwise_add_mut(e_ksk.as_mut(), scratch_space.as_ref());
// \beta * u
modop.elwise_scalar_mul(scratch_space.as_mut(), u_poly.as_ref(), beta);
// a_i * s + e + \beta * u
modop.elwise_add_mut(e_ksk.as_mut(), scratch_space.as_ref());
});
ksk
}
pub(crate) fn non_interactive_ksk_zero_encryptions_for_other_party_i<
M: MatrixMut + MatrixEntity,
S,
PRng: RandomFillUniformInModulus<[M::MatElement], ModOp::M>,
Rng: RandomFillGaussianInModulus<[M::MatElement], ModOp::M>,
NttOp: Ntt<Element = M::MatElement>,
ModOp: VectorOps<Element = M::MatElement> + GetModulus<Element = M::MatElement>,
>(
s: &[S],
gadget_vec: &[M::MatElement],
p_rng: &mut PRng,
rng: &mut Rng,
nttop: &NttOp,
modop: &ModOp,
) -> M
where
<M as Matrix>::R: RowMut + TryConvertFrom1<[S], ModOp::M> + RowEntity,
M::MatElement: Copy + Debug,
{
let q = modop.modulus();
let d = gadget_vec.len();
let ring_size = s.len();
let mut s_poly_eval = M::R::try_convert_from(s, q);
nttop.forward(s_poly_eval.as_mut());
// a_i * s + e
let mut zero_encs = M::zeros(d, ring_size);
let mut scratch_space = M::R::zeros(ring_size);
izip!(zero_encs.iter_rows_mut()).for_each(|e_zero| {
// sample a_i
RandomFillUniformInModulus::random_fill(p_rng, q, e_zero.as_mut());
// a_i * s + e
nttop.forward(e_zero.as_mut());
modop.elwise_mul_mut(e_zero.as_mut(), s_poly_eval.as_ref());
nttop.backward(e_zero.as_mut());
// sample error e
RandomFillGaussianInModulus::random_fill(rng, q, scratch_space.as_mut());
modop.elwise_add_mut(e_zero.as_mut(), scratch_space.as_ref());
});
zero_encs
}

+ 229
- 96
src/ntt.rs

@ -1,11 +1,18 @@
use itertools::Itertools;
use rand::{thread_rng, Rng, RngCore};
use itertools::{izip, Itertools};
use rand::{Rng, RngCore, SeedableRng};
use rand_chacha::ChaCha8Rng;
use crate::{ use crate::{
backend::{ArithmeticOps, ModularOpsU64},
utils::{mod_exponent, mod_inverse, shoup_representation_fq},
backend::{ArithmeticOps, ModInit, ModularOpsU64, Modulus},
utils::{mod_exponent, mod_inverse, ShoupMul},
}; };
pub trait NttInit<M> {
/// Ntt istance must be compatible across different instances with same `q`
/// and `n`
fn new(q: &M, n: usize) -> Self;
}
pub trait Ntt { pub trait Ntt {
type Element; type Element;
fn forward_lazy(&self, v: &mut [Self::Element]); fn forward_lazy(&self, v: &mut [Self::Element]);
@ -21,27 +28,50 @@ pub trait Ntt {
/// and both x' and y' are \in [0, 4q) /// and both x' and y' are \in [0, 4q)
/// ///
/// Implements Algorithm 4 of [FASTER ARITHMETIC FOR NUMBER-THEORETIC TRANSFORMS](https://arxiv.org/pdf/1205.2926.pdf) /// Implements Algorithm 4 of [FASTER ARITHMETIC FOR NUMBER-THEORETIC TRANSFORMS](https://arxiv.org/pdf/1205.2926.pdf)
pub unsafe fn forward_butterly(
x: *mut u64,
y: *mut u64,
w: &u64,
w_shoup: &u64,
q: &u64,
q_twice: &u64,
) {
debug_assert!(*x < *q * 4, "{} >= (4q){}", *x, 4 * q);
debug_assert!(*y < *q * 4, "{} >= (4q){}", *y, 4 * q);
pub fn forward_butterly_0_to_4q(
mut x: u64,
y: u64,
w: u64,
w_shoup: u64,
q: u64,
q_twice: u64,
) -> (u64, u64) {
debug_assert!(x < q * 4, "{} >= (4q){}", x, 4 * q);
debug_assert!(y < q * 4, "{} >= (4q){}", y, 4 * q);
if *x >= *q_twice {
*x = *x - q_twice;
if x >= q_twice {
x = x - q_twice;
} }
// TODO (Jay): Hot path expected. How expensive is it?
let k = ((*w_shoup as u128 * *y as u128) >> 64) as u64;
let t = w.wrapping_mul(*y).wrapping_sub(k.wrapping_mul(*q));
let t = ShoupMul::mul(y, w, w_shoup, q);
*y = *x + q_twice - t;
*x = *x + t;
(x + t, x + q_twice - t)
}
pub fn forward_butterly_0_to_2q(
mut x: u64,
y: u64,
w: u64,
w_shoup: u64,
q: u64,
q_twice: u64,
) -> (u64, u64) {
debug_assert!(x < q * 4, "{} >= (4q){}", x, 4 * q);
debug_assert!(y < q * 4, "{} >= (4q){}", y, 4 * q);
if x >= q_twice {
x = x - q_twice;
}
let t = ShoupMul::mul(y, w, w_shoup, q);
let ox = x.wrapping_add(t);
let oy = x.wrapping_sub(t);
(
(ox).min(ox.wrapping_sub(q_twice)),
oy.min(oy.wrapping_add(q_twice)),
)
} }
/// Inverse butterfly routine of Inverse Number theoretic transform. Given /// Inverse butterfly routine of Inverse Number theoretic transform. Given
@ -51,27 +81,26 @@ pub unsafe fn forward_butterly(
/// and both x' and y' are \in [0, 2q) /// and both x' and y' are \in [0, 2q)
/// ///
/// Implements Algorithm 3 of [FASTER ARITHMETIC FOR NUMBER-THEORETIC TRANSFORMS](https://arxiv.org/pdf/1205.2926.pdf) /// Implements Algorithm 3 of [FASTER ARITHMETIC FOR NUMBER-THEORETIC TRANSFORMS](https://arxiv.org/pdf/1205.2926.pdf)
pub unsafe fn inverse_butterfly(
x: *mut u64,
y: *mut u64,
w_inv: &u64,
w_inv_shoup: &u64,
q: &u64,
q_twice: &u64,
) {
debug_assert!(*x < *q_twice, "{} >= (2q){q_twice}", *x);
debug_assert!(*y < *q_twice, "{} >= (2q){q_twice}", *y);
pub fn inverse_butterfly_0_to_2q(
x: u64,
y: u64,
w_inv: u64,
w_inv_shoup: u64,
q: u64,
q_twice: u64,
) -> (u64, u64) {
debug_assert!(x < q_twice, "{} >= (2q){q_twice}", x);
debug_assert!(y < q_twice, "{} >= (2q){q_twice}", y);
let mut x_dash = *x + *y;
if x_dash >= *q_twice {
let mut x_dash = x + y;
if x_dash >= q_twice {
x_dash -= q_twice x_dash -= q_twice
} }
let t = *x + q_twice - *y;
let k = ((*w_inv_shoup as u128 * t as u128) >> 64) as u64; // TODO (Jay): Hot path
*y = w_inv.wrapping_mul(t).wrapping_sub(k.wrapping_mul(*q));
let t = x + q_twice - y;
let y = ShoupMul::mul(t, w_inv, w_inv_shoup, q);
*x = x_dash;
(x_dash, y)
} }
/// Number theoretic transform of vector `a` where each element can be in range /// Number theoretic transform of vector `a` where each element can be in range
@ -79,7 +108,7 @@ pub unsafe fn inverse_butterfly(
/// ///
/// Implements Cooley-tukey based forward NTT as given in Algorithm 1 of https://eprint.iacr.org/2016/504.pdf. /// Implements Cooley-tukey based forward NTT as given in Algorithm 1 of https://eprint.iacr.org/2016/504.pdf.
pub fn ntt_lazy(a: &mut [u64], psi: &[u64], psi_shoup: &[u64], q: u64, q_twice: u64) { pub fn ntt_lazy(a: &mut [u64], psi: &[u64], psi_shoup: &[u64], q: u64, q_twice: u64) {
debug_assert!(a.len() == psi.len());
assert!(a.len() == psi.len());
let n = a.len(); let n = a.len();
let mut t = n; let mut t = n;
@ -87,30 +116,67 @@ pub fn ntt_lazy(a: &mut [u64], psi: &[u64], psi_shoup: &[u64], q: u64, q_twice:
let mut m = 1; let mut m = 1;
while m < n { while m < n {
t >>= 1; t >>= 1;
for i in 0..m {
let j_1 = 2 * i * t;
let j_2 = j_1 + t;
unsafe {
let w = psi.get_unchecked(m + i);
let w_shoup = psi_shoup.get_unchecked(m + i);
for j in j_1..j_2 {
let x = a.get_unchecked_mut(j) as *mut u64;
let y = a.get_unchecked_mut(j + t) as *mut u64;
forward_butterly(x, y, w, w_shoup, &q, &q_twice);
let w = &psi[m..];
let w_shoup = &psi_shoup[m..];
if t == 1 {
for (a, w, w_shoup) in izip!(a.chunks_mut(2), w.iter(), w_shoup.iter()) {
let (ox, oy) = forward_butterly_0_to_2q(a[0], a[1], *w, *w_shoup, q, q_twice);
a[0] = ox;
a[1] = oy;
}
} else {
for i in 0..m {
let a = &mut a[2 * i * t..(2 * (i + 1) * t)];
let (left, right) = a.split_at_mut(t);
for (x, y) in izip!(left.iter_mut(), right.iter_mut()) {
let (ox, oy) = forward_butterly_0_to_4q(*x, *y, w[i], w_shoup[i], q, q_twice);
*x = ox;
*y = oy;
} }
} }
} }
m <<= 1; m <<= 1;
} }
}
/// Same as `ntt_lazy` with output in range [0, q)
pub fn ntt(a: &mut [u64], psi: &[u64], psi_shoup: &[u64], q: u64, q_twice: u64) {
assert!(a.len() == psi.len());
a.iter_mut().for_each(|a0| {
if *a0 >= q_twice {
*a0 -= q_twice
let n = a.len();
let mut t = n;
let mut m = 1;
while m < n {
t >>= 1;
let w = &psi[m..];
let w_shoup = &psi_shoup[m..];
if t == 1 {
for (a, w, w_shoup) in izip!(a.chunks_mut(2), w.iter(), w_shoup.iter()) {
let (ox, oy) = forward_butterly_0_to_2q(a[0], a[1], *w, *w_shoup, q, q_twice);
// reduce from range [0, 2q) to [0, q)
a[0] = ox.min(ox.wrapping_sub(q));
a[1] = oy.min(oy.wrapping_sub(q));
}
} else {
for i in 0..m {
let a = &mut a[2 * i * t..(2 * (i + 1) * t)];
let (left, right) = a.split_at_mut(t);
for (x, y) in izip!(left.iter_mut(), right.iter_mut()) {
let (ox, oy) = forward_butterly_0_to_4q(*x, *y, w[i], w_shoup[i], q, q_twice);
*x = ox;
*y = oy;
}
}
} }
});
m <<= 1;
}
} }
/// Inverse number theoretic transform of input vector `a` with each element can /// Inverse number theoretic transform of input vector `a` with each element can
@ -123,36 +189,92 @@ pub fn ntt_inv_lazy(
psi_inv: &[u64], psi_inv: &[u64],
psi_inv_shoup: &[u64], psi_inv_shoup: &[u64],
n_inv: u64, n_inv: u64,
n_inv_shoup: u64,
q: u64, q: u64,
q_twice: u64, q_twice: u64,
) { ) {
debug_assert!(a.len() == psi_inv.len());
assert!(a.len() == psi_inv.len());
let mut m = a.len();
let mut m = a.len() >> 1;
let mut t = 1; let mut t = 1;
while m > 1 {
let mut j_1: usize = 0;
let h = m >> 1;
for i in 0..h {
let j_2 = j_1 + t;
unsafe {
let w_inv = psi_inv.get_unchecked(h + i);
let w_inv_shoup = psi_inv_shoup.get_unchecked(h + i);
for j in j_1..j_2 {
let x = a.get_unchecked_mut(j) as *mut u64;
let y = a.get_unchecked_mut(j + t) as *mut u64;
inverse_butterfly(x, y, w_inv, w_inv_shoup, &q, &q_twice);
while m > 0 {
if m == 1 {
let (left, right) = a.split_at_mut(t);
for (x, y) in izip!(left.iter_mut(), right.iter_mut()) {
let (ox, oy) =
inverse_butterfly_0_to_2q(*x, *y, psi_inv[1], psi_inv_shoup[1], q, q_twice);
*x = ShoupMul::mul(ox, n_inv, n_inv_shoup, q);
*y = ShoupMul::mul(oy, n_inv, n_inv_shoup, q);
}
} else {
let w_inv = &psi_inv[m..];
let w_inv_shoup = &psi_inv_shoup[m..];
for i in 0..m {
let a = &mut a[2 * i * t..2 * (i + 1) * t];
let (left, right) = a.split_at_mut(t);
for (x, y) in izip!(left.iter_mut(), right.iter_mut()) {
let (ox, oy) =
inverse_butterfly_0_to_2q(*x, *y, w_inv[i], w_inv_shoup[i], q, q_twice);
*x = ox;
*y = oy;
} }
} }
j_1 = j_1 + 2 * t;
} }
t *= 2; t *= 2;
m >>= 1; m >>= 1;
} }
}
/// Same as `ntt_inv_lazy` with output in range [0, q)
pub fn ntt_inv(
a: &mut [u64],
psi_inv: &[u64],
psi_inv_shoup: &[u64],
n_inv: u64,
n_inv_shoup: u64,
q: u64,
q_twice: u64,
) {
assert!(a.len() == psi_inv.len());
let mut m = a.len() >> 1;
let mut t = 1;
a.iter_mut()
.for_each(|a0| *a0 = ((*a0 as u128 * n_inv as u128) % q as u128) as u64);
while m > 0 {
if m == 1 {
let (left, right) = a.split_at_mut(t);
for (x, y) in izip!(left.iter_mut(), right.iter_mut()) {
let (ox, oy) =
inverse_butterfly_0_to_2q(*x, *y, psi_inv[1], psi_inv_shoup[1], q, q_twice);
let ox = ShoupMul::mul(ox, n_inv, n_inv_shoup, q);
let oy = ShoupMul::mul(oy, n_inv, n_inv_shoup, q);
*x = ox.min(ox.wrapping_sub(q));
*y = oy.min(oy.wrapping_sub(q));
}
} else {
let w_inv = &psi_inv[m..];
let w_inv_shoup = &psi_inv_shoup[m..];
for i in 0..m {
let a = &mut a[2 * i * t..2 * (i + 1) * t];
let (left, right) = a.split_at_mut(t);
for (x, y) in izip!(left.iter_mut(), right.iter_mut()) {
let (ox, oy) =
inverse_butterfly_0_to_2q(*x, *y, w_inv[i], w_inv_shoup[i], q, q_twice);
*x = ox;
*y = oy;
}
}
}
t *= 2;
m >>= 1;
}
} }
/// Find n^{th} root of unity in field F_q, if one exists /// Find n^{th} root of unity in field F_q, if one exists
@ -184,11 +306,13 @@ pub(crate) fn find_primitive_root(q: u64, n: u64, rng: &mut R) -> Op
None None
} }
#[derive(Debug)]
pub struct NttBackendU64 { pub struct NttBackendU64 {
q: u64, q: u64,
q_twice: u64, q_twice: u64,
n: u64,
_n: u64,
n_inv: u64, n_inv: u64,
n_inv_shoup: u64,
psi_powers_bo: Box<[u64]>, psi_powers_bo: Box<[u64]>,
psi_inv_powers_bo: Box<[u64]>, psi_inv_powers_bo: Box<[u64]>,
psi_powers_bo_shoup: Box<[u64]>, psi_powers_bo_shoup: Box<[u64]>,
@ -196,12 +320,11 @@ pub struct NttBackendU64 {
} }
impl NttBackendU64 { impl NttBackendU64 {
pub fn new(q: u64, n: usize) -> Self {
fn _new(q: u64, n: usize) -> Self {
// \psi = 2n^{th} primitive root of unity in F_q // \psi = 2n^{th} primitive root of unity in F_q
let mut rng = thread_rng();
let mut rng = ChaCha8Rng::from_seed([0u8; 32]);
let psi = find_primitive_root(q, (n * 2) as u64, &mut rng) let psi = find_primitive_root(q, (n * 2) as u64, &mut rng)
.expect("Unable to find 2n^th root of unity"); .expect("Unable to find 2n^th root of unity");
let psi_inv = mod_inverse(psi, q); let psi_inv = mod_inverse(psi, q);
// assert!( // assert!(
@ -238,11 +361,11 @@ impl NttBackendU64 {
// shoup representation // shoup representation
let psi_powers_bo_shoup = psi_powers_bo let psi_powers_bo_shoup = psi_powers_bo
.iter() .iter()
.map(|v| shoup_representation_fq(*v, q))
.map(|v| ShoupMul::representation(*v, q))
.collect_vec(); .collect_vec();
let psi_inv_powers_bo_shoup = psi_inv_powers_bo let psi_inv_powers_bo_shoup = psi_inv_powers_bo
.iter() .iter()
.map(|v| shoup_representation_fq(*v, q))
.map(|v| ShoupMul::representation(*v, q))
.collect_vec(); .collect_vec();
// n^{-1} \mod{q} // n^{-1} \mod{q}
@ -251,8 +374,9 @@ impl NttBackendU64 {
NttBackendU64 { NttBackendU64 {
q, q,
q_twice: 2 * q, q_twice: 2 * q,
n: n as u64,
_n: n as u64,
n_inv, n_inv,
n_inv_shoup: ShoupMul::representation(n_inv, q),
psi_powers_bo: psi_powers_bo.into_boxed_slice(), psi_powers_bo: psi_powers_bo.into_boxed_slice(),
psi_inv_powers_bo: psi_inv_powers_bo.into_boxed_slice(), psi_inv_powers_bo: psi_inv_powers_bo.into_boxed_slice(),
psi_powers_bo_shoup: psi_powers_bo_shoup.into_boxed_slice(), psi_powers_bo_shoup: psi_powers_bo_shoup.into_boxed_slice(),
@ -261,14 +385,11 @@ impl NttBackendU64 {
} }
} }
impl NttBackendU64 {
fn reduce_from_lazy(&self, a: &mut [u64]) {
let q = self.q;
a.iter_mut().for_each(|a0| {
if *a0 >= q {
*a0 = *a0 - q;
}
});
impl<M: Modulus<Element = u64>> NttInit<M> for NttBackendU64 {
fn new(q: &M, n: usize) -> Self {
// This NTT does not support native modulus
assert!(!q.is_native());
NttBackendU64::_new(q.q().unwrap(), n)
} }
} }
@ -286,14 +407,13 @@ impl Ntt for NttBackendU64 {
} }
fn forward(&self, v: &mut [Self::Element]) { fn forward(&self, v: &mut [Self::Element]) {
ntt_lazy(
ntt(
v, v,
&self.psi_powers_bo, &self.psi_powers_bo,
&self.psi_powers_bo_shoup, &self.psi_powers_bo_shoup,
self.q, self.q,
self.q_twice, self.q_twice,
); );
self.reduce_from_lazy(v);
} }
fn backward_lazy(&self, v: &mut [Self::Element]) { fn backward_lazy(&self, v: &mut [Self::Element]) {
@ -302,24 +422,26 @@ impl Ntt for NttBackendU64 {
&self.psi_inv_powers_bo, &self.psi_inv_powers_bo,
&self.psi_inv_powers_bo_shoup, &self.psi_inv_powers_bo_shoup,
self.n_inv, self.n_inv,
self.n_inv_shoup,
self.q, self.q,
self.q_twice, self.q_twice,
) )
} }
fn backward(&self, v: &mut [Self::Element]) { fn backward(&self, v: &mut [Self::Element]) {
ntt_inv_lazy(
ntt_inv(
v, v,
&self.psi_inv_powers_bo, &self.psi_inv_powers_bo,
&self.psi_inv_powers_bo_shoup, &self.psi_inv_powers_bo_shoup,
self.n_inv, self.n_inv,
self.n_inv_shoup,
self.q, self.q,
self.q_twice, self.q_twice,
); );
self.reduce_from_lazy(v);
} }
} }
#[cfg(test)]
mod tests { mod tests {
use itertools::Itertools; use itertools::Itertools;
use rand::{thread_rng, Rng}; use rand::{thread_rng, Rng};
@ -327,7 +449,7 @@ mod tests {
use super::NttBackendU64; use super::NttBackendU64;
use crate::{ use crate::{
backend::{ModularOpsU64, VectorOps},
backend::{ModInit, ModularOpsU64, VectorOps},
ntt::Ntt, ntt::Ntt,
utils::{generate_prime, negacyclic_mul}, utils::{generate_prime, negacyclic_mul},
}; };
@ -344,29 +466,40 @@ mod tests {
.collect_vec() .collect_vec()
} }
fn assert_output_range(a: &[u64], max_val: u64) {
a.iter()
.for_each(|v| assert!(v <= &max_val, "{v} > {max_val}"));
}
#[test] #[test]
fn native_ntt_backend_works() { fn native_ntt_backend_works() {
// TODO(Jay): Improve tests. Add tests for different primes and ring size. // TODO(Jay): Improve tests. Add tests for different primes and ring size.
let ntt_backend = NttBackendU64::new(Q_60_BITS, N);
let ntt_backend = NttBackendU64::_new(Q_60_BITS, N);
for _ in 0..K { for _ in 0..K {
let mut a = random_vec_in_fq(N, Q_60_BITS); let mut a = random_vec_in_fq(N, Q_60_BITS);
let a_clone = a.clone(); let a_clone = a.clone();
ntt_backend.forward(&mut a); ntt_backend.forward(&mut a);
assert_output_range(a.as_ref(), Q_60_BITS - 1);
assert_ne!(a, a_clone); assert_ne!(a, a_clone);
ntt_backend.backward(&mut a); ntt_backend.backward(&mut a);
assert_output_range(a.as_ref(), Q_60_BITS - 1);
assert_eq!(a, a_clone); assert_eq!(a, a_clone);
ntt_backend.forward_lazy(&mut a); ntt_backend.forward_lazy(&mut a);
assert_output_range(a.as_ref(), (2 * Q_60_BITS) - 1);
assert_ne!(a, a_clone); assert_ne!(a, a_clone);
ntt_backend.backward(&mut a); ntt_backend.backward(&mut a);
assert_output_range(a.as_ref(), Q_60_BITS - 1);
assert_eq!(a, a_clone); assert_eq!(a, a_clone);
ntt_backend.forward(&mut a); ntt_backend.forward(&mut a);
assert_output_range(a.as_ref(), Q_60_BITS - 1);
ntt_backend.backward_lazy(&mut a); ntt_backend.backward_lazy(&mut a);
assert_output_range(a.as_ref(), (2 * Q_60_BITS) - 1);
// reduce // reduce
a.iter_mut().for_each(|a0| { a.iter_mut().for_each(|a0| {
if *a0 > Q_60_BITS {
if *a0 >= Q_60_BITS {
*a0 -= *a0 - Q_60_BITS; *a0 -= *a0 - Q_60_BITS;
} }
}); });
@ -376,13 +509,13 @@ mod tests {
#[test] #[test]
fn native_ntt_negacylic_mul() { fn native_ntt_negacylic_mul() {
let primes = [40, 50, 60]
let primes = [25, 40, 50, 60]
.iter() .iter()
.map(|bits| generate_prime(*bits, (2 * N) as u64, 1u64 << bits).unwrap()) .map(|bits| generate_prime(*bits, (2 * N) as u64, 1u64 << bits).unwrap())
.collect_vec(); .collect_vec();
for p in primes.into_iter() { for p in primes.into_iter() {
let ntt_backend = NttBackendU64::new(p, N);
let ntt_backend = NttBackendU64::_new(p, N);
let modulus_backend = ModularOpsU64::new(p); let modulus_backend = ModularOpsU64::new(p);
for _ in 0..K { for _ in 0..K {
let a = random_vec_in_fq(N, p); let a = random_vec_in_fq(N, p);

+ 0
- 3
src/num.rs

@ -1,3 +0,0 @@
use num_traits::{Num, PrimInt, WrappingShl, WrappingShr, Zero};
pub trait UnsignedInteger: Zero + Num {}

+ 482
- 0
src/pbs.rs

@ -0,0 +1,482 @@
use std::fmt::Display;
use num_traits::{FromPrimitive, One, PrimInt, ToPrimitive, Zero};
use crate::{
backend::{ArithmeticOps, Modulus, ShoupMatrixFMA, VectorOps},
decomposer::{Decomposer, RlweDecomposer},
lwe::lwe_key_switch,
ntt::Ntt,
rgsw::{
rlwe_auto_shoup, rlwe_by_rgsw_shoup, RgswCiphertextRef, RlweCiphertextMutRef, RlweKskRef,
RuntimeScratchMutRef,
},
Matrix, MatrixEntity, MatrixMut, RowMut,
};
pub(crate) trait PbsKey {
type RgswCt;
type AutoKey;
type LweKskKey;
/// RGSW ciphertext of LWE secret elements
fn rgsw_ct_lwe_si(&self, si: usize) -> &Self::RgswCt;
/// Key for automorphism with g^k. For -g use k = 0
fn galois_key_for_auto(&self, k: usize) -> &Self::AutoKey;
/// LWE ksk to key switch from RLWE secret to LWE secret
fn lwe_ksk(&self) -> &Self::LweKskKey;
}
pub(crate) trait WithShoupRepr: AsRef<Self::M> {
type M;
fn shoup_repr(&self) -> &Self::M;
}
pub(crate) trait PbsInfo {
/// Type of Matrix
type M: Matrix;
/// Type of Ciphertext modulus
type Modulus: Modulus<Element = <Self::M as Matrix>::MatElement>;
/// Type of Ntt Operator for Ring polynomials
type NttOp: Ntt<Element = <Self::M as Matrix>::MatElement>;
/// Type of Signed Decomposer
type D: Decomposer<Element = <Self::M as Matrix>::MatElement>;
// Although both `RlweModOp` and `LweModOp` types have same bounds, they can be
// different types. For ex, type RlweModOp may only support native modulus,
// where LweModOp may only support prime modulus, etc.
/// Type of RLWE Modulus Operator
type RlweModOp: ArithmeticOps<Element = <Self::M as Matrix>::MatElement>
+ ShoupMatrixFMA<<Self::M as Matrix>::R>;
/// Type of LWE Modulus Operator
type LweModOp: VectorOps<Element = <Self::M as Matrix>::MatElement>
+ ArithmeticOps<Element = <Self::M as Matrix>::MatElement>;
/// RLWE ciphertext modulus
fn rlwe_q(&self) -> &Self::Modulus;
/// LWE ciphertext modulus
fn lwe_q(&self) -> &Self::Modulus;
/// Blind rotation modulus. It is the modulus to which we switch for blind
/// rotation. Since blind rotation decrypts LWE ciphetext in the exponent of
/// ring polynmial (which is a ring mod 2N), `br_q <= 2N`
fn br_q(&self) -> usize;
/// Ring polynomial size `N`
fn rlwe_n(&self) -> usize;
/// LWE dimension `n`
fn lwe_n(&self) -> usize;
/// Embedding fator for ring X^{q}+1 inside
fn embedding_factor(&self) -> usize;
/// Window size parameter LKMC++ blind rotaiton
fn w(&self) -> usize;
/// generator `g` for group Z^*_{br_q}
fn g(&self) -> isize;
/// LWE key switching decomposer
fn lwe_decomposer(&self) -> &Self::D;
/// RLWE x RGSW decoposer
fn rlwe_rgsw_decomposer(&self) -> &(Self::D, Self::D);
/// RLWE auto decomposer
fn auto_decomposer(&self) -> &Self::D;
/// LWE modulus operator
fn modop_lweq(&self) -> &Self::LweModOp;
/// RLWE modulus operator
fn modop_rlweq(&self) -> &Self::RlweModOp;
/// Ntt operators
fn nttop_rlweq(&self) -> &Self::NttOp;
/// Maps a \in Z^*_{br_q} to discrete log k, with generator g (i.e. g^k =
/// a). Returned vector is of size q that stores dlog of `a` at `vec[a]`.
///
/// For any `a`, if k is s.t. `a = g^{k} % br_q`, then `k` is expressed as
/// k. If `k` is s.t `a = -g^{k} % br_q`, then `k` is expressed as
/// k=k+q/4
fn g_k_dlog_map(&self) -> &[usize];
/// Returns auto map and index vector for auto element g^k. For auto element
/// -g set k = 0.
fn rlwe_auto_map(&self, k: usize) -> &(Vec<usize>, Vec<bool>);
}
/// - Mod down
/// - key switching
/// - mod down
/// - blind rotate
pub(crate) fn pbs<
M: MatrixMut + MatrixEntity,
MShoup: WithShoupRepr<M = M>,
P: PbsInfo<M = M>,
K: PbsKey<RgswCt = MShoup, AutoKey = MShoup, LweKskKey = M>,
>(
pbs_info: &P,
test_vec: &M::R,
lwe_in: &mut M::R,
pbs_key: &K,
scratch_lwe_vec: &mut M::R,
scratch_blind_rotate_matrix: &mut M,
) where
<M as Matrix>::R: RowMut,
M::MatElement: PrimInt + FromPrimitive + One + Copy + Zero + Display,
{
let rlwe_q = pbs_info.rlwe_q();
let lwe_q = pbs_info.lwe_q();
let br_q = pbs_info.br_q();
let rlwe_qf64 = rlwe_q.q_as_f64().unwrap();
let lwe_qf64 = lwe_q.q_as_f64().unwrap();
let br_qf64 = br_q.to_f64().unwrap();
let rlwe_n = pbs_info.rlwe_n();
// moddown Q -> Q_ks
lwe_in.as_mut().iter_mut().for_each(|v| {
*v =
M::MatElement::from_f64(((v.to_f64().unwrap() * lwe_qf64) / rlwe_qf64).round()).unwrap()
});
// key switch RLWE secret to LWE secret
// let now = std::time::Instant::now();
scratch_lwe_vec.as_mut().fill(M::MatElement::zero());
lwe_key_switch(
scratch_lwe_vec,
lwe_in,
pbs_key.lwe_ksk(),
pbs_info.modop_lweq(),
pbs_info.lwe_decomposer(),
);
// println!("Time: {:?}", now.elapsed());
// odd moddown Q_ks -> q
let g_k_dlog_map = pbs_info.g_k_dlog_map();
let mut g_k_si = vec![vec![]; br_q >> 1];
scratch_lwe_vec
.as_ref()
.iter()
.skip(1)
.enumerate()
.for_each(|(index, v)| {
let odd_v = mod_switch_odd(v.to_f64().unwrap(), lwe_qf64, br_qf64);
// dlog `k` for `odd_v` is stored as `k` if odd_v = +g^{k}. If odd_v = -g^{k},
// then `k` is stored as `q/4 + k`.
let k = g_k_dlog_map[odd_v];
// assert!(k != 0);
g_k_si[k].push(index);
});
// handle b and set trivial test RLWE
let g = pbs_info.g() as usize;
let g_times_b = (g * mod_switch_odd(
scratch_lwe_vec.as_ref()[0].to_f64().unwrap(),
lwe_qf64,
br_qf64,
)) % (br_q);
// v = (v(X) * X^{g*b}) mod X^{q/2}+1
let br_qby2 = br_q >> 1;
let mut gb_monomial_sign = true;
let mut gb_monomial_exp = g_times_b;
// X^{g*b} mod X^{q/2}+1
if gb_monomial_exp > br_qby2 {
gb_monomial_exp -= br_qby2;
gb_monomial_sign = false
}
// monomial mul
let mut trivial_rlwe_test_poly = M::zeros(2, rlwe_n);
if pbs_info.embedding_factor() == 1 {
monomial_mul(
test_vec.as_ref(),
trivial_rlwe_test_poly.get_row_mut(1).as_mut(),
gb_monomial_exp,
gb_monomial_sign,
br_qby2,
pbs_info.modop_rlweq(),
);
} else {
// use lwe_in to store the `t = v(X) * X^{g*2} mod X^{q/2}+1` temporarily. This
// works because q/2 <= N (where N is lwe_in LWE dimension) always.
monomial_mul(
test_vec.as_ref(),
&mut lwe_in.as_mut()[..br_qby2],
gb_monomial_exp,
gb_monomial_sign,
br_qby2,
pbs_info.modop_rlweq(),
);
// emebed poly `t` in ring X^{q/2}+1 inside the bigger ring X^{N}+1
let embed_factor = pbs_info.embedding_factor();
let partb_trivial_rlwe = trivial_rlwe_test_poly.get_row_mut(1);
lwe_in.as_ref()[..br_qby2]
.iter()
.enumerate()
.for_each(|(index, v)| {
partb_trivial_rlwe[embed_factor * index] = *v;
});
}
// let now = std::time::Instant::now();
// blind rotate
blind_rotation(
&mut trivial_rlwe_test_poly,
scratch_blind_rotate_matrix,
pbs_info.g(),
pbs_info.w(),
br_q,
&g_k_si,
pbs_info.rlwe_rgsw_decomposer(),
pbs_info.auto_decomposer(),
pbs_info.nttop_rlweq(),
pbs_info.modop_rlweq(),
pbs_info,
pbs_key,
);
// println!("Blind rotation time: {:?}", now.elapsed());
// sample extract
sample_extract(lwe_in, &trivial_rlwe_test_poly, pbs_info.modop_rlweq(), 0);
}
/// LMKCY+ Blind rotation
///
/// - gk_to_si: Contains LWE secret index `i` in array of secret indices at k^th
/// index if a_i = g^k if k < q/4 or a_i = -g^k if k > q/4. [g^0, ...,
/// g^{q/2-1}, -g^0, -g^1, .., -g^{q/2-1}]
fn blind_rotation<
Mmut: MatrixMut,
RlweD: RlweDecomposer<Element = Mmut::MatElement>,
AutoD: Decomposer<Element = Mmut::MatElement>,
NttOp: Ntt<Element = Mmut::MatElement>,
ModOp: ArithmeticOps<Element = Mmut::MatElement> + ShoupMatrixFMA<Mmut::R>,
MShoup: WithShoupRepr<M = Mmut>,
K: PbsKey<RgswCt = MShoup, AutoKey = MShoup>,
P: PbsInfo<M = Mmut>,
>(
trivial_rlwe_test_poly: &mut Mmut,
scratch_matrix: &mut Mmut,
_g: isize,
w: usize,
q: usize,
gk_to_si: &[Vec<usize>],
rlwe_rgsw_decomposer: &RlweD,
auto_decomposer: &AutoD,
ntt_op: &NttOp,
mod_op: &ModOp,
parameters: &P,
pbs_key: &K,
) where
<Mmut as Matrix>::R: RowMut,
Mmut::MatElement: Copy + Zero,
{
let mut is_trivial = true;
let mut scratch_matrix = RuntimeScratchMutRef::new(scratch_matrix.as_mut());
let mut rlwe = RlweCiphertextMutRef::new(trivial_rlwe_test_poly.as_mut());
let d_a = rlwe_rgsw_decomposer.a().decomposition_count().0;
let d_b = rlwe_rgsw_decomposer.b().decomposition_count().0;
let d_auto = auto_decomposer.decomposition_count().0;
let q_by_4 = q >> 2;
// let mut count = 0;
// -(g^k)
let mut v = 0;
for i in (1..q_by_4).rev() {
// dbg!(q_by_4 + i);
let s_indices = &gk_to_si[q_by_4 + i];
s_indices.iter().for_each(|s_index| {
// let new = std::time::Instant::now();
let ct = pbs_key.rgsw_ct_lwe_si(*s_index);
rlwe_by_rgsw_shoup(
&mut rlwe,
&RgswCiphertextRef::new(ct.as_ref().as_ref(), d_a, d_b),
&RgswCiphertextRef::new(ct.shoup_repr().as_ref(), d_a, d_b),
&mut scratch_matrix,
rlwe_rgsw_decomposer,
ntt_op,
mod_op,
is_trivial,
);
is_trivial = false;
// println!("Rlwe x Rgsw time: {:?}", new.elapsed());
});
v += 1;
if gk_to_si[q_by_4 + i - 1].len() != 0 || v == w || i == 1 {
let (auto_map_index, auto_map_sign) = parameters.rlwe_auto_map(v);
// let now = std::time::Instant::now();
let auto_key = pbs_key.galois_key_for_auto(v);
rlwe_auto_shoup(
&mut rlwe,
&RlweKskRef::new(auto_key.as_ref().as_ref(), d_auto),
&RlweKskRef::new(auto_key.shoup_repr().as_ref(), d_auto),
&mut scratch_matrix,
&auto_map_index,
&auto_map_sign,
mod_op,
ntt_op,
auto_decomposer,
is_trivial,
);
// println!("Auto time: {:?}", now.elapsed());
// count += 1;
v = 0;
}
}
// -(g^0)
{
gk_to_si[q_by_4].iter().for_each(|s_index| {
let ct = pbs_key.rgsw_ct_lwe_si(*s_index);
rlwe_by_rgsw_shoup(
&mut rlwe,
&RgswCiphertextRef::new(ct.as_ref().as_ref(), d_a, d_b),
&RgswCiphertextRef::new(ct.shoup_repr().as_ref(), d_a, d_b),
&mut scratch_matrix,
rlwe_rgsw_decomposer,
ntt_op,
mod_op,
is_trivial,
);
is_trivial = false;
});
let (auto_map_index, auto_map_sign) = parameters.rlwe_auto_map(0);
let auto_key = pbs_key.galois_key_for_auto(0);
rlwe_auto_shoup(
&mut rlwe,
&RlweKskRef::new(auto_key.as_ref().as_ref(), d_auto),
&RlweKskRef::new(auto_key.shoup_repr().as_ref(), d_auto),
&mut scratch_matrix,
&auto_map_index,
&auto_map_sign,
mod_op,
ntt_op,
auto_decomposer,
is_trivial,
);
// count += 1;
}
// +(g^k)
let mut v = 0;
for i in (1..q_by_4).rev() {
let s_indices = &gk_to_si[i];
s_indices.iter().for_each(|s_index| {
let ct = pbs_key.rgsw_ct_lwe_si(*s_index);
rlwe_by_rgsw_shoup(
&mut rlwe,
&RgswCiphertextRef::new(ct.as_ref().as_ref(), d_a, d_b),
&RgswCiphertextRef::new(ct.shoup_repr().as_ref(), d_a, d_b),
&mut scratch_matrix,
rlwe_rgsw_decomposer,
ntt_op,
mod_op,
is_trivial,
);
is_trivial = false;
});
v += 1;
if gk_to_si[i - 1].len() != 0 || v == w || i == 1 {
let (auto_map_index, auto_map_sign) = parameters.rlwe_auto_map(v);
let auto_key = pbs_key.galois_key_for_auto(v);
rlwe_auto_shoup(
&mut rlwe,
&RlweKskRef::new(auto_key.as_ref().as_ref(), d_auto),
&RlweKskRef::new(auto_key.shoup_repr().as_ref(), d_auto),
&mut scratch_matrix,
&auto_map_index,
&auto_map_sign,
mod_op,
ntt_op,
auto_decomposer,
is_trivial,
);
v = 0;
// count += 1;
}
}
// +(g^0)
gk_to_si[0].iter().for_each(|s_index| {
let ct = pbs_key.rgsw_ct_lwe_si(*s_index);
rlwe_by_rgsw_shoup(
&mut rlwe,
&RgswCiphertextRef::new(ct.as_ref().as_ref(), d_a, d_b),
&RgswCiphertextRef::new(ct.shoup_repr().as_ref(), d_a, d_b),
&mut scratch_matrix,
rlwe_rgsw_decomposer,
ntt_op,
mod_op,
is_trivial,
);
is_trivial = false;
});
// println!("Auto count: {count}");
}
fn mod_switch_odd(v: f64, from_q: f64, to_q: f64) -> usize {
let odd_v = (((v * to_q) / (from_q)).floor()).to_usize().unwrap();
//TODO(Jay): check correctness of this
odd_v + ((odd_v & 1) ^ 1)
}
// TODO(Jay): Add tests for sample extract
pub(crate) fn sample_extract<M: Matrix + MatrixMut, ModOp: ArithmeticOps<Element = M::MatElement>>(
lwe_out: &mut M::R,
rlwe_in: &M,
mod_op: &ModOp,
index: usize,
) where
<M as Matrix>::R: RowMut,
M::MatElement: Copy,
{
let ring_size = rlwe_in.dimension().1;
assert!(ring_size + 1 == lwe_out.as_ref().len());
// index..=0
let to = &mut lwe_out.as_mut()[1..];
let from = rlwe_in.get_row_slice(0);
for i in 0..index + 1 {
to[i] = from[index - i];
}
// -(N..index)
for i in index + 1..ring_size {
to[i] = mod_op.neg(&from[ring_size + index - i]);
}
// set b
lwe_out.as_mut()[0] = *rlwe_in.get(1, index);
}
/// Monomial multiplication (p(X)*X^{mon_exp})
///
/// - p_out: Output is written to p_out and independent of values in p_out
fn monomial_mul<El, ModOp: ArithmeticOps<Element = El>>(
p_in: &[El],
p_out: &mut [El],
mon_exp: usize,
mon_sign: bool,
ring_size: usize,
mod_op: &ModOp,
) where
El: Copy,
{
debug_assert!(p_in.as_ref().len() == ring_size);
debug_assert!(p_in.as_ref().len() == p_out.as_ref().len());
debug_assert!(mon_exp < ring_size);
p_in.as_ref().iter().enumerate().for_each(|(index, v)| {
let mut to_index = index + mon_exp;
let mut to_sign = mon_sign;
if to_index >= ring_size {
to_index = to_index - ring_size;
to_sign = !to_sign;
}
if !to_sign {
p_out.as_mut()[to_index] = mod_op.neg(v);
} else {
p_out.as_mut()[to_index] = *v;
}
});
}

+ 126
- 89
src/random.rs

@ -1,33 +1,69 @@
use std::cell::RefCell; use std::cell::RefCell;
use itertools::izip; use itertools::izip;
use rand::{distributions::Uniform, thread_rng, CryptoRng, Rng, RngCore, SeedableRng};
use num_traits::{FromPrimitive, PrimInt, Zero};
use rand::{distributions::Uniform, Rng, RngCore, SeedableRng};
use rand_chacha::ChaCha8Rng; use rand_chacha::ChaCha8Rng;
use rand_distr::Distribution;
use rand_distr::{uniform::SampleUniform, Distribution};
use crate::utils::WithLocal;
use crate::{backend::Modulus, utils::WithLocal};
thread_local! { thread_local! {
pub(crate) static DEFAULT_RNG: RefCell<DefaultSecureRng> = RefCell::new(DefaultSecureRng::new()); pub(crate) static DEFAULT_RNG: RefCell<DefaultSecureRng> = RefCell::new(DefaultSecureRng::new());
} }
pub trait RandomGaussianDist<M>
pub trait NewWithSeed {
type Seed;
fn new_with_seed(seed: Self::Seed) -> Self;
}
pub trait RandomElementInModulus<T, M> {
/// Sample Random element of type T in range [0, modulus)
fn random(&mut self, modulus: &M) -> T;
}
pub trait RandomGaussianElementInModulus<T, M> {
/// Sample Random gaussian element from \mu = 0.0 and \sigma = 3.19. Sampled
/// element is converted to signed representation in modulus.
fn random(&mut self, modulus: &M) -> T;
}
pub trait RandomFill<M>
where
M: ?Sized,
{
/// Fill container with random elements of type of its elements
fn random_fill(&mut self, container: &mut M);
}
pub trait RandomFillGaussian<M>
where
M: ?Sized,
{
/// Fill container with random elements sampled from normal distribtuion
/// with \mu = 0.0 and \sigma = 3.19.
fn random_fill(&mut self, container: &mut M);
}
pub trait RandomFillUniformInModulus<M, P>
where where
M: ?Sized, M: ?Sized,
{ {
type Parameters: ?Sized;
fn random_fill(&mut self, parameters: &Self::Parameters, container: &mut M);
/// Fill container with random elements in range [0, modulus)
fn random_fill(&mut self, modulus: &P, container: &mut M);
} }
pub trait RandomUniformDist<M>
pub trait RandomFillGaussianInModulus<M, P>
where where
M: ?Sized, M: ?Sized,
{ {
type Parameters: ?Sized;
fn random_fill(&mut self, parameters: &Self::Parameters, container: &mut M);
/// Fill container with gaussian elements sampled from normal distribution
/// with \mu = 0.0 and \sigma = 3.19. Elements are converted to signed
/// represented in the modulus.
fn random_fill(&mut self, modulus: &P, container: &mut M);
} }
pub(crate) struct DefaultSecureRng {
pub struct DefaultSecureRng {
rng: ChaCha8Rng, rng: ChaCha8Rng,
} }
@ -41,27 +77,30 @@ impl DefaultSecureRng {
let rng = ChaCha8Rng::from_entropy(); let rng = ChaCha8Rng::from_entropy();
DefaultSecureRng { rng } DefaultSecureRng { rng }
} }
}
impl RandomUniformDist<usize> for DefaultSecureRng {
type Parameters = usize;
fn random_fill(&mut self, parameters: &Self::Parameters, container: &mut usize) {
*container = self.rng.gen_range(0..*parameters);
pub fn fill_bytes(&mut self, a: &mut [u8; 32]) {
self.rng.fill_bytes(a);
} }
} }
impl RandomUniformDist<[u8]> for DefaultSecureRng {
type Parameters = u8;
fn random_fill(&mut self, parameters: &Self::Parameters, container: &mut [u8]) {
self.rng.fill_bytes(container);
impl NewWithSeed for DefaultSecureRng {
type Seed = <ChaCha8Rng as SeedableRng>::Seed;
fn new_with_seed(seed: Self::Seed) -> Self {
DefaultSecureRng::new_seeded(seed)
} }
} }
impl RandomUniformDist<[u32]> for DefaultSecureRng {
type Parameters = u32;
fn random_fill(&mut self, parameters: &Self::Parameters, container: &mut [u32]) {
impl<T, C> RandomFillUniformInModulus<[T], C> for DefaultSecureRng
where
T: PrimInt + SampleUniform,
C: Modulus<Element = T>,
{
fn random_fill(&mut self, modulus: &C, container: &mut [T]) {
izip!( izip!(
(&mut self.rng).sample_iter(Uniform::new(0, parameters)),
(&mut self.rng).sample_iter(Uniform::new_inclusive(
T::zero(),
modulus.largest_unsigned_value()
)),
container.iter_mut() container.iter_mut()
) )
.for_each(|(from, to)| { .for_each(|(from, to)| {
@ -70,99 +109,90 @@ impl RandomUniformDist<[u32]> for DefaultSecureRng {
} }
} }
impl RandomUniformDist<[u64]> for DefaultSecureRng {
type Parameters = u64;
fn random_fill(&mut self, parameters: &Self::Parameters, container: &mut [u64]) {
impl<T, C> RandomFillGaussianInModulus<[T], C> for DefaultSecureRng
where
T: PrimInt,
C: Modulus<Element = T>,
{
fn random_fill(&mut self, modulus: &C, container: &mut [T]) {
izip!( izip!(
(&mut self.rng).sample_iter(Uniform::new(0, parameters)),
rand_distr::Normal::new(0.0, 3.19f64)
.unwrap()
.sample_iter(&mut self.rng),
container.iter_mut() container.iter_mut()
) )
.for_each(|(from, to)| { .for_each(|(from, to)| {
*to = from;
*to = modulus.map_element_from_f64(from);
}); });
} }
} }
impl RandomGaussianDist<u64> for DefaultSecureRng {
type Parameters = u64;
fn random_fill(&mut self, parameters: &Self::Parameters, container: &mut u64) {
let o = rand_distr::Normal::new(0.0, 3.2f64)
.unwrap()
.sample(&mut self.rng)
.round();
// let o = 0.0f64;
let is_neg = o.is_sign_negative() && o != 0.0;
if is_neg {
*container = parameters - (o.abs() as u64);
} else {
*container = o as u64;
}
}
}
impl RandomGaussianDist<u32> for DefaultSecureRng {
type Parameters = u32;
fn random_fill(&mut self, parameters: &Self::Parameters, container: &mut u32) {
let o = rand_distr::Normal::new(0.0, 3.2f32)
.unwrap()
.sample(&mut self.rng)
.round();
// let o = 0.0f32;
let is_neg = o.is_sign_negative() && o != 0.0;
if is_neg {
*container = parameters - (o.abs() as u32);
} else {
*container = o as u32;
}
impl<T> RandomFill<[T]> for DefaultSecureRng
where
T: PrimInt + SampleUniform,
{
fn random_fill(&mut self, container: &mut [T]) {
izip!(
(&mut self.rng).sample_iter(Uniform::new_inclusive(T::zero(), T::max_value())),
container.iter_mut()
)
.for_each(|(from, to)| {
*to = from;
});
} }
} }
impl RandomGaussianDist<[u64]> for DefaultSecureRng {
type Parameters = u64;
fn random_fill(&mut self, parameters: &Self::Parameters, container: &mut [u64]) {
impl<T> RandomFillGaussian<[T]> for DefaultSecureRng
where
T: FromPrimitive,
{
fn random_fill(&mut self, container: &mut [T]) {
izip!( izip!(
rand_distr::Normal::new(0.0, 3.2f64)
rand_distr::Normal::new(0.0, 3.19f64)
.unwrap() .unwrap()
.sample_iter(&mut self.rng), .sample_iter(&mut self.rng),
container.iter_mut() container.iter_mut()
) )
.for_each(|(oi, v)| {
let oi = oi.round();
let is_neg = oi.is_sign_negative() && oi != 0.0;
if is_neg {
*v = parameters - (oi.abs() as u64);
} else {
*v = oi as u64;
}
.for_each(|(from, to)| {
*to = T::from_f64(from).unwrap();
}); });
} }
} }
impl RandomGaussianDist<[u32]> for DefaultSecureRng {
type Parameters = u32;
fn random_fill(&mut self, parameters: &Self::Parameters, container: &mut [u32]) {
impl<T> RandomFill<[T; 32]> for DefaultSecureRng
where
T: PrimInt + SampleUniform,
{
fn random_fill(&mut self, container: &mut [T; 32]) {
izip!( izip!(
rand_distr::Normal::new(0.0, 3.2f32)
.unwrap()
.sample_iter(&mut self.rng),
(&mut self.rng).sample_iter(Uniform::new_inclusive(T::zero(), T::max_value())),
container.iter_mut() container.iter_mut()
) )
.for_each(|(oi, v)| {
let oi = oi.round();
let is_neg = oi.is_sign_negative() && oi != 0.0;
if is_neg {
*v = parameters - (oi.abs() as u32);
} else {
*v = oi as u32;
}
.for_each(|(from, to)| {
*to = from;
}); });
} }
} }
impl<T> RandomElementInModulus<T, T> for DefaultSecureRng
where
T: Zero + SampleUniform,
{
fn random(&mut self, modulus: &T) -> T {
Uniform::new(T::zero(), modulus).sample(&mut self.rng)
}
}
impl<T, M: Modulus<Element = T>> RandomGaussianElementInModulus<T, M> for DefaultSecureRng {
fn random(&mut self, modulus: &M) -> T {
modulus.map_element_from_f64(
rand_distr::Normal::new(0.0, 3.19f64)
.unwrap()
.sample(&mut self.rng),
)
}
}
impl WithLocal for DefaultSecureRng { impl WithLocal for DefaultSecureRng {
fn with_local<F, R>(func: F) -> R fn with_local<F, R>(func: F) -> R
where where
@ -177,4 +207,11 @@ impl WithLocal for DefaultSecureRng {
{ {
DEFAULT_RNG.with_borrow_mut(|r| func(r)) DEFAULT_RNG.with_borrow_mut(|r| func(r))
} }
fn with_local_mut_mut<F, R>(func: &mut F) -> R
where
F: FnMut(&mut Self) -> R,
{
DEFAULT_RNG.with_borrow_mut(|r| func(r))
}
} }

+ 0
- 466
src/rgsw.rs

@ -1,466 +0,0 @@
use itertools::izip;
use crate::{
backend::VectorOps,
decomposer::{self, Decomposer},
ntt::Ntt,
random::{DefaultSecureRng, RandomGaussianDist, RandomUniformDist},
utils::{fill_random_ternary_secret_with_hamming_weight, TryConvertFrom, WithLocal},
Matrix, MatrixEntity, MatrixMut, RowMut, Secret,
};
struct RlweSecret {
values: Vec<i32>,
}
impl Secret for RlweSecret {
type Element = i32;
fn values(&self) -> &[Self::Element] {
&self.values
}
}
impl RlweSecret {
fn random(hw: usize, n: usize) -> RlweSecret {
DefaultSecureRng::with_local_mut(|rng| {
let mut out = vec![0i32; n];
fill_random_ternary_secret_with_hamming_weight(&mut out, hw, rng);
RlweSecret { values: out }
})
}
}
/// Encrypts message m as a RGSW ciphertext.
///
/// - m_eval: is `m` is evaluation domain
/// - out_rgsw: RGSW(m) is stored as single matrix of dimension (d_rgsw * 4,
/// ring_size). The matrix has the following structure [RLWE'_A(-sm) ||
/// RLWE'_B(-sm) || RLWE'_A(m) || RLWE'_B(m)]^T
fn encrypt_rgsw<
Mmut: MatrixMut + MatrixEntity,
M: Matrix<MatElement = Mmut::MatElement> + Clone,
S: Secret,
R: RandomGaussianDist<[Mmut::MatElement], Parameters = Mmut::MatElement>
+ RandomUniformDist<[Mmut::MatElement], Parameters = Mmut::MatElement>,
ModOp: VectorOps<Element = Mmut::MatElement>,
NttOp: Ntt<Element = Mmut::MatElement>,
>(
out_rgsw: &mut Mmut,
m_eval: &M,
gadget_vector: &[Mmut::MatElement],
s: &S,
mod_op: &ModOp,
ntt_op: &NttOp,
rng: &mut R,
) where
<Mmut as Matrix>::R: RowMut,
Mmut: TryConvertFrom<[S::Element], Parameters = Mmut::MatElement>,
{
let d = gadget_vector.len();
let q = mod_op.modulus();
let ring_size = s.values().len();
assert!(out_rgsw.dimension() == (d * 4, ring_size));
assert!(m_eval.dimension() == (1, ring_size));
// RLWE(-sm), RLWE(-sm)
let (rlwe_dash_nsm, rlwe_dash_m) = out_rgsw.split_at_row(d * 2);
let mut s_eval = Mmut::try_convert_from(s.values(), &q);
ntt_op.forward(s_eval.get_row_mut(0).as_mut());
let mut scratch_space = Mmut::zeros(1, ring_size);
// RLWE'(-sm)
let (a_rlwe_dash_nsm, b_rlwe_dash_nsm) = rlwe_dash_nsm.split_at_mut(d);
izip!(
a_rlwe_dash_nsm.iter_mut(),
b_rlwe_dash_nsm.iter_mut(),
gadget_vector.iter()
)
.for_each(|(ai, bi, beta_i)| {
// Sample a_i and transform to evaluation domain
RandomUniformDist::random_fill(rng, &q, ai.as_mut());
ntt_op.forward(ai.as_mut());
// a_i * s
mod_op.elwise_mul(
scratch_space.get_row_mut(0),
ai.as_ref(),
s_eval.get_row_slice(0),
);
// b_i = e_i + a_i * s
RandomGaussianDist::random_fill(rng, &q, bi.as_mut());
ntt_op.forward(bi.as_mut());
mod_op.elwise_add_mut(bi.as_mut(), scratch_space.get_row_slice(0));
// a_i + \beta_i * m
mod_op.elwise_scalar_mul(
scratch_space.get_row_mut(0),
m_eval.get_row_slice(0),
beta_i,
);
mod_op.elwise_add_mut(ai.as_mut(), scratch_space.get_row_slice(0));
});
// RLWE(m)
let (a_rlwe_dash_m, b_rlwe_dash_m) = rlwe_dash_m.split_at_mut(d);
izip!(
a_rlwe_dash_m.iter_mut(),
b_rlwe_dash_m.iter_mut(),
gadget_vector.iter()
)
.for_each(|(ai, bi, beta_i)| {
// Sample e_i and transform to evaluation domain
RandomGaussianDist::random_fill(rng, &q, bi.as_mut());
ntt_op.forward(bi.as_mut());
// beta_i * m
mod_op.elwise_scalar_mul(
scratch_space.get_row_mut(0),
m_eval.get_row_slice(0),
beta_i,
);
// e_i + beta_i * m
mod_op.elwise_add_mut(bi.as_mut(), scratch_space.get_row_slice(0));
// Sample a_i and transform to evaluation domain
RandomUniformDist::random_fill(rng, &q, ai.as_mut());
ntt_op.forward(ai.as_mut());
// ai * s
mod_op.elwise_mul(
scratch_space.get_row_mut(0),
ai.as_ref(),
s_eval.get_row_slice(0),
);
// b_i = a_i*s + e_i + beta_i*m
mod_op.elwise_add_mut(bi.as_mut(), scratch_space.get_row_slice(0));
});
}
/// Returns RLWE(mm') = RLWE(m) x RGSW(m')
///
/// - rgsw_in: RGSW(m') in evaluation domain
/// - rlwe_in_decomposed: decomposed RLWE(m) in evaluation domain
/// - rlwe_out: returned RLWE(mm') in evaluation domain
fn rlwe_in_decomposed_evaluation_domain_mul_rgsw_rlwe_out_evaluation_domain<
Mmut: MatrixMut + MatrixEntity,
M: Matrix<MatElement = Mmut::MatElement> + Clone,
ModOp: VectorOps<Element = Mmut::MatElement>,
>(
rgsw_in: &M,
rlwe_in_decomposed_eval: &Mmut,
rlwe_out_eval: &mut Mmut,
mod_op: &ModOp,
) where
<Mmut as Matrix>::R: RowMut,
{
let ring_size = rgsw_in.dimension().1;
let d_rgsw = rgsw_in.dimension().0 / 4;
assert!(rlwe_in_decomposed_eval.dimension() == (2 * d_rgsw, ring_size));
assert!(rlwe_out_eval.dimension() == (2, ring_size));
let (a_rlwe_out, b_rlwe_out) = rlwe_out_eval.split_at_row(1);
// a * RLWE'(-sm)
let a_rlwe_dash_nsm = rgsw_in.iter_rows().take(d_rgsw);
let b_rlwe_dash_nsm = rgsw_in.iter_rows().skip(d_rgsw).take(d_rgsw);
izip!(
rlwe_in_decomposed_eval.iter_rows().take(d_rgsw),
a_rlwe_dash_nsm
)
.for_each(|(a, b)| {
mod_op.elwise_fma_mut(a_rlwe_out[0].as_mut(), a.as_ref(), b.as_ref());
});
izip!(
rlwe_in_decomposed_eval.iter_rows().take(d_rgsw),
b_rlwe_dash_nsm
)
.for_each(|(a, b)| {
mod_op.elwise_fma_mut(b_rlwe_out[0].as_mut(), a.as_ref(), b.as_ref());
});
// b * RLWE'(m)
let a_rlwe_dash_m = rgsw_in.iter_rows().skip(d_rgsw * 2).take(d_rgsw);
let b_rlwe_dash_m = rgsw_in.iter_rows().skip(d_rgsw * 3);
izip!(
rlwe_in_decomposed_eval.iter_rows().skip(d_rgsw),
a_rlwe_dash_m
)
.for_each(|(a, b)| {
mod_op.elwise_fma_mut(a_rlwe_out[0].as_mut(), a.as_ref(), b.as_ref());
});
izip!(
rlwe_in_decomposed_eval.iter_rows().skip(d_rgsw),
b_rlwe_dash_m
)
.for_each(|(a, b)| {
mod_op.elwise_fma_mut(b_rlwe_out[0].as_mut(), a.as_ref(), b.as_ref());
});
}
fn decompose_rlwe<
M: Matrix + Clone,
Mmut: MatrixMut<MatElement = M::MatElement> + MatrixEntity,
D: Decomposer<Element = M::MatElement>,
>(
rlwe_in: &M,
decomposer: &D,
rlwe_in_decomposed: &mut Mmut,
) where
M::MatElement: Copy,
<Mmut as Matrix>::R: RowMut,
{
let d_rgsw = decomposer.d();
let ring_size = rlwe_in.dimension().1;
assert!(rlwe_in_decomposed.dimension() == (2 * d_rgsw, ring_size));
// Decompose rlwe_in
for ri in 0..ring_size {
// ai
let ai_decomposed = decomposer.decompose(rlwe_in.get(0, ri));
for j in 0..d_rgsw {
rlwe_in_decomposed.set(j, ri, ai_decomposed[j]);
}
// bi
let bi_decomposed = decomposer.decompose(rlwe_in.get(1, ri));
for j in 0..d_rgsw {
rlwe_in_decomposed.set(j + d_rgsw, ri, bi_decomposed[j]);
}
}
}
/// Returns RLWE(m0m1) = RLWE(m0) x RGSW(m1)
///
/// - rlwe_in: is RLWE(m0) with polynomials in coefficient domain
/// - rgsw_in: is RGSW(m1) with polynomials in evaluation domain
/// - rlwe_out: is output RLWE(m0m1) with polynomials in coefficient domain
/// - rlwe_in_decomposed: is a matrix of dimension (d_rgsw * 2, ring_size) used
/// as scratch space to store decomposed RLWE(m0)
fn rlwe_by_rgsw<
M: Matrix + Clone,
Mmut: MatrixMut<MatElement = M::MatElement> + MatrixEntity,
D: Decomposer<Element = M::MatElement>,
ModOp: VectorOps<Element = M::MatElement>,
NttOp: Ntt<Element = M::MatElement>,
>(
rlwe_in: &M,
rgsw_in: &M,
rlwe_out: &mut Mmut,
rlwe_in_decomposed: &mut Mmut,
decomposer: &D,
ntt_op: &NttOp,
mod_op: &ModOp,
) where
M::MatElement: Copy,
<Mmut as Matrix>::R: RowMut,
{
decompose_rlwe(rlwe_in, decomposer, rlwe_in_decomposed);
// transform rlwe_in decomposed to evaluation domain
rlwe_in_decomposed
.iter_rows_mut()
.for_each(|r| ntt_op.forward(r.as_mut()));
// decomposed RLWE x RGSW
rlwe_in_decomposed_evaluation_domain_mul_rgsw_rlwe_out_evaluation_domain(
rgsw_in,
rlwe_in_decomposed,
rlwe_out,
mod_op,
);
// transform rlwe_out to coefficient domain
rlwe_out
.iter_rows_mut()
.for_each(|r| ntt_op.backward(r.as_mut()));
}
/// Encrypt polynomial m(X) as RLWE ciphertext.
///
/// - rlwe_out: returned RLWE ciphertext RLWE(m) in coefficient domain. RLWE
/// ciphertext is a matirx with first row consiting polynomial `a` and the
/// second rows consting polynomial `b`
fn encrypt_rlwe<
Mmut: Matrix + MatrixMut + Clone,
ModOp: VectorOps<Element = Mmut::MatElement>,
NttOp: Ntt<Element = Mmut::MatElement>,
S: Secret,
R: RandomUniformDist<[Mmut::MatElement], Parameters = Mmut::MatElement>
+ RandomGaussianDist<[Mmut::MatElement], Parameters = Mmut::MatElement>,
>(
m: &Mmut,
rlwe_out: &mut Mmut,
s: &S,
mod_op: &ModOp,
ntt_op: &NttOp,
rng: &mut R,
) where
<Mmut as Matrix>::R: RowMut,
Mmut: TryConvertFrom<[S::Element], Parameters = Mmut::MatElement>,
{
let ring_size = s.values().len();
assert!(rlwe_out.dimension() == (2, ring_size));
assert!(m.dimension() == (1, ring_size));
let q = mod_op.modulus();
// sample a
RandomUniformDist::random_fill(rng, &q, rlwe_out.get_row_mut(0));
// s * a
let mut sa = Mmut::try_convert_from(s.values(), &q);
ntt_op.forward(sa.get_row_mut(0));
ntt_op.forward(rlwe_out.get_row_mut(0));
mod_op.elwise_mul_mut(sa.get_row_mut(0), rlwe_out.get_row_slice(0));
ntt_op.backward(rlwe_out.get_row_mut(0));
ntt_op.backward(sa.get_row_mut(0));
// sample e
RandomGaussianDist::random_fill(rng, &q, rlwe_out.get_row_mut(1));
mod_op.elwise_add_mut(rlwe_out.get_row_mut(1), m.get_row_slice(0));
mod_op.elwise_add_mut(rlwe_out.get_row_mut(1), sa.get_row_slice(0));
}
/// Decrypts degree 1 RLWE ciphertext RLWE(m) and returns m
///
/// - rlwe_ct: input degree 1 ciphertext RLWE(m).
fn decrypt_rlwe<
Mmut: MatrixMut + Clone,
M: Matrix<MatElement = Mmut::MatElement>,
ModOp: VectorOps<Element = Mmut::MatElement>,
NttOp: Ntt<Element = Mmut::MatElement>,
S: Secret,
>(
rlwe_ct: &M,
s: &S,
m_out: &mut Mmut,
ntt_op: &NttOp,
mod_op: &ModOp,
) where
<Mmut as Matrix>::R: RowMut,
Mmut: TryConvertFrom<[S::Element], Parameters = Mmut::MatElement>,
Mmut::MatElement: Copy,
{
let ring_size = s.values().len();
assert!(rlwe_ct.dimension() == (2, ring_size));
assert!(m_out.dimension() == (1, ring_size));
// transform a to evluation form
m_out
.get_row_mut(0)
.copy_from_slice(rlwe_ct.get_row_slice(0));
ntt_op.forward(m_out.get_row_mut(0));
// -s*a
let mut s = Mmut::try_convert_from(&s.values(), &mod_op.modulus());
ntt_op.forward(s.get_row_mut(0));
mod_op.elwise_mul_mut(m_out.get_row_mut(0), s.get_row_slice(0));
mod_op.elwise_neg_mut(m_out.get_row_mut(0));
ntt_op.backward(m_out.get_row_mut(0));
// m+e = b - s*a
mod_op.elwise_add_mut(m_out.get_row_mut(0), rlwe_ct.get_row_slice(1));
}
#[cfg(test)]
mod tests {
use std::vec;
use itertools::Itertools;
use rand::{thread_rng, Rng};
use crate::{
backend::ModularOpsU64,
decomposer::{gadget_vector, DefaultDecomposer},
ntt::{self, Ntt, NttBackendU64},
random::{DefaultSecureRng, RandomUniformDist},
utils::{generate_prime, negacyclic_mul},
};
use super::{decrypt_rlwe, encrypt_rgsw, encrypt_rlwe, rlwe_by_rgsw, RlweSecret};
#[test]
fn rlwe_by_rgsw_works() {
let logq = 50;
let logp = 3;
let ring_size = 1 << 10;
let q = generate_prime(logq, ring_size, 1u64 << logq).unwrap();
let p = 1u64 << logp;
let d_rgsw = 10;
let logb = 5;
let mut rng = DefaultSecureRng::new();
let s = RlweSecret::random((ring_size >> 1) as usize, ring_size as usize);
let mut m0 = vec![0u64; ring_size as usize];
RandomUniformDist::<[u64]>::random_fill(&mut rng, &(1u64 << logp), m0.as_mut_slice());
let mut m1 = vec![0u64; ring_size as usize];
m1[thread_rng().gen_range(0..ring_size) as usize] = 1;
let ntt_op = NttBackendU64::new(q, ring_size as usize);
let mod_op = ModularOpsU64::new(q);
// Encrypt m1 as RGSW(m1)
let mut rgsw_ct = vec![vec![0u64; ring_size as usize]; d_rgsw * 4];
let gadget_vector = gadget_vector(logq, logb, d_rgsw);
let mut m1_eval = m1.clone();
ntt_op.forward(&mut m1_eval);
encrypt_rgsw(
&mut rgsw_ct,
&vec![m1_eval],
&gadget_vector,
&s,
&mod_op,
&ntt_op,
&mut rng,
);
// println!("RGSW(m1): {:?}", &rgsw_ct);
// Encrypt m0 as RLWE(m0)
let mut rlwe_in_ct = vec![vec![0u64; ring_size as usize]; 2];
let encoded_m = m0
.iter()
.map(|v| (((*v as f64) * q as f64) / (p as f64)).round() as u64)
.collect_vec();
encrypt_rlwe(
&vec![encoded_m.clone()],
&mut rlwe_in_ct,
&s,
&mod_op,
&ntt_op,
&mut rng,
);
// RLWE(m0m1) = RLWE(m0) x RGSW(m1)
let mut rlwe_out_ct = vec![vec![0u64; ring_size as usize]; 2];
let mut scratch_space = vec![vec![0u64; ring_size as usize]; d_rgsw * 2];
let decomposer = DefaultDecomposer::new(q, logb, d_rgsw);
rlwe_by_rgsw(
&rlwe_in_ct,
&rgsw_ct,
&mut rlwe_out_ct,
&mut scratch_space,
&decomposer,
&ntt_op,
&mod_op,
);
// Decrypt RLWE(m0m1)
let mut encoded_m0m1_back = vec![vec![0u64; ring_size as usize]];
decrypt_rlwe(&rlwe_out_ct, &s, &mut encoded_m0m1_back, &ntt_op, &mod_op);
let m0m1_back = encoded_m0m1_back[0]
.iter()
.map(|v| (((*v as f64 * p as f64) / (q as f64)).round() as u64) % p)
.collect_vec();
let mul_mod = |v0: &u64, v1: &u64| (v0 * v1) % p;
let m0m1 = negacyclic_mul(&m0, &m1, mul_mod, p);
assert_eq!(m0m1, m0m1_back, "Expected {:?} got {:?}", m0m1, m0m1_back);
// dbg!(&m0m1_back, m0m1, q);
}
}

+ 677
- 0
src/rgsw/keygen.rs

@ -0,0 +1,677 @@
use std::{fmt::Debug, ops::Sub};
use itertools::izip;
use num_traits::{PrimInt, Signed, ToPrimitive, Zero};
use crate::{
backend::{ArithmeticOps, GetModulus, Modulus, VectorOps},
ntt::Ntt,
random::{
RandomElementInModulus, RandomFill, RandomFillGaussianInModulus, RandomFillUniformInModulus,
},
utils::{fill_random_ternary_secret_with_hamming_weight, TryConvertFrom1},
Matrix, MatrixEntity, MatrixMut, Row, RowEntity, RowMut,
};
pub(crate) fn generate_auto_map(ring_size: usize, k: isize) -> (Vec<usize>, Vec<bool>) {
assert!(k & 1 == 1, "Auto {k} must be odd");
let k = if k < 0 {
// k is -ve, return k%(2*N)
(2 * ring_size) - (k.abs() as usize % (2 * ring_size))
} else {
k as usize
};
let (auto_map_index, auto_sign_index): (Vec<usize>, Vec<bool>) = (0..ring_size)
.into_iter()
.map(|i| {
let mut to_index = (i * k) % (2 * ring_size);
let mut sign = true;
// wrap around. false implies negative
if to_index >= ring_size {
to_index = to_index - ring_size;
sign = false;
}
(to_index, sign)
})
.unzip();
(auto_map_index, auto_sign_index)
}
/// Returns RGSW(m)
///
/// RGSW = [RLWE'(-sm) || RLWE(m)] = [RLWE'_A(-sm), RLWE'_B(-sm), RLWE'_A(m),
/// RLWE'_B(m)]
///
/// RGSW(m1) ciphertext is used for RLWE(m0) x RGSW(m1) multiplication.
/// Let RLWE(m) = [a, b] where b = as + e + m0.
/// For RLWExRGSW we calculate:
/// (\sum signed_decompose(a)[i] x RLWE(-s \beta^i' m1))
/// + (\sum signed_decompose(b)[i'] x RLWE(\beta^i' m1))
/// = RLWE(m0m1)
/// We denote decomposer count for signed_decompose(a)[i] with d_a and
/// corresponding gadget vector with `gadget_a`. We denote decomposer count for
/// signed_decompose(b)[i] with d_b and corresponding gadget vector with
/// `gadget_b`
///
/// In secret key RGSW encrypton RLWE'_A(m) can be seeded. Hence, we seed it
/// using the `p_rng` passed and the retured RGSW ciphertext has d_a * 2 + d_b
/// rows
///
/// - s: is the secret key
/// - m: message to encrypt
/// - gadget_a: Gadget vector for RLWE'(-sm)
/// - gadget_b: Gadget vector for RLWE'(m)
/// - p_rng: Seeded psuedo random generator used to sample RLWE'_A(m).
pub(crate) fn secret_key_encrypt_rgsw<
Mmut: MatrixMut + MatrixEntity,
S,
R: RandomFillGaussianInModulus<[Mmut::MatElement], ModOp::M>
+ RandomFillUniformInModulus<[Mmut::MatElement], ModOp::M>,
PR: RandomFillUniformInModulus<[Mmut::MatElement], ModOp::M>,
ModOp: VectorOps<Element = Mmut::MatElement> + GetModulus<Element = Mmut::MatElement>,
NttOp: Ntt<Element = Mmut::MatElement>,
>(
out_rgsw: &mut Mmut,
m: &[Mmut::MatElement],
gadget_a: &[Mmut::MatElement],
gadget_b: &[Mmut::MatElement],
s: &[S],
mod_op: &ModOp,
ntt_op: &NttOp,
p_rng: &mut PR,
rng: &mut R,
) where
<Mmut as Matrix>::R: RowMut + RowEntity + TryConvertFrom1<[S], ModOp::M> + Debug,
Mmut::MatElement: Copy + Debug,
{
let d_a = gadget_a.len();
let d_b = gadget_b.len();
let q = mod_op.modulus();
let ring_size = s.len();
assert!(out_rgsw.dimension() == (d_a * 2 + d_b, ring_size));
assert!(m.as_ref().len() == ring_size);
// RLWE(-sm), RLWE(m)
let (rlwe_dash_nsm, b_rlwe_dash_m) = out_rgsw.split_at_row_mut(d_a * 2);
let mut s_eval = Mmut::R::try_convert_from(s, &q);
ntt_op.forward(s_eval.as_mut());
let mut scratch_space = Mmut::R::zeros(ring_size);
// RLWE'(-sm)
let (a_rlwe_dash_nsm, b_rlwe_dash_nsm) = rlwe_dash_nsm.split_at_mut(d_a);
izip!(
a_rlwe_dash_nsm.iter_mut(),
b_rlwe_dash_nsm.iter_mut(),
gadget_a.iter()
)
.for_each(|(ai, bi, beta_i)| {
// Sample a_i
RandomFillUniformInModulus::random_fill(rng, &q, ai.as_mut());
// a_i * s
scratch_space.as_mut().copy_from_slice(ai.as_ref());
ntt_op.forward(scratch_space.as_mut());
mod_op.elwise_mul_mut(scratch_space.as_mut(), s_eval.as_ref());
ntt_op.backward(scratch_space.as_mut());
// b_i = e_i + a_i * s
RandomFillGaussianInModulus::random_fill(rng, &q, bi.as_mut());
mod_op.elwise_add_mut(bi.as_mut(), scratch_space.as_ref());
// a_i + \beta_i * m
mod_op.elwise_scalar_mul(scratch_space.as_mut(), m.as_ref(), beta_i);
mod_op.elwise_add_mut(ai.as_mut(), scratch_space.as_ref());
});
// RLWE(m)
let mut a_rlwe_dash_m = {
// polynomials of part A of RLWE'(m) are sampled from seed
let mut a = Mmut::zeros(d_b, ring_size);
a.iter_rows_mut()
.for_each(|ai| RandomFillUniformInModulus::random_fill(p_rng, &q, ai.as_mut()));
a
};
izip!(
a_rlwe_dash_m.iter_rows_mut(),
b_rlwe_dash_m.iter_mut(),
gadget_b.iter()
)
.for_each(|(ai, bi, beta_i)| {
// ai * s
ntt_op.forward(ai.as_mut());
mod_op.elwise_mul_mut(ai.as_mut(), s_eval.as_ref());
ntt_op.backward(ai.as_mut());
// beta_i * m
mod_op.elwise_scalar_mul(scratch_space.as_mut(), m.as_ref(), beta_i);
// Sample e_i
RandomFillGaussianInModulus::random_fill(rng, &q, bi.as_mut());
// e_i + beta_i * m + ai*s
mod_op.elwise_add_mut(bi.as_mut(), scratch_space.as_ref());
mod_op.elwise_add_mut(bi.as_mut(), ai.as_ref());
});
}
/// Returns RGSW(m) encrypted with public key
///
/// Follows the same routine as `secret_key_encrypt_rgsw` but with the
/// difference that each RLWE encryption uses public key instead of secret key.
///
/// Since public key encryption cannot be seeded `RLWE'_A(m)` is included in the
/// ciphertext. Hence the returned RGSW ciphertext has d_a * 2 + d_b * 2 rows
pub(crate) fn public_key_encrypt_rgsw<
Mmut: MatrixMut + MatrixEntity,
M: Matrix<MatElement = Mmut::MatElement>,
R: RandomFillGaussianInModulus<[Mmut::MatElement], ModOp::M>
+ RandomFill<[u8]>
+ RandomElementInModulus<usize, usize>,
ModOp: VectorOps<Element = Mmut::MatElement> + GetModulus<Element = Mmut::MatElement>,
NttOp: Ntt<Element = Mmut::MatElement>,
>(
out_rgsw: &mut Mmut,
m: &[M::MatElement],
public_key: &M,
gadget_a: &[Mmut::MatElement],
gadget_b: &[Mmut::MatElement],
mod_op: &ModOp,
ntt_op: &NttOp,
rng: &mut R,
) where
<Mmut as Matrix>::R: RowMut + RowEntity + TryConvertFrom1<[i32], ModOp::M>,
Mmut::MatElement: Copy,
{
let ring_size = public_key.dimension().1;
let d_a = gadget_a.len();
let d_b = gadget_b.len();
assert!(public_key.dimension().0 == 2);
assert!(out_rgsw.dimension() == (d_a * 2 + d_b * 2, ring_size));
let mut pk_eval = Mmut::zeros(2, ring_size);
izip!(pk_eval.iter_rows_mut(), public_key.iter_rows()).for_each(|(to_i, from_i)| {
to_i.as_mut().copy_from_slice(from_i.as_ref());
ntt_op.forward(to_i.as_mut());
});
let p0 = pk_eval.get_row_slice(0);
let p1 = pk_eval.get_row_slice(1);
let q = mod_op.modulus();
// RGSW(m) = RLWE'(-sm), RLWE(m)
let (rlwe_dash_nsm, rlwe_dash_m) = out_rgsw.split_at_row_mut(d_a * 2);
// RLWE(-sm)
let (rlwe_dash_nsm_parta, rlwe_dash_nsm_partb) = rlwe_dash_nsm.split_at_mut(d_a);
izip!(
rlwe_dash_nsm_parta.iter_mut(),
rlwe_dash_nsm_partb.iter_mut(),
gadget_a.iter()
)
.for_each(|(ai, bi, beta_i)| {
// sample ephemeral secret u_i
let mut u = vec![0i32; ring_size];
fill_random_ternary_secret_with_hamming_weight(u.as_mut(), ring_size >> 1, rng);
let mut u_eval = Mmut::R::try_convert_from(u.as_ref(), &q);
ntt_op.forward(u_eval.as_mut());
let mut u_eval_copy = Mmut::R::zeros(ring_size);
u_eval_copy.as_mut().copy_from_slice(u_eval.as_ref());
// p0 * u
mod_op.elwise_mul_mut(u_eval.as_mut(), p0.as_ref());
// p1 * u
mod_op.elwise_mul_mut(u_eval_copy.as_mut(), p1.as_ref());
ntt_op.backward(u_eval.as_mut());
ntt_op.backward(u_eval_copy.as_mut());
// sample error
RandomFillGaussianInModulus::random_fill(rng, &q, ai.as_mut());
RandomFillGaussianInModulus::random_fill(rng, &q, bi.as_mut());
// a = p0*u+e0
mod_op.elwise_add_mut(ai.as_mut(), u_eval.as_ref());
// b = p1*u+e1
mod_op.elwise_add_mut(bi.as_mut(), u_eval_copy.as_ref());
// a = p0*u + e0 + \beta*m
// use u_eval as scratch
mod_op.elwise_scalar_mul(u_eval.as_mut(), m.as_ref(), beta_i);
mod_op.elwise_add_mut(ai.as_mut(), u_eval.as_ref());
});
// RLWE(m)
let (rlwe_dash_m_parta, rlwe_dash_m_partb) = rlwe_dash_m.split_at_mut(d_b);
izip!(
rlwe_dash_m_parta.iter_mut(),
rlwe_dash_m_partb.iter_mut(),
gadget_b.iter()
)
.for_each(|(ai, bi, beta_i)| {
// sample ephemeral secret u_i
let mut u = vec![0i32; ring_size];
fill_random_ternary_secret_with_hamming_weight(u.as_mut(), ring_size >> 1, rng);
let mut u_eval = Mmut::R::try_convert_from(u.as_ref(), &q);
ntt_op.forward(u_eval.as_mut());
let mut u_eval_copy = Mmut::R::zeros(ring_size);
u_eval_copy.as_mut().copy_from_slice(u_eval.as_ref());
// p0 * u
mod_op.elwise_mul_mut(u_eval.as_mut(), p0.as_ref());
// p1 * u
mod_op.elwise_mul_mut(u_eval_copy.as_mut(), p1.as_ref());
ntt_op.backward(u_eval.as_mut());
ntt_op.backward(u_eval_copy.as_mut());
// sample error
RandomFillGaussianInModulus::random_fill(rng, &q, ai.as_mut());
RandomFillGaussianInModulus::random_fill(rng, &q, bi.as_mut());
// a = p0*u+e0
mod_op.elwise_add_mut(ai.as_mut(), u_eval.as_ref());
// b = p1*u+e1
mod_op.elwise_add_mut(bi.as_mut(), u_eval_copy.as_ref());
// b = p1*u + e0 + \beta*m
// use u_eval as scratch
mod_op.elwise_scalar_mul(u_eval.as_mut(), m.as_ref(), beta_i);
mod_op.elwise_add_mut(bi.as_mut(), u_eval.as_ref());
});
}
/// Returns key switching key to key switch ciphertext RLWE_{from_s}(m)
/// to RLWE_{to_s}(m).
///
/// Let key switching decomposer have `d` decompostion count with gadget vector:
/// [1, \beta, ..., \beta^d-1]
///
/// Key switching key consists of `d` RLWE ciphertexts:
/// RLWE'_{to_s}(-from_s) = [RLWE_{to_s}(\beta^i -from_s)]
///
/// In RLWE(m) s.t. b = as + e + m where s is the secret key, `a` can be seeded.
/// And we seed all RLWE ciphertexts in key switchin key.
///
/// - neg_from_s: Negative of secret polynomial to key switch from (i.e.
/// -from_s)
/// - to_s: secret polynomial to key switch to.
/// - gadget_vector: Gadget vector of decomposer used in key switch
/// - p_rng: Seeded pseudo random generate used to generate `a` polynomials of
/// key switching key RLWE ciphertexts
fn seeded_rlwe_ksk_gen<
Mmut: MatrixMut + MatrixEntity,
ModOp: ArithmeticOps<Element = Mmut::MatElement>
+ VectorOps<Element = Mmut::MatElement>
+ GetModulus<Element = Mmut::MatElement>,
NttOp: Ntt<Element = Mmut::MatElement>,
R: RandomFillGaussianInModulus<[Mmut::MatElement], ModOp::M>,
PR: RandomFillUniformInModulus<[Mmut::MatElement], ModOp::M>,
>(
ksk_out: &mut Mmut,
neg_from_s: Mmut::R,
mut to_s: Mmut::R,
gadget_vector: &[Mmut::MatElement],
mod_op: &ModOp,
ntt_op: &NttOp,
p_rng: &mut PR,
rng: &mut R,
) where
<Mmut as Matrix>::R: RowMut,
{
let ring_size = neg_from_s.as_ref().len();
let d = gadget_vector.len();
assert!(ksk_out.dimension() == (d, ring_size));
let q = mod_op.modulus();
ntt_op.forward(to_s.as_mut());
// RLWE'_{to_s}(-from_s)
let mut part_a = {
let mut a = Mmut::zeros(d, ring_size);
a.iter_rows_mut()
.for_each(|ai| RandomFillUniformInModulus::random_fill(p_rng, q, ai.as_mut()));
a
};
izip!(
part_a.iter_rows_mut(),
ksk_out.iter_rows_mut(),
gadget_vector.iter(),
)
.for_each(|(ai, bi, beta_i)| {
// si * ai
ntt_op.forward(ai.as_mut());
mod_op.elwise_mul_mut(ai.as_mut(), to_s.as_ref());
ntt_op.backward(ai.as_mut());
// ei + to_s*ai
RandomFillGaussianInModulus::random_fill(rng, &q, bi.as_mut());
mod_op.elwise_add_mut(bi.as_mut(), ai.as_ref());
// beta_i * -from_s
// use ai as scratch space
mod_op.elwise_scalar_mul(ai.as_mut(), neg_from_s.as_ref(), beta_i);
// bi = ei + to_s*ai + beta_i*-from_s
mod_op.elwise_add_mut(bi.as_mut(), ai.as_ref());
});
}
/// Returns auto key to send RLWE(m(X)) -> RLWE(m(X^k))
///
/// Auto key is key switchin key that key-switches RLWE_{s(X^k)}(m(X^k)) to
/// RLWE_{s(X)}(m(X^k)).
///
/// - s: secret polynomial s(X)
/// - auto_k: k used in for autmorphism X -> X^k
/// - gadget_vector: Gadget vector corresponding to decomposer used in key
/// switch
/// - p_rng: pseudo random generator used to generate `a` polynomials of key
/// switching key RLWE ciphertexts
pub(crate) fn seeded_auto_key_gen<
Mmut: MatrixMut + MatrixEntity,
ModOp: ArithmeticOps<Element = Mmut::MatElement>
+ VectorOps<Element = Mmut::MatElement>
+ GetModulus<Element = Mmut::MatElement>,
NttOp: Ntt<Element = Mmut::MatElement>,
S,
R: RandomFillGaussianInModulus<[Mmut::MatElement], ModOp::M>,
PR: RandomFillUniformInModulus<[Mmut::MatElement], ModOp::M>,
>(
ksk_out: &mut Mmut,
s: &[S],
auto_k: isize,
gadget_vector: &[Mmut::MatElement],
mod_op: &ModOp,
ntt_op: &NttOp,
p_rng: &mut PR,
rng: &mut R,
) where
<Mmut as Matrix>::R: RowMut,
Mmut::R: TryConvertFrom1<[S], ModOp::M> + RowEntity,
Mmut::MatElement: Copy + Sub<Output = Mmut::MatElement>,
{
let ring_size = s.len();
let (auto_map_index, auto_map_sign) = generate_auto_map(ring_size, auto_k);
let q = mod_op.modulus();
// s(X) -> -s(X^k)
let s = Mmut::R::try_convert_from(s, q);
let mut neg_s_auto = Mmut::R::zeros(s.as_ref().len());
izip!(s.as_ref(), auto_map_index.iter(), auto_map_sign.iter()).for_each(
|(el, to_index, sign)| {
// if sign is +ve (true), then negate because we need -s(X) (i.e. do the
// opposite than the usual case)
if *sign {
neg_s_auto.as_mut()[*to_index] = mod_op.neg(el);
} else {
neg_s_auto.as_mut()[*to_index] = *el;
}
},
);
// Ksk from -s(X^k) to s(X)
seeded_rlwe_ksk_gen(
ksk_out,
neg_s_auto,
s,
gadget_vector,
mod_op,
ntt_op,
p_rng,
rng,
);
}
/// Returns seeded RLWE(m(X))
///
/// RLWE(m(X)) = [a(X), b(X) = a(X)s(X) + e(X) + m(X)]
///
/// a(X) of RLWE encyrptions using secret key s(X) can be seeded. We use seeded
/// pseudo random generator `p_rng` to sample a(X) and return seeded RLWE
/// ciphertext (i.e. only b(X))
pub(crate) fn seeded_secret_key_encrypt_rlwe<
Ro: Row + RowMut + RowEntity,
ModOp: VectorOps<Element = Ro::Element> + GetModulus<Element = Ro::Element>,
NttOp: Ntt<Element = Ro::Element>,
S,
R: RandomFillGaussianInModulus<[Ro::Element], ModOp::M>,
PR: RandomFillUniformInModulus<[Ro::Element], ModOp::M>,
>(
m: &[Ro::Element],
b_rlwe_out: &mut Ro,
s: &[S],
mod_op: &ModOp,
ntt_op: &NttOp,
p_rng: &mut PR,
rng: &mut R,
) where
Ro: TryConvertFrom1<[S], ModOp::M> + Debug,
{
let ring_size = s.len();
assert!(m.as_ref().len() == ring_size);
assert!(b_rlwe_out.as_ref().len() == ring_size);
let q = mod_op.modulus();
// sample a
let mut a = {
let mut a = Ro::zeros(ring_size);
RandomFillUniformInModulus::random_fill(p_rng, q, a.as_mut());
a
};
// s * a
let mut sa = Ro::try_convert_from(s, q);
ntt_op.forward(sa.as_mut());
ntt_op.forward(a.as_mut());
mod_op.elwise_mul_mut(sa.as_mut(), a.as_ref());
ntt_op.backward(sa.as_mut());
// sample e
RandomFillGaussianInModulus::random_fill(rng, q, b_rlwe_out.as_mut());
mod_op.elwise_add_mut(b_rlwe_out.as_mut(), m.as_ref());
mod_op.elwise_add_mut(b_rlwe_out.as_mut(), sa.as_ref());
}
/// Returns RLWE(m(X)) encrypted using public key.
///
/// Unlike secret key encryption, public key encryption cannot be seeded
pub(crate) fn public_key_encrypt_rlwe<
M: Matrix,
Mmut: MatrixMut<MatElement = M::MatElement>,
ModOp: VectorOps<Element = M::MatElement> + GetModulus<Element = M::MatElement>,
NttOp: Ntt<Element = M::MatElement>,
S,
R: RandomFillGaussianInModulus<[M::MatElement], ModOp::M>
+ RandomFillUniformInModulus<[M::MatElement], ModOp::M>
+ RandomFill<[u8]>
+ RandomElementInModulus<usize, usize>,
>(
rlwe_out: &mut Mmut,
pk: &M,
m: &[M::MatElement],
mod_op: &ModOp,
ntt_op: &NttOp,
rng: &mut R,
) where
<Mmut as Matrix>::R: RowMut + TryConvertFrom1<[S], ModOp::M> + RowEntity,
M::MatElement: Copy,
S: Zero + Signed + Copy,
{
let ring_size = m.len();
assert!(rlwe_out.dimension() == (2, ring_size));
let q = mod_op.modulus();
let mut u = vec![S::zero(); ring_size];
fill_random_ternary_secret_with_hamming_weight(u.as_mut(), ring_size >> 1, rng);
let mut u = Mmut::R::try_convert_from(&u, q);
ntt_op.forward(u.as_mut());
let mut ua = Mmut::R::zeros(ring_size);
ua.as_mut().copy_from_slice(pk.get_row_slice(0));
let mut ub = Mmut::R::zeros(ring_size);
ub.as_mut().copy_from_slice(pk.get_row_slice(1));
// a*u
ntt_op.forward(ua.as_mut());
mod_op.elwise_mul_mut(ua.as_mut(), u.as_ref());
ntt_op.backward(ua.as_mut());
// b*u
ntt_op.forward(ub.as_mut());
mod_op.elwise_mul_mut(ub.as_mut(), u.as_ref());
ntt_op.backward(ub.as_mut());
// sample error
rlwe_out.iter_rows_mut().for_each(|ri| {
RandomFillGaussianInModulus::random_fill(rng, &q, ri.as_mut());
});
// a*u + e0
mod_op.elwise_add_mut(rlwe_out.get_row_mut(0), ua.as_ref());
// b*u + e1
mod_op.elwise_add_mut(rlwe_out.get_row_mut(1), ub.as_ref());
// b*u + e1 + m
mod_op.elwise_add_mut(rlwe_out.get_row_mut(1), m);
}
/// Returns RLWE public key generated using RLWE secret key
pub(crate) fn rlwe_public_key<
Ro: RowMut + RowEntity,
S,
ModOp: VectorOps<Element = Ro::Element> + GetModulus<Element = Ro::Element>,
NttOp: Ntt<Element = Ro::Element>,
PRng: RandomFillUniformInModulus<[Ro::Element], ModOp::M>,
Rng: RandomFillGaussianInModulus<[Ro::Element], ModOp::M>,
>(
part_b_out: &mut Ro,
s: &[S],
ntt_op: &NttOp,
mod_op: &ModOp,
p_rng: &mut PRng,
rng: &mut Rng,
) where
Ro: TryConvertFrom1<[S], ModOp::M>,
{
let ring_size = s.len();
assert!(part_b_out.as_ref().len() == ring_size);
let q = mod_op.modulus();
// sample a
let mut a = {
let mut tmp = Ro::zeros(ring_size);
RandomFillUniformInModulus::random_fill(p_rng, &q, tmp.as_mut());
tmp
};
ntt_op.forward(a.as_mut());
// s*a
let mut sa = Ro::try_convert_from(s, &q);
ntt_op.forward(sa.as_mut());
mod_op.elwise_mul_mut(sa.as_mut(), a.as_ref());
ntt_op.backward(sa.as_mut());
// s*a + e
RandomFillGaussianInModulus::random_fill(rng, &q, part_b_out.as_mut());
mod_op.elwise_add_mut(part_b_out.as_mut(), sa.as_ref());
}
/// Decrypts ciphertext RLWE(m) and returns noisy m
///
/// We assume RLWE(m) = [a, b] is a degree 1 ciphertext s.t. b - sa = e + m
pub(crate) fn decrypt_rlwe<
R: RowMut,
M: Matrix<MatElement = R::Element>,
ModOp: VectorOps<Element = R::Element> + GetModulus<Element = R::Element>,
NttOp: Ntt<Element = R::Element>,
S,
>(
rlwe_ct: &M,
s: &[S],
m_out: &mut R,
ntt_op: &NttOp,
mod_op: &ModOp,
) where
R: TryConvertFrom1<[S], ModOp::M>,
R::Element: Copy,
{
let ring_size = s.len();
assert!(rlwe_ct.dimension() == (2, ring_size));
assert!(m_out.as_ref().len() == ring_size);
// transform a to evluation form
m_out.as_mut().copy_from_slice(rlwe_ct.get_row_slice(0));
ntt_op.forward(m_out.as_mut());
// -s*a
let mut s = R::try_convert_from(&s, mod_op.modulus());
ntt_op.forward(s.as_mut());
mod_op.elwise_mul_mut(m_out.as_mut(), s.as_ref());
mod_op.elwise_neg_mut(m_out.as_mut());
ntt_op.backward(m_out.as_mut());
// m+e = b - s*a
mod_op.elwise_add_mut(m_out.as_mut(), rlwe_ct.get_row_slice(1));
}
// Measures maximum noise in degree 1 RLWE ciphertext against message `want_m`
fn measure_max_noise<
Mmut: MatrixMut + Matrix,
ModOp: VectorOps<Element = Mmut::MatElement> + GetModulus<Element = Mmut::MatElement>,
NttOp: Ntt<Element = Mmut::MatElement>,
S,
>(
rlwe_ct: &Mmut,
want_m: &Mmut::R,
ntt_op: &NttOp,
mod_op: &ModOp,
s: &[S],
) -> f64
where
<Mmut as Matrix>::R: RowMut,
Mmut::R: RowEntity + TryConvertFrom1<[S], ModOp::M>,
Mmut::MatElement: PrimInt + ToPrimitive + Debug,
{
let ring_size = s.len();
assert!(rlwe_ct.dimension() == (2, ring_size));
assert!(want_m.as_ref().len() == ring_size);
// -(s * a)
let q = mod_op.modulus();
let mut s = Mmut::R::try_convert_from(s, &q);
ntt_op.forward(s.as_mut());
let mut a = Mmut::R::zeros(ring_size);
a.as_mut().copy_from_slice(rlwe_ct.get_row_slice(0));
ntt_op.forward(a.as_mut());
mod_op.elwise_mul_mut(s.as_mut(), a.as_ref());
mod_op.elwise_neg_mut(s.as_mut());
ntt_op.backward(s.as_mut());
// m+e = b - s*a
let mut m_plus_e = s;
mod_op.elwise_add_mut(m_plus_e.as_mut(), rlwe_ct.get_row_slice(1));
// difference
mod_op.elwise_sub_mut(m_plus_e.as_mut(), want_m.as_ref());
let mut max_diff_bits = f64::MIN;
m_plus_e.as_ref().iter().for_each(|v| {
let bits = (q.map_element_to_i64(v).to_f64().unwrap().abs()).log2();
if max_diff_bits < bits {
max_diff_bits = bits;
}
});
return max_diff_bits;
}

+ 982
- 0
src/rgsw/mod.rs

@ -0,0 +1,982 @@
mod keygen;
mod runtime;
pub(crate) use keygen::*;
pub(crate) use runtime::*;
#[cfg(test)]
pub(crate) mod tests {
use std::{fmt::Debug, marker::PhantomData, vec};
use itertools::{izip, Itertools};
use rand::{thread_rng, Rng};
use crate::{
backend::{GetModulus, ModInit, ModularOpsU64, Modulus, VectorOps},
decomposer::{Decomposer, DefaultDecomposer, RlweDecomposer},
ntt::{Ntt, NttBackendU64, NttInit},
random::{DefaultSecureRng, NewWithSeed, RandomFillUniformInModulus},
rgsw::{
rlwe_auto_scratch_rows, rlwe_auto_shoup, rlwe_by_rgsw_shoup, rlwe_x_rgsw_scratch_rows,
RgswCiphertextRef, RlweCiphertextMutRef, RlweKskRef, RuntimeScratchMutRef,
},
utils::{
fill_random_ternary_secret_with_hamming_weight, generate_prime, negacyclic_mul,
tests::Stats, ToShoup, TryConvertFrom1, WithLocal,
},
Matrix, MatrixEntity, MatrixMut, Row, RowEntity, RowMut,
};
use super::{
keygen::{
decrypt_rlwe, generate_auto_map, rlwe_public_key, secret_key_encrypt_rgsw,
seeded_auto_key_gen, seeded_secret_key_encrypt_rlwe,
},
rgsw_x_rgsw_scratch_rows,
runtime::{rgsw_by_rgsw_inplace, rlwe_auto, rlwe_by_rgsw},
RgswCiphertextMutRef,
};
struct SeededAutoKey<M, S, Mod>
where
M: Matrix,
{
data: M,
seed: S,
modulus: Mod,
}
impl<M: Matrix + MatrixEntity, S, Mod: Modulus<Element = M::MatElement>> SeededAutoKey<M, S, Mod> {
fn empty<D: Decomposer>(
ring_size: usize,
auto_decomposer: &D,
seed: S,
modulus: Mod,
) -> Self {
SeededAutoKey {
data: M::zeros(auto_decomposer.decomposition_count().0, ring_size),
seed,
modulus,
}
}
}
struct AutoKeyEvaluationDomain<M: Matrix, R, N> {
data: M,
_phantom: PhantomData<(R, N)>,
}
impl<
M: MatrixMut + MatrixEntity,
Mod: Modulus<Element = M::MatElement> + Clone,
R: RandomFillUniformInModulus<[M::MatElement], Mod> + NewWithSeed,
N: NttInit<Mod> + Ntt<Element = M::MatElement>,
> From<&SeededAutoKey<M, R::Seed, Mod>> for AutoKeyEvaluationDomain<M, R, N>
where
<M as Matrix>::R: RowMut,
M::MatElement: Copy,
R::Seed: Clone,
{
fn from(value: &SeededAutoKey<M, R::Seed, Mod>) -> Self {
let (d, ring_size) = value.data.dimension();
let mut data = M::zeros(2 * d, ring_size);
// sample RLWE'_A(-s(X^k))
let mut p_rng = R::new_with_seed(value.seed.clone());
data.iter_rows_mut().take(d).for_each(|r| {
RandomFillUniformInModulus::random_fill(&mut p_rng, &value.modulus, r.as_mut());
});
// copy over RLWE'_B(-s(X^k))
izip!(data.iter_rows_mut().skip(d), value.data.iter_rows()).for_each(
|(to_r, from_r)| {
to_r.as_mut().copy_from_slice(from_r.as_ref());
},
);
// send RLWE'(-s(X^k)) polynomials to evaluation domain
let ntt_op = N::new(&value.modulus, ring_size);
data.iter_rows_mut()
.for_each(|r| ntt_op.forward(r.as_mut()));
AutoKeyEvaluationDomain {
data,
_phantom: PhantomData,
}
}
}
struct RgswCiphertext<M: Matrix, Mod> {
/// Rgsw ciphertext polynomials
data: M,
modulus: Mod,
/// Decomposition for RLWE part A
d_a: usize,
/// Decomposition for RLWE part B
d_b: usize,
}
impl<M: MatrixEntity, Mod: Modulus<Element = M::MatElement>> RgswCiphertext<M, Mod> {
pub(crate) fn empty<D: RlweDecomposer>(
ring_size: usize,
decomposer: &D,
modulus: Mod,
) -> RgswCiphertext<M, Mod> {
RgswCiphertext {
data: M::zeros(
decomposer.a().decomposition_count().0 * 2
+ decomposer.b().decomposition_count().0 * 2,
ring_size,
),
d_a: decomposer.a().decomposition_count().0,
d_b: decomposer.b().decomposition_count().0,
modulus,
}
}
}
pub struct SeededRgswCiphertext<M, S, Mod>
where
M: Matrix,
{
pub(crate) data: M,
seed: S,
modulus: Mod,
/// Decomposition for RLWE part A
d_a: usize,
/// Decomposition for RLWE part B
d_b: usize,
}
impl<M: Matrix + MatrixEntity, S, Mod> SeededRgswCiphertext<M, S, Mod> {
pub(crate) fn empty<D: RlweDecomposer>(
ring_size: usize,
decomposer: &D,
seed: S,
modulus: Mod,
) -> SeededRgswCiphertext<M, S, Mod> {
SeededRgswCiphertext {
data: M::zeros(
decomposer.a().decomposition_count().0 * 2
+ decomposer.b().decomposition_count().0,
ring_size,
),
seed,
modulus,
d_a: decomposer.a().decomposition_count().0,
d_b: decomposer.b().decomposition_count().0,
}
}
}
impl<M: Debug + Matrix, S: Debug, Mod: Debug> Debug for SeededRgswCiphertext<M, S, Mod>
where
M::MatElement: Debug,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SeededRgswCiphertext")
.field("data", &self.data)
.field("seed", &self.seed)
.field("modulus", &self.modulus)
.finish()
}
}
pub struct RgswCiphertextEvaluationDomain<M, Mod, R, N> {
pub(crate) data: M,
modulus: Mod,
_phantom: PhantomData<(R, N)>,
}
impl<
M: MatrixMut + MatrixEntity,
Mod: Modulus<Element = M::MatElement> + Clone,
R: NewWithSeed + RandomFillUniformInModulus<[M::MatElement], Mod>,
N: NttInit<Mod> + Ntt<Element = M::MatElement> + Debug,
> From<&SeededRgswCiphertext<M, R::Seed, Mod>>
for RgswCiphertextEvaluationDomain<M, Mod, R, N>
where
<M as Matrix>::R: RowMut,
M::MatElement: Copy,
R::Seed: Clone,
M: Debug,
{
fn from(value: &SeededRgswCiphertext<M, R::Seed, Mod>) -> Self {
let mut data = M::zeros(value.d_a * 2 + value.d_b * 2, value.data.dimension().1);
// copy RLWE'(-sm)
izip!(
data.iter_rows_mut().take(value.d_a * 2),
value.data.iter_rows().take(value.d_a * 2)
)
.for_each(|(to_ri, from_ri)| {
to_ri.as_mut().copy_from_slice(from_ri.as_ref());
});
// sample A polynomials of RLWE'(m) - RLWE'A(m)
let mut p_rng = R::new_with_seed(value.seed.clone());
izip!(data.iter_rows_mut().skip(value.d_a * 2).take(value.d_b * 1))
.for_each(|ri| p_rng.random_fill(&value.modulus, ri.as_mut()));
// RLWE'_B(m)
izip!(
data.iter_rows_mut().skip(value.d_a * 2 + value.d_b),
value.data.iter_rows().skip(value.d_a * 2)
)
.for_each(|(to_ri, from_ri)| {
to_ri.as_mut().copy_from_slice(from_ri.as_ref());
});
// Send polynomials to evaluation domain
let ring_size = data.dimension().1;
let nttop = N::new(&value.modulus, ring_size);
data.iter_rows_mut()
.for_each(|ri| nttop.forward(ri.as_mut()));
Self {
data: data,
modulus: value.modulus.clone(),
_phantom: PhantomData,
}
}
}
impl<
M: MatrixMut + MatrixEntity,
Mod: Modulus<Element = M::MatElement> + Clone,
R,
N: NttInit<Mod> + Ntt<Element = M::MatElement>,
> From<&RgswCiphertext<M, Mod>> for RgswCiphertextEvaluationDomain<M, Mod, R, N>
where
<M as Matrix>::R: RowMut,
M::MatElement: Copy,
M: Debug,
{
fn from(value: &RgswCiphertext<M, Mod>) -> Self {
let mut data = M::zeros(value.d_a * 2 + value.d_b * 2, value.data.dimension().1);
// copy RLWE'(-sm)
izip!(
data.iter_rows_mut().take(value.d_a * 2),
value.data.iter_rows().take(value.d_a * 2)
)
.for_each(|(to_ri, from_ri)| {
to_ri.as_mut().copy_from_slice(from_ri.as_ref());
});
// copy RLWE'(m)
izip!(
data.iter_rows_mut().skip(value.d_a * 2),
value.data.iter_rows().skip(value.d_a * 2)
)
.for_each(|(to_ri, from_ri)| {
to_ri.as_mut().copy_from_slice(from_ri.as_ref());
});
// Send polynomials to evaluation domain
let ring_size = data.dimension().1;
let nttop = N::new(&value.modulus, ring_size);
data.iter_rows_mut()
.for_each(|ri| nttop.forward(ri.as_mut()));
Self {
data: data,
modulus: value.modulus.clone(),
_phantom: PhantomData,
}
}
}
impl<M: Debug, Mod: Debug, R, N> Debug for RgswCiphertextEvaluationDomain<M, Mod, R, N> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("RgswCiphertextEvaluationDomain")
.field("data", &self.data)
.field("modulus", &self.modulus)
.field("_phantom", &self._phantom)
.finish()
}
}
struct SeededRlweCiphertext<R, S, Mod> {
data: R,
seed: S,
modulus: Mod,
}
impl<R: RowEntity, S, Mod> SeededRlweCiphertext<R, S, Mod> {
fn empty(ring_size: usize, seed: S, modulus: Mod) -> Self {
SeededRlweCiphertext {
data: R::zeros(ring_size),
seed,
modulus,
}
}
}
pub struct RlweCiphertext<M, Rng> {
data: M,
_phatom: PhantomData<Rng>,
}
impl<
R: Row,
M: MatrixEntity<R = R, MatElement = R::Element> + MatrixMut,
Rng: NewWithSeed + RandomFillUniformInModulus<[M::MatElement], Mod>,
Mod: Modulus<Element = R::Element>,
> From<&SeededRlweCiphertext<R, Rng::Seed, Mod>> for RlweCiphertext<M, Rng>
where
Rng::Seed: Clone,
<M as Matrix>::R: RowMut,
R::Element: Copy,
{
fn from(value: &SeededRlweCiphertext<R, Rng::Seed, Mod>) -> Self {
let mut data = M::zeros(2, value.data.as_ref().len());
// sample a
let mut p_rng = Rng::new_with_seed(value.seed.clone());
RandomFillUniformInModulus::random_fill(
&mut p_rng,
&value.modulus,
data.get_row_mut(0),
);
data.get_row_mut(1).copy_from_slice(value.data.as_ref());
RlweCiphertext {
data,
_phatom: PhantomData,
}
}
}
struct SeededRlwePublicKey<Ro: Row, S> {
data: Ro,
seed: S,
modulus: Ro::Element,
}
impl<Ro: RowEntity, S> SeededRlwePublicKey<Ro, S> {
pub(crate) fn empty(ring_size: usize, seed: S, modulus: Ro::Element) -> Self {
Self {
data: Ro::zeros(ring_size),
seed,
modulus,
}
}
}
struct RlwePublicKey<M, R> {
data: M,
_phantom: PhantomData<R>,
}
impl<
M: MatrixMut + MatrixEntity,
Rng: NewWithSeed + RandomFillUniformInModulus<[M::MatElement], M::MatElement>,
> From<&SeededRlwePublicKey<M::R, Rng::Seed>> for RlwePublicKey<M, Rng>
where
<M as Matrix>::R: RowMut,
M::MatElement: Copy,
Rng::Seed: Copy,
{
fn from(value: &SeededRlwePublicKey<M::R, Rng::Seed>) -> Self {
let mut data = M::zeros(2, value.data.as_ref().len());
// sample a
let mut p_rng = Rng::new_with_seed(value.seed);
RandomFillUniformInModulus::random_fill(
&mut p_rng,
&value.modulus,
data.get_row_mut(0),
);
// copy over b
data.get_row_mut(1).copy_from_slice(value.data.as_ref());
Self {
data,
_phantom: PhantomData,
}
}
}
#[derive(Clone)]
struct RlweSecret {
pub(crate) values: Vec<i32>,
}
impl RlweSecret {
pub fn random(hw: usize, n: usize) -> RlweSecret {
DefaultSecureRng::with_local_mut(|rng| {
let mut out = vec![0i32; n];
fill_random_ternary_secret_with_hamming_weight(&mut out, hw, rng);
RlweSecret { values: out }
})
}
fn values(&self) -> &[i32] {
&self.values
}
}
fn random_seed() -> [u8; 32] {
let mut rng = DefaultSecureRng::new();
let mut seed = [0u8; 32];
rng.fill_bytes(&mut seed);
seed
}
/// Encrypts m as RGSW ciphertext RGSW(m) using supplied secret key. Returns
/// seeded RGSW ciphertext in coefficient domain
fn sk_encrypt_rgsw<T: Modulus<Element = u64> + Clone>(
m: &[u64],
s: &[i32],
decomposer: &(DefaultDecomposer<u64>, DefaultDecomposer<u64>),
mod_op: &ModularOpsU64<T>,
ntt_op: &NttBackendU64,
) -> SeededRgswCiphertext<Vec<Vec<u64>>, [u8; 32], T> {
let ring_size = s.len();
assert!(m.len() == s.len());
let mut rng = DefaultSecureRng::new();
let q = mod_op.modulus();
let rgsw_seed = random_seed();
let mut seeded_rgsw_ct = SeededRgswCiphertext::<Vec<Vec<u64>>, [u8; 32], T>::empty(
ring_size as usize,
decomposer,
rgsw_seed,
q.clone(),
);
let mut p_rng = DefaultSecureRng::new_seeded(rgsw_seed);
secret_key_encrypt_rgsw(
&mut seeded_rgsw_ct.data,
m,
&decomposer.a().gadget_vector(),
&decomposer.b().gadget_vector(),
s,
mod_op,
ntt_op,
&mut p_rng,
&mut rng,
);
seeded_rgsw_ct
}
#[test]
fn rlwe_encrypt_decryption() {
let logq = 50;
let logp = 2;
let ring_size = 1 << 4;
let q = generate_prime(logq, ring_size, 1u64 << logq).unwrap();
let p = 1u64 << logp;
let mut rng = DefaultSecureRng::new();
let s = RlweSecret::random((ring_size >> 1) as usize, ring_size as usize);
// sample m0
let mut m0 = vec![0u64; ring_size as usize];
RandomFillUniformInModulus::<[u64], u64>::random_fill(
&mut rng,
&(1u64 << logp),
m0.as_mut_slice(),
);
let ntt_op = NttBackendU64::new(&q, ring_size as usize);
let mod_op = ModularOpsU64::new(q);
// encrypt m0
let encoded_m = m0
.iter()
.map(|v| (((*v as f64) * q as f64) / (p as f64)).round() as u64)
.collect_vec();
let seed = random_seed();
let mut rlwe_in_ct =
SeededRlweCiphertext::<Vec<u64>, _, _>::empty(ring_size as usize, seed, q);
let mut p_rng = DefaultSecureRng::new_seeded(seed);
seeded_secret_key_encrypt_rlwe(
&encoded_m,
&mut rlwe_in_ct.data,
s.values(),
&mod_op,
&ntt_op,
&mut p_rng,
&mut rng,
);
let rlwe_in_ct = RlweCiphertext::<Vec<Vec<u64>>, DefaultSecureRng>::from(&rlwe_in_ct);
let mut encoded_m_back = vec![0u64; ring_size as usize];
decrypt_rlwe(
&rlwe_in_ct.data,
s.values(),
&mut encoded_m_back,
&ntt_op,
&mod_op,
);
let m_back = encoded_m_back
.iter()
.map(|v| (((*v as f64 * p as f64) / q as f64).round() as u64) % p)
.collect_vec();
assert_eq!(m0, m_back);
}
#[test]
fn rlwe_by_rgsw_works() {
let logq = 50;
let logp = 2;
let ring_size = 1 << 4;
let q = generate_prime(logq, ring_size, 1u64 << logq).unwrap();
let p: u64 = 1u64 << logp;
let mut rng = DefaultSecureRng::new_seeded([0u8; 32]);
let s = RlweSecret::random((ring_size >> 1) as usize, ring_size as usize);
let mut m0 = vec![0u64; ring_size as usize];
RandomFillUniformInModulus::<[u64], _>::random_fill(
&mut rng,
&(1u64 << logp),
m0.as_mut_slice(),
);
let mut m1 = vec![0u64; ring_size as usize];
m1[thread_rng().gen_range(0..ring_size) as usize] = 1;
let ntt_op = NttBackendU64::new(&q, ring_size as usize);
let mod_op = ModularOpsU64::new(q);
let d_rgsw = 10;
let logb = 5;
let decomposer = (
DefaultDecomposer::new(q, logb, d_rgsw),
DefaultDecomposer::new(q, logb, d_rgsw),
);
// create public key
let pk_seed = random_seed();
let mut pk_prng = DefaultSecureRng::new_seeded(pk_seed);
let mut seeded_pk =
SeededRlwePublicKey::<Vec<u64>, _>::empty(ring_size as usize, pk_seed, q);
rlwe_public_key(
&mut seeded_pk.data,
s.values(),
&ntt_op,
&mod_op,
&mut pk_prng,
&mut rng,
);
// let pk = RlwePublicKey::<Vec<Vec<u64>>, DefaultSecureRng>::from(&seeded_pk);
// Encrypt m1 as RGSW(m1)
let rgsw_ct = {
// Encryption m1 as RGSW(m1) using secret key
let seeded_rgsw_ct = sk_encrypt_rgsw(&m1, s.values(), &decomposer, &mod_op, &ntt_op);
RgswCiphertextEvaluationDomain::<Vec<Vec<u64>>, _,DefaultSecureRng, NttBackendU64>::from(&seeded_rgsw_ct)
};
// Encrypt m0 as RLWE(m0)
let mut rlwe_in_ct = {
let encoded_m = m0
.iter()
.map(|v| (((*v as f64) * q as f64) / (p as f64)).round() as u64)
.collect_vec();
let seed = random_seed();
let mut p_rng = DefaultSecureRng::new_seeded(seed);
let mut seeded_rlwe = SeededRlweCiphertext::empty(ring_size as usize, seed, q);
seeded_secret_key_encrypt_rlwe(
&encoded_m,
&mut seeded_rlwe.data,
s.values(),
&mod_op,
&ntt_op,
&mut p_rng,
&mut rng,
);
RlweCiphertext::<Vec<Vec<u64>>, DefaultSecureRng>::from(&seeded_rlwe)
};
// RLWE(m0m1) = RLWE(m0) x RGSW(m1)
let mut scratch_space =
vec![vec![0u64; ring_size as usize]; rlwe_x_rgsw_scratch_rows(&decomposer)];
// rlwe x rgsw with with soup repr
let rlwe_in_ct_shoup = {
let mut rlwe_in_ct_shoup = rlwe_in_ct.data.clone();
let rgsw_ct_shoup = ToShoup::to_shoup(&rgsw_ct.data, q);
rlwe_by_rgsw_shoup(
&mut RlweCiphertextMutRef::new(rlwe_in_ct_shoup.as_mut()),
&RgswCiphertextRef::new(
rgsw_ct.data.as_ref(),
decomposer.a().decomposition_count().0,
decomposer.b().decomposition_count().0,
),
&RgswCiphertextRef::new(
rgsw_ct_shoup.as_ref(),
decomposer.a().decomposition_count().0,
decomposer.b().decomposition_count().0,
),
&mut RuntimeScratchMutRef::new(scratch_space.as_mut()),
&decomposer,
&ntt_op,
&mod_op,
false,
);
rlwe_in_ct_shoup
};
// rlwe x rgsw normal
{
rlwe_by_rgsw(
&mut RlweCiphertextMutRef::new(rlwe_in_ct.data.as_mut()),
&RgswCiphertextRef::new(
rgsw_ct.data.as_ref(),
decomposer.a().decomposition_count().0,
decomposer.b().decomposition_count().0,
),
&mut RuntimeScratchMutRef::new(scratch_space.as_mut()),
&decomposer,
&ntt_op,
&mod_op,
false,
);
}
// output from both functions must be equal
assert_eq!(rlwe_in_ct.data, rlwe_in_ct_shoup);
// Decrypt RLWE(m0m1)
let mut encoded_m0m1_back = vec![0u64; ring_size as usize];
decrypt_rlwe(
&rlwe_in_ct_shoup,
s.values(),
&mut encoded_m0m1_back,
&ntt_op,
&mod_op,
);
let m0m1_back = encoded_m0m1_back
.iter()
.map(|v| (((*v as f64 * p as f64) / (q as f64)).round() as u64) % p)
.collect_vec();
let mul_mod = |v0: &u64, v1: &u64| (v0 * v1) % p;
let m0m1 = negacyclic_mul(&m0, &m1, mul_mod, p);
// {
// // measure noise
// let encoded_m_ideal = m0m1
// .iter()
// .map(|v| (((*v as f64) * q as f64) / (p as f64)).round() as u64)
// .collect_vec();
// let noise = measure_noise(&rlwe_in_ct, &encoded_m_ideal, &ntt_op,
// &mod_op, s.values()); println!("Noise RLWE(m0m1)(=
// RLWE(m0)xRGSW(m1)) : {noise}"); }
assert!(
m0m1 == m0m1_back,
"Expected {:?} \n Got {:?}",
m0m1,
m0m1_back
);
}
#[test]
fn rlwe_auto_works() {
let logq = 55;
let ring_size = 1 << 11;
let q = generate_prime(logq, 2 * ring_size, 1u64 << logq).unwrap();
let logp = 3;
let p = 1u64 << logp;
let d_rgsw = 5;
let logb = 11;
let mut rng = DefaultSecureRng::new();
let s = RlweSecret::random((ring_size >> 1) as usize, ring_size as usize);
let mut m = vec![0u64; ring_size as usize];
RandomFillUniformInModulus::random_fill(&mut rng, &p, m.as_mut_slice());
let encoded_m = m
.iter()
.map(|v| (((*v as f64 * q as f64) / (p as f64)).round() as u64))
.collect_vec();
let ntt_op = NttBackendU64::new(&q, ring_size as usize);
let mod_op = ModularOpsU64::new(q);
// RLWE_{s}(m)
let seed_rlwe = random_seed();
let mut seeded_rlwe_m = SeededRlweCiphertext::empty(ring_size as usize, seed_rlwe, q);
let mut p_rng = DefaultSecureRng::new_seeded(seed_rlwe);
seeded_secret_key_encrypt_rlwe(
&encoded_m,
&mut seeded_rlwe_m.data,
s.values(),
&mod_op,
&ntt_op,
&mut p_rng,
&mut rng,
);
let mut rlwe_m = RlweCiphertext::<Vec<Vec<u64>>, DefaultSecureRng>::from(&seeded_rlwe_m);
let auto_k = -125;
// Generate auto key to key switch from s^k to s
let decomposer = DefaultDecomposer::new(q, logb, d_rgsw);
let seed_auto = random_seed();
let mut seeded_auto_key =
SeededAutoKey::empty(ring_size as usize, &decomposer, seed_auto, q);
let mut p_rng = DefaultSecureRng::new_seeded(seed_auto);
let gadget_vector = decomposer.gadget_vector();
seeded_auto_key_gen(
&mut seeded_auto_key.data,
s.values(),
auto_k,
&gadget_vector,
&mod_op,
&ntt_op,
&mut p_rng,
&mut rng,
);
let auto_key =
AutoKeyEvaluationDomain::<Vec<Vec<u64>>, DefaultSecureRng, NttBackendU64>::from(
&seeded_auto_key,
);
// Send RLWE_{s}(m) -> RLWE_{s}(m^k)
let mut scratch_space =
vec![vec![0; ring_size as usize]; rlwe_auto_scratch_rows(&decomposer)];
let (auto_map_index, auto_map_sign) = generate_auto_map(ring_size as usize, auto_k);
// galois auto with auto key in shoup repr
let rlwe_m_shoup = {
let auto_key_shoup = ToShoup::to_shoup(&auto_key.data, q);
let mut rlwe_m_shoup = rlwe_m.data.clone();
rlwe_auto_shoup(
&mut RlweCiphertextMutRef::new(&mut rlwe_m_shoup),
&RlweKskRef::new(&auto_key.data, decomposer.decomposition_count().0),
&RlweKskRef::new(&auto_key_shoup, decomposer.decomposition_count().0),
&mut RuntimeScratchMutRef::new(&mut scratch_space),
&auto_map_index,
&auto_map_sign,
&mod_op,
&ntt_op,
&decomposer,
false,
);
rlwe_m_shoup
};
// normal galois auto
{
rlwe_auto(
&mut RlweCiphertextMutRef::new(rlwe_m.data.as_mut()),
&RlweKskRef::new(auto_key.data.as_ref(), decomposer.decomposition_count().0),
&mut RuntimeScratchMutRef::new(scratch_space.as_mut()),
&auto_map_index,
&auto_map_sign,
&mod_op,
&ntt_op,
&decomposer,
false,
);
}
// rlwe out from both functions must be same
assert_eq!(rlwe_m.data, rlwe_m_shoup);
let rlwe_m_k = rlwe_m;
// Decrypt RLWE_{s}(m^k) and check
let mut encoded_m_k_back = vec![0u64; ring_size as usize];
decrypt_rlwe(
&rlwe_m_k.data,
s.values(),
&mut encoded_m_k_back,
&ntt_op,
&mod_op,
);
let m_k_back = encoded_m_k_back
.iter()
.map(|v| (((*v as f64 * p as f64) / q as f64).round() as u64) % p)
.collect_vec();
let mut m_k = vec![0u64; ring_size as usize];
// Send \delta m -> \delta m^k
izip!(m.iter(), auto_map_index.iter(), auto_map_sign.iter()).for_each(
|(v, to_index, sign)| {
if !*sign {
m_k[*to_index] = (p - *v) % p;
} else {
m_k[*to_index] = *v;
}
},
);
// {
// let encoded_m_k = m_k
// .iter()
// .map(|v| ((*v as f64 * q as f64) / p as f64).round() as u64)
// .collect_vec();
// let noise = measure_noise(&rlwe_m_k, &encoded_m_k, &ntt_op, &mod_op,
// s.values()); println!("Ksk noise: {noise}");
// }
assert_eq!(m_k_back, m_k);
}
/// Collect noise stats of RGSW ciphertext
///
/// - rgsw_ct: RGSW ciphertext must be in coefficient domain
fn rgsw_noise_stats<T: Modulus<Element = u64> + Clone>(
rgsw_ct: &[Vec<u64>],
m: &[u64],
s: &[i32],
decomposer: &(DefaultDecomposer<u64>, DefaultDecomposer<u64>),
q: &T,
) -> Stats<i64> {
let gadget_vector_a = decomposer.a().gadget_vector();
let gadget_vector_b = decomposer.b().gadget_vector();
let d_a = gadget_vector_a.len();
let d_b = gadget_vector_b.len();
let ring_size = s.len();
assert!(Matrix::dimension(&rgsw_ct) == (d_a * 2 + d_b * 2, ring_size));
assert!(m.len() == ring_size);
let mod_op = ModularOpsU64::new(q.clone());
let ntt_op = NttBackendU64::new(q, ring_size);
let mul_mod =
|a: &u64, b: &u64| ((*a as u128 * *b as u128) % q.q().unwrap() as u128) as u64;
let s_poly = Vec::<u64>::try_convert_from(s, q);
let mut neg_s = s_poly.clone();
mod_op.elwise_neg_mut(neg_s.as_mut());
let neg_sm0m1 = negacyclic_mul(&neg_s, &m, mul_mod, q.q().unwrap());
let mut stats = Stats::new();
// RLWE(\beta^j -s * m)
for j in 0..d_a {
let want_m = {
// RLWE(\beta^j -s * m)
let mut beta_neg_sm0m1 = vec![0u64; ring_size as usize];
mod_op.elwise_scalar_mul(beta_neg_sm0m1.as_mut(), &neg_sm0m1, &gadget_vector_a[j]);
beta_neg_sm0m1
};
let mut rlwe = vec![vec![0u64; ring_size as usize]; 2];
rlwe[0].copy_from_slice(rgsw_ct.get_row_slice(j));
rlwe[1].copy_from_slice(rgsw_ct.get_row_slice(d_a + j));
let mut got_m = vec![0; ring_size];
decrypt_rlwe(&rlwe, s, &mut got_m, &ntt_op, &mod_op);
let mut diff = want_m;
mod_op.elwise_sub_mut(diff.as_mut(), got_m.as_ref());
stats.add_many_samples(&Vec::<i64>::try_convert_from(&diff, q));
}
// RLWE(\beta^j m)
for j in 0..d_b {
let want_m = {
// RLWE(\beta^j m)
let mut beta_m0m1 = vec![0u64; ring_size as usize];
mod_op.elwise_scalar_mul(beta_m0m1.as_mut(), &m, &gadget_vector_b[j]);
beta_m0m1
};
let mut rlwe = vec![vec![0u64; ring_size as usize]; 2];
rlwe[0].copy_from_slice(rgsw_ct.get_row_slice(d_a * 2 + j));
rlwe[1].copy_from_slice(rgsw_ct.get_row_slice(d_a * 2 + d_b + j));
let mut got_m = vec![0; ring_size];
decrypt_rlwe(&rlwe, s, &mut got_m, &ntt_op, &mod_op);
let mut diff = want_m;
mod_op.elwise_sub_mut(diff.as_mut(), got_m.as_ref());
stats.add_many_samples(&Vec::<i64>::try_convert_from(&diff, q));
}
stats
}
#[test]
fn print_noise_stats_rgsw_x_rgsw() {
let logq = 60;
let logp = 2;
let ring_size = 1 << 11;
let q = generate_prime(logq, ring_size, 1u64 << logq).unwrap();
let d_rgsw = 12;
let logb = 5;
let s = RlweSecret::random((ring_size >> 1) as usize, ring_size as usize);
let ntt_op = NttBackendU64::new(&q, ring_size as usize);
let mod_op = ModularOpsU64::new(q);
let decomposer = (
DefaultDecomposer::new(q, logb, d_rgsw),
DefaultDecomposer::new(q, logb, d_rgsw),
);
let d_a = decomposer.a().decomposition_count().0;
let d_b = decomposer.b().decomposition_count().0;
let mul_mod = |a: &u64, b: &u64| ((*a as u128 * *b as u128) % q as u128) as u64;
let mut carry_m = vec![0u64; ring_size as usize];
carry_m[thread_rng().gen_range(0..ring_size) as usize] = 1 << logp;
// RGSW(carry_m)
let mut rgsw_carrym = {
let seeded_rgsw = sk_encrypt_rgsw(&carry_m, s.values(), &decomposer, &mod_op, &ntt_op);
let mut rgsw_eval =
RgswCiphertextEvaluationDomain::<_, _, DefaultSecureRng, NttBackendU64>::from(
&seeded_rgsw,
);
rgsw_eval
.data
.iter_mut()
.for_each(|ri| ntt_op.backward(ri.as_mut()));
rgsw_eval.data
};
let mut scratch_matrix = vec![
vec![0u64; ring_size as usize];
rgsw_x_rgsw_scratch_rows(&decomposer, &decomposer)
];
rgsw_noise_stats(&rgsw_carrym, &carry_m, s.values(), &decomposer, &q);
for i in 0..8 {
let mut m = vec![0u64; ring_size as usize];
m[thread_rng().gen_range(0..ring_size) as usize] = if (i & 1) == 1 { q - 1 } else { 1 };
let rgsw_m =
RgswCiphertextEvaluationDomain::<_, _, DefaultSecureRng, NttBackendU64>::from(
&sk_encrypt_rgsw(&m, s.values(), &decomposer, &mod_op, &ntt_op),
);
rgsw_by_rgsw_inplace(
&mut RgswCiphertextMutRef::new(rgsw_carrym.as_mut(), d_a, d_b),
&RgswCiphertextRef::new(rgsw_m.data.as_ref(), d_a, d_b),
&decomposer,
&decomposer,
&mut RuntimeScratchMutRef::new(scratch_matrix.as_mut()),
&ntt_op,
&mod_op,
);
// measure noise
carry_m = negacyclic_mul(&carry_m, &m, mul_mod, q);
let stats = rgsw_noise_stats(&rgsw_carrym, &carry_m, s.values(), &decomposer, &q);
println!(
"Log2 of noise std after {i} RGSW x RGSW: {}",
stats.std_dev().abs().log2()
);
}
}
}

+ 1063
- 0
src/rgsw/runtime.rs
File diff suppressed because it is too large
View File


+ 370
- 0
src/shortint/enc_dec.rs

@ -0,0 +1,370 @@
use itertools::Itertools;
use crate::{
bool::BoolEvaluator,
random::{DefaultSecureRng, RandomFillUniformInModulus},
utils::WithLocal,
Decryptor, Encryptor, KeySwitchWithId, Matrix, MatrixEntity, MatrixMut, MultiPartyDecryptor,
RowMut, SampleExtractor,
};
/// Fhe UInt8
///
/// Note that `Self.data` stores encryptions of bits in little endian (i.e least
/// signficant bit stored at 0th index and most signficant bit stores at 7th
/// index)
#[derive(Clone)]
pub struct FheUint8<C> {
pub(super) data: Vec<C>,
}
impl<C> FheUint8<C> {
pub(super) fn data(&self) -> &[C] {
&self.data
}
pub(super) fn data_mut(&mut self) -> &mut [C] {
&mut self.data
}
}
/// Stores a batch of Fhe Uint8 ciphertext as collection of unseeded RLWE
/// ciphertexts always encrypted under the ideal RLWE secret `s` of the MPC
/// protocol
///
/// To extract Fhe Uint8 ciphertext at `index` call `self.extract(index)`
pub struct BatchedFheUint8<C> {
/// Vector of RLWE ciphertexts `C`
data: Vec<C>,
/// Count of FheUint8s packed in vector of RLWE ciphertexts
count: usize,
}
impl<K, C> Encryptor<[u8], BatchedFheUint8<C>> for K
where
K: Encryptor<[bool], Vec<C>>,
{
/// Encrypt a batch of uint8s packed in vector of RLWE ciphertexts
///
/// Uint8s can be extracted from `BatchedFheUint8` with `SampleExtractor`
fn encrypt(&self, m: &[u8]) -> BatchedFheUint8<C> {
let bool_m = m
.iter()
.flat_map(|v| {
(0..8)
.into_iter()
.map(|i| ((*v >> i) & 1) == 1)
.collect_vec()
})
.collect_vec();
let cts = K::encrypt(&self, &bool_m);
BatchedFheUint8 {
data: cts,
count: m.len(),
}
}
}
impl<M: MatrixEntity + MatrixMut<MatElement = u64>> From<&SeededBatchedFheUint8<M::R, [u8; 32]>>
for BatchedFheUint8<M>
where
<M as Matrix>::R: RowMut,
{
/// Unseeds collection of seeded RLWE ciphertext in SeededBatchedFheUint8
/// and returns as `Self`
fn from(value: &SeededBatchedFheUint8<M::R, [u8; 32]>) -> Self {
BoolEvaluator::with_local(|e| {
let parameters = e.parameters();
let ring_size = parameters.rlwe_n().0;
let rlwe_q = parameters.rlwe_q();
let mut prng = DefaultSecureRng::new_seeded(value.seed);
let rlwes = value
.data
.iter()
.map(|partb| {
let mut rlwe = M::zeros(2, ring_size);
// sample A
RandomFillUniformInModulus::random_fill(&mut prng, rlwe_q, rlwe.get_row_mut(0));
// Copy over B
rlwe.get_row_mut(1).copy_from_slice(partb.as_ref());
rlwe
})
.collect_vec();
Self {
data: rlwes,
count: value.count,
}
})
}
}
impl<C, R> SampleExtractor<FheUint8<R>> for BatchedFheUint8<C>
where
C: SampleExtractor<R>,
{
/// Extract Fhe Uint8 ciphertext at `index`
///
/// `Self` stores batch of Fhe uint8 ciphertext as vector of RLWE
/// ciphertexts. Since Fhe uint8 ciphertext is collection of 8 bool
/// ciphertexts, Fhe uint8 ciphertext at index `i` is stored in coefficients
/// `i*8...(i+1)*8`. To extract Fhe uint8 at index `i`, sample extract bool
/// ciphertext at indices `[i*8, ..., (i+1)*8)`
fn extract_at(&self, index: usize) -> FheUint8<R> {
assert!(index < self.count);
BoolEvaluator::with_local(|e| {
let ring_size = e.parameters().rlwe_n().0;
let start_index = index * 8;
let end_index = (index + 1) * 8;
let data = (start_index..end_index)
.map(|i| {
let rlwe_index = i / ring_size;
let coeff_index = i % ring_size;
self.data[rlwe_index].extract_at(coeff_index)
})
.collect_vec();
FheUint8 { data }
})
}
/// Extracts all FheUint8s packed in vector of RLWE ciphertexts of `Self`
fn extract_all(&self) -> Vec<FheUint8<R>> {
(0..self.count)
.map(|index| self.extract_at(index))
.collect_vec()
}
/// Extracts first `how_many` FheUint8s packed in vector of RLWE
/// ciphertexts of `Self`
fn extract_many(&self, how_many: usize) -> Vec<FheUint8<R>> {
(0..how_many)
.map(|index| self.extract_at(index))
.collect_vec()
}
}
/// Stores a batch of FheUint8s packed in a collection unseeded RLWE ciphertexts
///
/// `Self` stores unseeded RLWE ciphertexts encrypted under user's RLWE secret
/// `u_j` and is different from `BatchFheUint8` which stores collection of RLWE
/// ciphertexts under ideal RLWE secret `s` of the (non-interactive/interactive)
/// MPC protocol.
///
/// To extract FheUint8s from `Self`'s collection of RLWE ciphertexts, first
/// switch `Self` to `BatchFheUint8` with `key_switch(user_id)` where `user_id`
/// is user's id. This key switches collection of RLWE ciphertexts from
/// user's RLWE secret `u_j` to ideal RLWE secret `s` of the MPC protocol. Then
/// proceed to use `SampleExtract` on `BatchFheUint8` (for ex, call
/// `extract_at(0)` to extract FheUint8 stored at index 0)
pub struct NonInteractiveBatchedFheUint8<C> {
/// Vector of RLWE ciphertexts `C`
data: Vec<C>,
/// Count of FheUint8s packed in vector of RLWE ciphertexts
count: usize,
}
impl<M: MatrixEntity + MatrixMut<MatElement = u64>> From<&SeededBatchedFheUint8<M::R, [u8; 32]>>
for NonInteractiveBatchedFheUint8<M>
where
<M as Matrix>::R: RowMut,
{
/// Unseeds collection of seeded RLWE ciphertext in SeededBatchedFheUint8
/// and returns as `Self`
fn from(value: &SeededBatchedFheUint8<M::R, [u8; 32]>) -> Self {
BoolEvaluator::with_local(|e| {
let parameters = e.parameters();
let ring_size = parameters.rlwe_n().0;
let rlwe_q = parameters.rlwe_q();
let mut prng = DefaultSecureRng::new_seeded(value.seed);
let rlwes = value
.data
.iter()
.map(|partb| {
let mut rlwe = M::zeros(2, ring_size);
// sample A
RandomFillUniformInModulus::random_fill(&mut prng, rlwe_q, rlwe.get_row_mut(0));
// Copy over B
rlwe.get_row_mut(1).copy_from_slice(partb.as_ref());
rlwe
})
.collect_vec();
Self {
data: rlwes,
count: value.count,
}
})
}
}
impl<C> KeySwitchWithId<BatchedFheUint8<C>> for NonInteractiveBatchedFheUint8<C>
where
C: KeySwitchWithId<C>,
{
/// Key switch `Self`'s collection of RLWE cihertexts encrypted under user's
/// RLWE secret `u_j` to ideal RLWE secret `s` of the MPC protocol.
///
/// - user_id: user id of user `j`
fn key_switch(&self, user_id: usize) -> BatchedFheUint8<C> {
let data = self
.data
.iter()
.map(|c| c.key_switch(user_id))
.collect_vec();
BatchedFheUint8 {
data,
count: self.count,
}
}
}
pub struct SeededBatchedFheUint8<C, S> {
/// Vector of Seeded RLWE ciphertexts `C`.
///
/// If RLWE(m) = [a, b] s.t. m + e = b - as, `a` can be seeded and seeded
/// RLWE ciphertext only contains `b` polynomial
data: Vec<C>,
/// Seed for the ciphertexts
seed: S,
/// Count of FheUint8s packed in vector of RLWE ciphertexts
count: usize,
}
impl<K, C, S> Encryptor<[u8], SeededBatchedFheUint8<C, S>> for K
where
K: Encryptor<[bool], (Vec<C>, S)>,
{
/// Encrypt a slice of u8s of arbitray length packed into collection of
/// seeded RLWE ciphertexts and return `SeededBatchedFheUint8`
fn encrypt(&self, m: &[u8]) -> SeededBatchedFheUint8<C, S> {
// convert vector of u8s to vector bools
let bool_m = m
.iter()
.flat_map(|v| (0..8).into_iter().map(|i| (((*v) >> i) & 1) == 1))
.collect_vec();
let (cts, seed) = K::encrypt(&self, &bool_m);
SeededBatchedFheUint8 {
data: cts,
seed,
count: m.len(),
}
}
}
impl<C, S> SeededBatchedFheUint8<C, S> {
/// Unseed collection of seeded RLWE ciphertexts of `Self` and returns
/// `NonInteractiveBatchedFheUint8` with collection of unseeded RLWE
/// ciphertexts.
///
/// In non-interactive MPC setting, RLWE ciphertexts are encrypted under
/// user's RLWE secret `u_j`. The RLWE ciphertexts must be key switched to
/// ideal RLWE secret `s` of the MPC protocol before use.
///
/// Note that we don't provide `unseed` API from `Self` to
/// `BatchedFheUint8`. This is because:
///
/// - In non-interactive setting (1) client encrypts private inputs using
/// their secret `u_j` as `SeededBatchedFheUint8` and sends it to the
/// server. (2) Server unseeds `SeededBatchedFheUint8` into
/// `NonInteractiveBatchedFheUint8` indicating that private inputs are
/// still encrypted under user's RLWE secret `u_j`. (3) Server key
/// switches `NonInteractiveBatchedFheUint8` from user's RLWE secret `u_j`
/// to ideal RLWE secret `s` and outputs `BatchedFheUint8`. (4)
/// `BatchedFheUint8` always stores RLWE secret under ideal RLWE secret of
/// the protocol. Hence, it is safe to extract FheUint8s. Server proceeds
/// to extract necessary FheUint8s.
///
/// - In interactive setting (1) client always encrypts private inputs using
/// public key corresponding to ideal RLWE secret `s` of the protocol and
/// produces `BatchedFheUint8`. (2) Given `BatchedFheUint8` stores
/// collection of RLWE ciphertext under ideal RLWE secret `s`, server can
/// directly extract necessary FheUint8s to use.
///
/// Thus, there's no need to go directly from `Self` to `BatchedFheUint8`.
pub fn unseed<M>(&self) -> NonInteractiveBatchedFheUint8<M>
where
NonInteractiveBatchedFheUint8<M>: for<'a> From<&'a SeededBatchedFheUint8<C, S>>,
M: Matrix<R = C>,
{
NonInteractiveBatchedFheUint8::from(self)
}
}
impl<C, K> MultiPartyDecryptor<u8, FheUint8<C>> for K
where
K: MultiPartyDecryptor<bool, C>,
<Self as MultiPartyDecryptor<bool, C>>::DecryptionShare: Clone,
{
type DecryptionShare = Vec<<Self as MultiPartyDecryptor<bool, C>>::DecryptionShare>;
fn gen_decryption_share(&self, c: &FheUint8<C>) -> Self::DecryptionShare {
assert!(c.data().len() == 8);
c.data()
.iter()
.map(|bit_c| {
let decryption_share =
MultiPartyDecryptor::<bool, C>::gen_decryption_share(self, bit_c);
decryption_share
})
.collect_vec()
}
fn aggregate_decryption_shares(&self, c: &FheUint8<C>, shares: &[Self::DecryptionShare]) -> u8 {
let mut out = 0u8;
(0..8).into_iter().for_each(|i| {
// Collect bit i^th decryption share of each party
let bit_i_decryption_shares = shares.iter().map(|s| s[i].clone()).collect_vec();
let bit_i = MultiPartyDecryptor::<bool, C>::aggregate_decryption_shares(
self,
&c.data()[i],
&bit_i_decryption_shares,
);
if bit_i {
out += 1 << i;
}
});
out
}
}
impl<C, K> Encryptor<u8, FheUint8<C>> for K
where
K: Encryptor<bool, C>,
{
fn encrypt(&self, m: &u8) -> FheUint8<C> {
let cts = (0..8)
.into_iter()
.map(|i| {
let bit = ((m >> i) & 1) == 1;
K::encrypt(self, &bit)
})
.collect_vec();
FheUint8 { data: cts }
}
}
impl<K, C> Decryptor<u8, FheUint8<C>> for K
where
K: Decryptor<bool, C>,
{
fn decrypt(&self, c: &FheUint8<C>) -> u8 {
assert!(c.data.len() == 8);
let mut out = 0u8;
c.data().iter().enumerate().for_each(|(index, bit_c)| {
let bool = K::decrypt(self, bit_c);
if bool {
out += 1 << index;
}
});
out
}
}

+ 294
- 0
src/shortint/mod.rs

@ -0,0 +1,294 @@
mod enc_dec;
mod ops;
pub type FheUint8 = enc_dec::FheUint8<Vec<u64>>;
use std::cell::RefCell;
use crate::bool::{BoolEvaluator, BooleanGates, FheBool, RuntimeServerKey};
thread_local! {
static DIV_ZERO_ERROR: RefCell<Option<FheBool>> = RefCell::new(None);
}
/// Returns Boolean ciphertext indicating whether last division was attempeted
/// with decnomiantor set to 0.
pub fn div_zero_error_flag() -> Option<FheBool> {
DIV_ZERO_ERROR.with_borrow(|c| c.clone())
}
/// Reset all error flags
///
/// Error flags are thread local. When running multiple circuits in sequence
/// within a single program you must prevent error flags set during the
/// execution of previous circuit to affect error flags set during execution of
/// the next circuit. To do so call `reset_error_flags()`.
pub fn reset_error_flags() {
DIV_ZERO_ERROR.with_borrow_mut(|c| *c = None);
}
mod frontend {
use super::ops::{
arbitrary_bit_adder, arbitrary_bit_division_for_quotient_and_rem, arbitrary_bit_subtractor,
eight_bit_mul, is_zero,
};
use crate::utils::{Global, WithLocal};
use super::*;
/// Set Div by Zero flag after each divison. Div by zero flag is set to true
/// if either 1 of the division executed in circuit evaluation has
/// denominator set to 0.
fn set_div_by_zero_flag(denominator: &FheUint8) {
{
BoolEvaluator::with_local_mut(|e| {
let key = RuntimeServerKey::global();
let is_zero = is_zero(e, denominator.data(), key);
DIV_ZERO_ERROR.with_borrow_mut(|before_is_zero| {
if before_is_zero.is_none() {
*before_is_zero = Some(FheBool { data: is_zero });
} else {
e.or_inplace(before_is_zero.as_mut().unwrap().data_mut(), &is_zero, key);
}
});
})
}
}
mod arithetic {
use super::*;
use std::ops::{Add, AddAssign, Div, Mul, Rem, Sub};
impl AddAssign<&FheUint8> for FheUint8 {
fn add_assign(&mut self, rhs: &FheUint8) {
BoolEvaluator::with_local_mut_mut(&mut |e| {
let key = RuntimeServerKey::global();
arbitrary_bit_adder(e, self.data_mut(), rhs.data(), false, key);
});
}
}
impl Add<&FheUint8> for &FheUint8 {
type Output = FheUint8;
fn add(self, rhs: &FheUint8) -> Self::Output {
let mut a = self.clone();
a += rhs;
a
}
}
impl Sub<&FheUint8> for &FheUint8 {
type Output = FheUint8;
fn sub(self, rhs: &FheUint8) -> Self::Output {
BoolEvaluator::with_local_mut(|e| {
let key = RuntimeServerKey::global();
let (out, _, _) = arbitrary_bit_subtractor(e, self.data(), rhs.data(), key);
FheUint8 { data: out }
})
}
}
impl Mul<&FheUint8> for &FheUint8 {
type Output = FheUint8;
fn mul(self, rhs: &FheUint8) -> Self::Output {
BoolEvaluator::with_local_mut(|e| {
let key = RuntimeServerKey::global();
let out = eight_bit_mul(e, self.data(), rhs.data(), key);
FheUint8 { data: out }
})
}
}
impl Div<&FheUint8> for &FheUint8 {
type Output = FheUint8;
fn div(self, rhs: &FheUint8) -> Self::Output {
// set div by 0 error flag
set_div_by_zero_flag(rhs);
BoolEvaluator::with_local_mut(|e| {
let key = RuntimeServerKey::global();
let (quotient, _) = arbitrary_bit_division_for_quotient_and_rem(
e,
self.data(),
rhs.data(),
key,
);
FheUint8 { data: quotient }
})
}
}
impl Rem<&FheUint8> for &FheUint8 {
type Output = FheUint8;
fn rem(self, rhs: &FheUint8) -> Self::Output {
BoolEvaluator::with_local_mut(|e| {
let key = RuntimeServerKey::global();
let (_, remainder) = arbitrary_bit_division_for_quotient_and_rem(
e,
self.data(),
rhs.data(),
key,
);
FheUint8 { data: remainder }
})
}
}
impl FheUint8 {
/// Calculates `Self += rhs` and returns `overflow`
///
/// `overflow` is set to `True` if `Self += rhs` overflowed,
/// otherwise it is set to `False`
pub fn overflowing_add_assign(&mut self, rhs: &FheUint8) -> FheBool {
BoolEvaluator::with_local_mut_mut(&mut |e| {
let key = RuntimeServerKey::global();
let (overflow, _) =
arbitrary_bit_adder(e, self.data_mut(), rhs.data(), false, key);
FheBool { data: overflow }
})
}
/// Returns (Self + rhs, overflow).
///
/// `overflow` is set to `True` if `Self + rhs` overflowed,
/// otherwise it is set to `False`
pub fn overflowing_add(self, rhs: &FheUint8) -> (FheUint8, FheBool) {
BoolEvaluator::with_local_mut(|e| {
let mut lhs = self.clone();
let key = RuntimeServerKey::global();
let (overflow, _) =
arbitrary_bit_adder(e, lhs.data_mut(), rhs.data(), false, key);
(lhs, FheBool { data: overflow })
})
}
/// Returns (Self - rhs, overflow).
///
/// `overflow` is set to `True` if `Self - rhs` overflowed,
/// otherwise it is set to `False`
pub fn overflowing_sub(&self, rhs: &FheUint8) -> (FheUint8, FheBool) {
BoolEvaluator::with_local_mut(|e| {
let key = RuntimeServerKey::global();
let (out, mut overflow, _) =
arbitrary_bit_subtractor(e, self.data(), rhs.data(), key);
e.not_inplace(&mut overflow);
(FheUint8 { data: out }, FheBool { data: overflow })
})
}
/// Returns (quotient, remainder) s.t. self = rhs x quotient +
/// remainder.
///
/// If rhs is 0, then quotient = 255, remainder = self, and Div by
/// Zero error flag (accessible via `div_zero_error_flag`) is set to
/// `True`
pub fn div_rem(&self, rhs: &FheUint8) -> (FheUint8, FheUint8) {
// set div by 0 error flag
set_div_by_zero_flag(rhs);
BoolEvaluator::with_local_mut(|e| {
let key = RuntimeServerKey::global();
let (quotient, remainder) = arbitrary_bit_division_for_quotient_and_rem(
e,
self.data(),
rhs.data(),
key,
);
(FheUint8 { data: quotient }, FheUint8 { data: remainder })
})
}
}
}
mod booleans {
use crate::shortint::ops::{
arbitrary_bit_comparator, arbitrary_bit_equality, arbitrary_bit_mux,
};
use super::*;
impl FheUint8 {
/// Returns `FheBool` indicating `Self == other`
pub fn eq(&self, other: &FheUint8) -> FheBool {
BoolEvaluator::with_local_mut(|e| {
let key = RuntimeServerKey::global();
let out = arbitrary_bit_equality(e, self.data(), other.data(), key);
FheBool { data: out }
})
}
/// Returns `FheBool` indicating `Self != other`
pub fn neq(&self, other: &FheUint8) -> FheBool {
BoolEvaluator::with_local_mut(|e| {
let key = RuntimeServerKey::global();
let mut is_equal = arbitrary_bit_equality(e, self.data(), other.data(), key);
e.not_inplace(&mut is_equal);
FheBool { data: is_equal }
})
}
/// Returns `FheBool` indicating `Self < other`
pub fn lt(&self, other: &FheUint8) -> FheBool {
BoolEvaluator::with_local_mut(|e| {
let key = RuntimeServerKey::global();
let out = arbitrary_bit_comparator(e, other.data(), self.data(), key);
FheBool { data: out }
})
}
/// Returns `FheBool` indicating `Self > other`
pub fn gt(&self, other: &FheUint8) -> FheBool {
BoolEvaluator::with_local_mut(|e| {
let key = RuntimeServerKey::global();
let out = arbitrary_bit_comparator(e, self.data(), other.data(), key);
FheBool { data: out }
})
}
/// Returns `FheBool` indicating `Self <= other`
pub fn le(&self, other: &FheUint8) -> FheBool {
BoolEvaluator::with_local_mut(|e| {
let key = RuntimeServerKey::global();
let mut a_greater_b =
arbitrary_bit_comparator(e, self.data(), other.data(), key);
e.not_inplace(&mut a_greater_b);
FheBool { data: a_greater_b }
})
}
/// Returns `FheBool` indicating `Self >= other`
pub fn ge(&self, other: &FheUint8) -> FheBool {
BoolEvaluator::with_local_mut(|e| {
let key = RuntimeServerKey::global();
let mut a_less_b = arbitrary_bit_comparator(e, other.data(), self.data(), key);
e.not_inplace(&mut a_less_b);
FheBool { data: a_less_b }
})
}
/// Returns `Self` if `selector = True` else returns `other`
pub fn mux(&self, other: &FheUint8, selector: &FheBool) -> FheUint8 {
BoolEvaluator::with_local_mut(|e| {
let key = RuntimeServerKey::global();
let out = arbitrary_bit_mux(e, selector.data(), self.data(), other.data(), key);
FheUint8 { data: out }
})
}
/// Returns max(`Self`, `other`)
pub fn max(&self, other: &FheUint8) -> FheUint8 {
let self_gt = self.gt(other);
self.mux(other, &self_gt)
}
/// Returns min(`Self`, `other`)
pub fn min(&self, other: &FheUint8) -> FheUint8 {
let self_lt = self.lt(other);
self.mux(other, &self_lt)
}
}
}
}

+ 356
- 0
src/shortint/ops.rs

@ -0,0 +1,356 @@
use itertools::{izip, Itertools};
use crate::bool::BooleanGates;
pub(super) fn half_adder<E: BooleanGates>(
evaluator: &mut E,
a: &mut E::Ciphertext,
b: &E::Ciphertext,
key: &E::Key,
) -> E::Ciphertext {
let carry = evaluator.and(a, b, key);
evaluator.xor_inplace(a, b, key);
carry
}
pub(super) fn full_adder_plain_carry_in<E: BooleanGates>(
evaluator: &mut E,
a: &mut E::Ciphertext,
b: &E::Ciphertext,
carry_in: bool,
key: &E::Key,
) -> E::Ciphertext {
let mut a_and_b = evaluator.and(a, b, key);
evaluator.xor_inplace(a, b, key); //a = a ^ b
if carry_in {
// a_and_b = A & B | ((A^B) & C_in={True})
evaluator.or_inplace(&mut a_and_b, &a, key);
} else {
// a_and_b = A & B | ((A^B) & C_in={False})
// a_and_b = A & B
// noop
}
// In xor if a input is 0, output equals the firt variable. If input is 1 then
// output equals !(first variable)
if carry_in {
// (A^B)^1 = !(A^B)
evaluator.not_inplace(a);
} else {
// (A^B)^0
// no-op
}
a_and_b
}
pub(super) fn full_adder<E: BooleanGates>(
evaluator: &mut E,
a: &mut E::Ciphertext,
b: &E::Ciphertext,
carry_in: &E::Ciphertext,
key: &E::Key,
) -> E::Ciphertext {
let mut a_and_b = evaluator.and(a, b, key);
evaluator.xor_inplace(a, b, key); //a = a ^ b
let a_xor_b_and_c = evaluator.and(&a, carry_in, key);
evaluator.or_inplace(&mut a_and_b, &a_xor_b_and_c, key); // a_and_b = A & B | ((A^B) & C_in)
evaluator.xor_inplace(a, &carry_in, key);
a_and_b
}
pub(super) fn arbitrary_bit_adder<E: BooleanGates>(
evaluator: &mut E,
a: &mut [E::Ciphertext],
b: &[E::Ciphertext],
carry_in: bool,
key: &E::Key,
) -> (E::Ciphertext, E::Ciphertext)
where
E::Ciphertext: Clone,
{
assert!(a.len() == b.len());
let n = a.len();
let mut carry = if !carry_in {
half_adder(evaluator, &mut a[0], &b[0], key)
} else {
full_adder_plain_carry_in(evaluator, &mut a[0], &b[0], true, key)
};
izip!(a.iter_mut(), b.iter())
.skip(1)
.take(n - 3)
.for_each(|(a_bit, b_bit)| {
carry = full_adder(evaluator, a_bit, b_bit, &carry, key);
});
let carry_last_last = full_adder(evaluator, &mut a[n - 2], &b[n - 2], &carry, key);
let carry_last = full_adder(evaluator, &mut a[n - 1], &b[n - 1], &carry_last_last, key);
(carry_last, carry_last_last)
}
pub(super) fn arbitrary_bit_subtractor<E: BooleanGates>(
evaluator: &mut E,
a: &[E::Ciphertext],
b: &[E::Ciphertext],
key: &E::Key,
) -> (Vec<E::Ciphertext>, E::Ciphertext, E::Ciphertext)
where
E::Ciphertext: Clone,
{
let mut neg_b: Vec<E::Ciphertext> = b.iter().map(|v| evaluator.not(v)).collect();
let (carry_last, carry_last_last) = arbitrary_bit_adder(evaluator, &mut neg_b, &a, true, key);
return (neg_b, carry_last, carry_last_last);
}
pub(super) fn bit_mux<E: BooleanGates>(
evaluator: &mut E,
selector: E::Ciphertext,
if_true: &E::Ciphertext,
if_false: &E::Ciphertext,
key: &E::Key,
) -> E::Ciphertext {
// (s&a) | ((1-s)^b)
let not_selector = evaluator.not(&selector);
let mut s_and_a = evaluator.and(&selector, if_true, key);
let s_and_b = evaluator.and(&not_selector, if_false, key);
evaluator.or(&mut s_and_a, &s_and_b, key);
s_and_a
}
pub(super) fn arbitrary_bit_mux<E: BooleanGates>(
evaluator: &mut E,
selector: &E::Ciphertext,
if_true: &[E::Ciphertext],
if_false: &[E::Ciphertext],
key: &E::Key,
) -> Vec<E::Ciphertext> {
// (s&a) | ((1-s)^b)
let not_selector = evaluator.not(&selector);
izip!(if_true.iter(), if_false.iter())
.map(|(a, b)| {
let mut s_and_a = evaluator.and(&selector, a, key);
let s_and_b = evaluator.and(&not_selector, b, key);
evaluator.or_inplace(&mut s_and_a, &s_and_b, key);
s_and_a
})
.collect()
}
pub(super) fn eight_bit_mul<E: BooleanGates>(
evaluator: &mut E,
a: &[E::Ciphertext],
b: &[E::Ciphertext],
key: &E::Key,
) -> Vec<E::Ciphertext> {
assert!(a.len() == 8);
assert!(b.len() == 8);
let mut carries = Vec::with_capacity(7);
let mut out = Vec::with_capacity(8);
for i in (0..8) {
if i == 0 {
let s = evaluator.and(&a[0], &b[0], key);
out.push(s);
} else if i == 1 {
let mut tmp0 = evaluator.and(&a[1], &b[0], key);
let tmp1 = evaluator.and(&a[0], &b[1], key);
let carry = half_adder(evaluator, &mut tmp0, &tmp1, key);
carries.push(carry);
out.push(tmp0);
} else {
let mut sum = {
let mut sum = evaluator.and(&a[i], &b[0], key);
let tmp = evaluator.and(&a[i - 1], &b[1], key);
carries[0] = full_adder(evaluator, &mut sum, &tmp, &carries[0], key);
sum
};
for j in 2..i {
let tmp = evaluator.and(&a[i - j], &b[j], key);
carries[j - 1] = full_adder(evaluator, &mut sum, &tmp, &carries[j - 1], key);
}
let tmp = evaluator.and(&a[0], &b[i], key);
let carry = half_adder(evaluator, &mut sum, &tmp, key);
carries.push(carry);
out.push(sum)
}
debug_assert!(carries.len() <= 7);
}
out
}
pub(super) fn arbitrary_bit_division_for_quotient_and_rem<E: BooleanGates>(
evaluator: &mut E,
a: &[E::Ciphertext],
b: &[E::Ciphertext],
key: &E::Key,
) -> (Vec<E::Ciphertext>, Vec<E::Ciphertext>)
where
E::Ciphertext: Clone,
{
let n = a.len();
let neg_b = b.iter().map(|v| evaluator.not(v)).collect_vec();
// Both remainder and quotient are initially stored in Big-endian in contrast to
// the usual little endian we use. This is more friendly to vec pushes in
// division. After computing remainder and quotient, we simply reverse the
// vectors.
let mut remainder = vec![];
let mut quotient = vec![];
for i in 0..n {
// left shift
remainder.push(a[n - 1 - i].clone());
let mut subtract = remainder.clone();
// subtraction
// At i^th iteration remainder is only filled with i bits and the rest of the
// bits are zero. For example, at i = 1
// 0 0 0 0 0 0 X X => remainder
// - Y Y Y Y Y Y Y Y => divisor .
// --------------- .
// Z Z Z Z Z Z Z Z => result
// For the next iteration we only care about result if divisor is <= remainder
// (which implies result <= remainder). Otherwise we care about remainder
// (recall re-storing division). Hence we optimise subtraction and
// ignore full adders for places where remainder bits are known to be false
// bits. We instead use `ANDs` to compute the carry overs, since the
// last carry over indicates whether the value has overflown (i.e. divisor <=
// remainder). Last carry out is `true` if value has not overflown, otherwise
// false.
let mut carry =
full_adder_plain_carry_in(evaluator, &mut subtract[i], &neg_b[0], true, key);
for j in 1..i + 1 {
carry = full_adder(evaluator, &mut subtract[i - j], &neg_b[j], &carry, key);
}
for j in i + 1..n {
// All I care about are the carries
evaluator.and_inplace(&mut carry, &neg_b[j], key);
}
let not_carry = evaluator.not(&carry);
// Choose `remainder` if subtraction has overflown (i.e. carry = false).
// Otherwise choose `subtractor`.
//
// mux k^a | !(k)^b, where k is the selector.
izip!(remainder.iter_mut(), subtract.iter_mut()).for_each(|(r, s)| {
// choose `s` when carry is true, otherwise choose r
evaluator.and_inplace(s, &carry, key);
evaluator.and_inplace(r, &not_carry, key);
evaluator.or_inplace(r, s, key);
});
// Set i^th MSB of quotient to 1 if carry = true, otherwise set it to 0.
// X&1 | X&0 => X&1 => X
quotient.push(carry);
}
remainder.reverse();
quotient.reverse();
(quotient, remainder)
}
pub(super) fn is_zero<E: BooleanGates>(
evaluator: &mut E,
a: &[E::Ciphertext],
key: &E::Key,
) -> E::Ciphertext {
let mut a = a.iter().map(|v| evaluator.not(v)).collect_vec();
let (out, rest_a) = a.split_at_mut(1);
rest_a.iter().for_each(|c| {
evaluator.and_inplace(&mut out[0], c, key);
});
return a.remove(0);
}
pub(super) fn arbitrary_bit_equality<E: BooleanGates>(
evaluator: &mut E,
a: &[E::Ciphertext],
b: &[E::Ciphertext],
key: &E::Key,
) -> E::Ciphertext {
assert!(a.len() == b.len());
let mut out = evaluator.xnor(&a[0], &b[0], key);
izip!(a.iter(), b.iter()).skip(1).for_each(|(abit, bbit)| {
let e = evaluator.xnor(abit, bbit, key);
evaluator.and_inplace(&mut out, &e, key);
});
return out;
}
/// Comparator handle computes comparator result 2ns MSB onwards. It is
/// separated because comparator subroutine for signed and unsgind integers
/// differs only for 1st MSB and is common second MSB onwards
fn _comparator_handler_from_second_msb<E: BooleanGates>(
evaluator: &mut E,
a: &[E::Ciphertext],
b: &[E::Ciphertext],
mut comp: E::Ciphertext,
mut casc: E::Ciphertext,
key: &E::Key,
) -> E::Ciphertext {
let n = a.len();
// handle MSB - 1
let mut tmp = evaluator.not(&b[n - 2]);
evaluator.and_inplace(&mut tmp, &a[n - 2], key);
evaluator.and_inplace(&mut tmp, &casc, key);
evaluator.or_inplace(&mut comp, &tmp, key);
for i in 2..n {
// calculate cascading bit
let tmp_casc = evaluator.xnor(&a[n - i], &b[n - i], key);
evaluator.and_inplace(&mut casc, &tmp_casc, key);
// calculate computate bit
let mut tmp = evaluator.not(&b[n - 1 - i]);
evaluator.and_inplace(&mut tmp, &a[n - 1 - i], key);
evaluator.and_inplace(&mut tmp, &casc, key);
evaluator.or_inplace(&mut comp, &tmp, key);
}
return comp;
}
/// Signed integer comparison is same as unsigned integer with MSB flipped.
pub(super) fn arbitrary_signed_bit_comparator<E: BooleanGates>(
evaluator: &mut E,
a: &[E::Ciphertext],
b: &[E::Ciphertext],
key: &E::Key,
) -> E::Ciphertext {
assert!(a.len() == b.len());
let n = a.len();
// handle MSB
let mut comp = evaluator.not(&a[n - 1]);
evaluator.and_inplace(&mut comp, &b[n - 1], key); // comp
let casc = evaluator.xnor(&a[n - 1], &b[n - 1], key); // casc
return _comparator_handler_from_second_msb(evaluator, a, b, comp, casc, key);
}
pub(super) fn arbitrary_bit_comparator<E: BooleanGates>(
evaluator: &mut E,
a: &[E::Ciphertext],
b: &[E::Ciphertext],
key: &E::Key,
) -> E::Ciphertext {
assert!(a.len() == b.len());
let n = a.len();
// handle MSB
let mut comp = evaluator.not(&b[n - 1]);
evaluator.and_inplace(&mut comp, &a[n - 1], key);
let casc = evaluator.xnor(&a[n - 1], &b[n - 1], key);
return _comparator_handler_from_second_msb(evaluator, a, b, comp, casc, key);
}

+ 218
- 52
src/utils.rs

@ -1,9 +1,14 @@
use std::usize;
use std::{usize, vec};
use itertools::Itertools;
use num_traits::{PrimInt, Signed};
use itertools::{izip, Itertools};
use num_traits::{One, PrimInt, Signed};
use crate::RandomUniformDist;
use crate::{
backend::Modulus,
decomposer::NumInfo,
random::{RandomElementInModulus, RandomFill},
Matrix, RowEntity, RowMut,
};
pub trait WithLocal { pub trait WithLocal {
fn with_local<F, R>(func: F) -> R fn with_local<F, R>(func: F) -> R
where where
@ -12,26 +17,78 @@ pub trait WithLocal {
fn with_local_mut<F, R>(func: F) -> R fn with_local_mut<F, R>(func: F) -> R
where where
F: Fn(&mut Self) -> R; F: Fn(&mut Self) -> R;
fn with_local_mut_mut<F, R>(func: &mut F) -> R
where
F: FnMut(&mut Self) -> R;
}
pub trait Global {
fn global() -> &'static Self;
}
pub(crate) trait ShoupMul {
fn representation(value: Self, q: Self) -> Self;
fn mul(a: Self, b: Self, b_shoup: Self, q: Self) -> Self;
}
impl ShoupMul for u64 {
#[inline]
fn representation(value: Self, q: Self) -> Self {
((value as u128 * (1u128 << 64)) / q as u128) as u64
}
#[inline]
/// Returns a * b % q
fn mul(a: Self, b: Self, b_shoup: Self, q: Self) -> Self {
(b.wrapping_mul(a))
.wrapping_sub(q.wrapping_mul(((b_shoup as u128 * a as u128) >> 64) as u64))
}
}
pub(crate) trait ToShoup {
type Modulus;
fn to_shoup(value: &Self, modulus: Self::Modulus) -> Self;
}
impl ToShoup for u64 {
type Modulus = u64;
fn to_shoup(value: &Self, modulus: Self) -> Self {
((*value as u128 * (1u128 << 64)) / modulus as u128) as u64
}
}
impl ToShoup for Vec<Vec<u64>> {
type Modulus = u64;
fn to_shoup(value: &Self, modulus: Self::Modulus) -> Self {
let (row, col) = value.dimension();
let mut shoup_value = vec![vec![0u64; col]; row];
izip!(shoup_value.iter_mut(), value.iter()).for_each(|(shoup_r, r)| {
izip!(shoup_r.iter_mut(), r.iter()).for_each(|(s, e)| {
*s = u64::to_shoup(e, modulus);
})
});
shoup_value
}
} }
pub fn fill_random_ternary_secret_with_hamming_weight< pub fn fill_random_ternary_secret_with_hamming_weight<
T: Signed, T: Signed,
R: RandomUniformDist<[u8], Parameters = u8> + RandomUniformDist<usize, Parameters = usize>,
R: RandomFill<[u8]> + RandomElementInModulus<usize, usize>,
>( >(
out: &mut [T], out: &mut [T],
hamming_weight: usize, hamming_weight: usize,
rng: &mut R, rng: &mut R,
) { ) {
let mut bytes = vec![0u8; hamming_weight.div_ceil(8)]; let mut bytes = vec![0u8; hamming_weight.div_ceil(8)];
RandomUniformDist::<[u8]>::random_fill(rng, &0, &mut bytes);
RandomFill::<[u8]>::random_fill(rng, &mut bytes);
let size = out.len(); let size = out.len();
let mut secret_indices = (0..size).into_iter().map(|i| i).collect_vec(); let mut secret_indices = (0..size).into_iter().map(|i| i).collect_vec();
let mut bit_index = 0; let mut bit_index = 0;
let mut byte_index = 0; let mut byte_index = 0;
for _ in 0..hamming_weight {
let mut s_index = 0usize;
RandomUniformDist::<usize>::random_fill(rng, &secret_indices.len(), &mut s_index);
for i in 0..hamming_weight {
let s_index = RandomElementInModulus::<usize, usize>::random(rng, &secret_indices.len());
let curr_bit = (bytes[byte_index] >> bit_index) & 1; let curr_bit = (bytes[byte_index] >> bit_index) & 1;
if curr_bit == 1 { if curr_bit == 1 {
@ -41,7 +98,7 @@ pub fn fill_random_ternary_secret_with_hamming_weight<
} }
secret_indices[s_index] = *secret_indices.last().unwrap(); secret_indices[s_index] = *secret_indices.last().unwrap();
secret_indices.truncate(secret_indices.len());
secret_indices.truncate(secret_indices.len() - 1);
if bit_index == 7 { if bit_index == 7 {
bit_index = 0; bit_index = 0;
@ -62,7 +119,7 @@ fn is_probably_prime(candidate: u64) -> bool {
/// - $prime \lt upper_bound$ /// - $prime \lt upper_bound$
/// - $\log{prime} = num_bits$ /// - $\log{prime} = num_bits$
/// - `prime % modulo == 1` /// - `prime % modulo == 1`
pub fn generate_prime(num_bits: usize, modulo: u64, upper_bound: u64) -> Option<u64> {
pub(crate) fn generate_prime(num_bits: usize, modulo: u64, upper_bound: u64) -> Option<u64> {
let leading_zeros = (64 - num_bits) as u32; let leading_zeros = (64 - num_bits) as u32;
let mut tentative_prime = upper_bound - 1; let mut tentative_prime = upper_bound - 1;
@ -107,15 +164,11 @@ pub fn mod_exponent(a: u64, mut b: u64, q: u64) -> u64 {
out out
} }
pub fn mod_inverse(a: u64, q: u64) -> u64 {
pub(crate) fn mod_inverse(a: u64, q: u64) -> u64 {
mod_exponent(a, q - 2, q) mod_exponent(a, q - 2, q)
} }
pub fn shoup_representation_fq(v: u64, q: u64) -> u64 {
((v as u128 * (1u128 << 64)) / q as u128) as u64
}
pub fn negacyclic_mul<T: PrimInt, F: Fn(&T, &T) -> T>(
pub(crate) fn negacyclic_mul<T: PrimInt, F: Fn(&T, &T) -> T>(
a: &[T], a: &[T],
b: &[T], b: &[T],
mul: F, mul: F,
@ -138,54 +191,167 @@ pub fn negacyclic_mul T>(
return r; return r;
} }
pub trait TryConvertFrom<T: ?Sized> {
type Parameters: ?Sized;
/// Returns a polynomial X^{emebedding_factor * si} \mod {Z_Q / X^{N}+1}
pub(crate) fn encode_x_pow_si_with_emebedding_factor<
R: RowEntity + RowMut,
M: Modulus<Element = R::Element>,
>(
si: i32,
embedding_factor: usize,
ring_size: usize,
modulus: &M,
) -> R
where
R::Element: One,
{
assert!((si.abs() as usize) < ring_size);
let mut m = R::zeros(ring_size);
let si = si * (embedding_factor as i32);
if si < 0 {
// X^{-si} = X^{2N-si} = -X^{N-si}, assuming abs(si) < N
m.as_mut()[ring_size - (si.abs() as usize)] = modulus.neg_one();
} else {
m.as_mut()[si as usize] = R::Element::one();
}
m
}
fn try_convert_from(value: &T, parameters: &Self::Parameters) -> Self;
pub(crate) fn puncture_p_rng<S: Default + Copy, R: RandomFill<S>>(
p_rng: &mut R,
times: usize,
) -> S {
let mut out = S::default();
for _ in 0..times {
RandomFill::<S>::random_fill(p_rng, &mut out);
}
return out;
} }
impl TryConvertFrom<[i32]> for Vec<Vec<u32>> {
type Parameters = u32;
fn try_convert_from(value: &[i32], parameters: &Self::Parameters) -> Self {
let row0 = value
.iter()
.map(|v| {
let is_neg = v.is_negative();
let v_u32 = v.abs() as u32;
pub(crate) fn log2<T: PrimInt + NumInfo>(v: &T) -> usize {
if (*v & (*v - T::one())) == T::zero() {
// value is power of 2
(T::BITS - v.leading_zeros() - 1) as usize
} else {
(T::BITS - v.leading_zeros()) as usize
}
}
assert!(v_u32 < *parameters);
pub trait TryConvertFrom1<T: ?Sized, P> {
fn try_convert_from(value: &T, parameters: &P) -> Self;
}
if is_neg {
parameters - v_u32
} else {
v_u32
}
})
.collect_vec();
impl<P: Modulus<Element = u64>> TryConvertFrom1<[i64], P> for Vec<u64> {
fn try_convert_from(value: &[i64], parameters: &P) -> Self {
value
.iter()
.map(|v| parameters.map_element_from_i64(*v))
.collect_vec()
}
}
vec![row0]
impl<P: Modulus<Element = u64>> TryConvertFrom1<[i32], P> for Vec<u64> {
fn try_convert_from(value: &[i32], parameters: &P) -> Self {
value
.iter()
.map(|v| parameters.map_element_from_i64(*v as i64))
.collect_vec()
} }
} }
impl TryConvertFrom<[i32]> for Vec<Vec<u64>> {
type Parameters = u64;
fn try_convert_from(value: &[i32], parameters: &Self::Parameters) -> Self {
let row0 = value
impl<P: Modulus> TryConvertFrom1<[P::Element], P> for Vec<i64> {
fn try_convert_from(value: &[P::Element], parameters: &P) -> Self {
value
.iter() .iter()
.map(|v| {
let is_neg = v.is_negative();
let v_u64 = v.abs() as u64;
.map(|v| parameters.map_element_to_i64(v))
.collect_vec()
}
}
#[cfg(test)]
pub(crate) mod tests {
use std::fmt::Debug;
use num_traits::ToPrimitive;
use crate::random::DefaultSecureRng;
assert!(v_u64 < *parameters);
use super::fill_random_ternary_secret_with_hamming_weight;
if is_neg {
parameters - v_u64
} else {
v_u64
#[derive(Clone)]
pub(crate) struct Stats<T> {
pub(crate) samples: Vec<T>,
}
impl<T> Default for Stats<T> {
fn default() -> Self {
Stats { samples: vec![] }
}
}
impl<T: Copy + ToPrimitive + Debug> Stats<T>
where
// T: for<'a> Sum<&'a T>,
T: for<'a> std::iter::Sum<&'a T> + std::iter::Sum<T>,
{
pub(crate) fn new() -> Self {
Self { samples: vec![] }
}
pub(crate) fn mean(&self) -> f64 {
self.samples.iter().sum::<T>().to_f64().unwrap() / (self.samples.len() as f64)
}
pub(crate) fn variance(&self) -> f64 {
let mean = self.mean();
// diff
let diff_sq = self
.samples
.iter()
.map(|v| {
let t = v.to_f64().unwrap() - mean;
t * t
})
.into_iter()
.sum::<f64>();
diff_sq / (self.samples.len() as f64 - 1.0)
}
pub(crate) fn std_dev(&self) -> f64 {
self.variance().sqrt()
}
pub(crate) fn add_many_samples(&mut self, values: &[T]) {
self.samples.extend(values.iter());
}
pub(crate) fn add_sample(&mut self, value: T) {
self.samples.push(value)
}
pub(crate) fn merge_in(&mut self, other: &Self) {
self.samples.extend(other.samples.iter());
}
}
#[test]
fn ternary_secret_has_correct_hw() {
let mut rng = DefaultSecureRng::new();
for n in 4..15 {
let ring_size = 1 << n;
let mut out = vec![0i32; ring_size];
fill_random_ternary_secret_with_hamming_weight(&mut out, ring_size >> 1, &mut rng);
// check hamming weight of out equals ring_size/2
let mut non_zeros = 0;
out.iter().for_each(|i| {
if *i != 0 {
non_zeros += 1;
} }
})
.collect_vec();
});
vec![row0]
assert_eq!(ring_size >> 1, non_zeros);
}
} }
} }

Loading…
Cancel
Save