Browse Source

ckks addition & substraction of ciphertexts

gfhe-over-ring-trait
arnaucube 1 month ago
parent
commit
84c54e8edd
2 changed files with 113 additions and 11 deletions
  1. +2
    -3
      ckks/src/encoder.rs
  2. +111
    -8
      ckks/src/lib.rs

+ 2
- 3
ckks/src/encoder.rs

@ -153,10 +153,9 @@ mod tests {
#[test]
fn test_encode_decode() -> Result<()> {
const Q: u64 = 1024;
// const N: usize = 4; // ie. m=2*n=8
const N: usize = 16;
const N: usize = 32;
let T = 16; // WIP
let T = 128; // WIP
let mut rng = rand::thread_rng();
for _ in 0..100 {

+ 111
- 8
ckks/src/lib.rs

@ -72,7 +72,7 @@ impl CKKS {
fn decrypt(
&self, // TODO maybe rm?
sk: SecretKey<Q, N>,
sk: &SecretKey<Q, N>,
c: (Rq<Q, N>, Rq<Q, N>),
) -> Result<R<N>> {
let m = c.0.clone() + c.1 * sk.0;
@ -95,10 +95,25 @@ impl CKKS {
sk: SecretKey<Q, N>,
c: (Rq<Q, N>, Rq<Q, N>),
) -> Result<Vec<C<f64>>> {
let d = self.decrypt(sk, c)?;
let d = self.decrypt(&sk, c)?;
self.encoder.decode(&d)
}
pub fn add(
&self,
c0: &(Rq<Q, N>, Rq<Q, N>),
c1: &(Rq<Q, N>, Rq<Q, N>),
) -> Result<(Rq<Q, N>, Rq<Q, N>)> {
Ok((&c0.0 + &c1.0, &c0.1 + &c1.1))
}
pub fn sub(
&self,
c0: &(Rq<Q, N>, Rq<Q, N>),
c1: &(Rq<Q, N>, Rq<Q, N>),
) -> Result<(Rq<Q, N>, Rq<Q, N>)> {
Ok((&c0.0 - &c1.0, &c0.1 + &c1.1))
}
}
#[cfg(test)]
@ -108,10 +123,10 @@ mod tests {
#[test]
fn test_encrypt_decrypt() -> Result<()> {
const Q: u64 = 2u64.pow(16) + 1;
const T: u64 = 16;
const N: usize = 8;
const N: usize = 32;
const T: u64 = 50;
let scale_factor_u64 = 512_u64; // delta
let scale_factor = C::<f64>::new(512.0, 0.0); // delta
let scale_factor = C::<f64>::new(scale_factor_u64 as f64, 0.0); // delta
let mut rng = rand::thread_rng();
@ -124,7 +139,7 @@ mod tests {
let m = m_raw * scale_factor_u64;
let ct = ckks.encrypt(&mut rng, &pk, &m)?;
let m_decrypted = ckks.decrypt(sk, ct)?;
let m_decrypted = ckks.decrypt(&sk, ct)?;
let m_decrypted: Vec<u64> = m_decrypted
.coeffs()
@ -141,8 +156,8 @@ mod tests {
#[test]
fn test_encode_encrypt_decrypt_decode() -> Result<()> {
const Q: u64 = 2u64.pow(16) + 1;
const N: usize = 16;
const T: u64 = 16;
const N: usize = 4;
let scale_factor = C::<f64>::new(512.0, 0.0); // delta
let mut rng = rand::thread_rng();
@ -155,6 +170,7 @@ mod tests {
.take(N / 2)
.collect();
let m: R<N> = ckks.encoder.encode(&z)?;
println!("{}", m);
// sanity check
{
@ -167,7 +183,8 @@ mod tests {
}
let ct = ckks.encrypt(&mut rng, &pk, &m)?;
let m_decrypted = ckks.decrypt(sk, ct)?;
let m_decrypted = ckks.decrypt(&sk, ct)?;
println!("{}", m_decrypted);
let z_decrypted = ckks.encoder.decode(&m_decrypted)?;
@ -180,4 +197,90 @@ mod tests {
Ok(())
}
#[test]
fn test_add() -> Result<()> {
const Q: u64 = 2u64.pow(16) + 1;
const N: usize = 16;
const T: u64 = 10;
let scale_factor = C::<f64>::new(1024.0, 0.0); // delta
let mut rng = rand::thread_rng();
for _ in 0..1000 {
let ckks = CKKS::<Q, N>::new(scale_factor);
let (sk, pk) = ckks.new_key(&mut rng)?;
let z0: Vec<C<f64>> = std::iter::repeat_with(|| C::<f64>::rand(&mut rng, T))
.take(N / 2)
.collect();
let z1: Vec<C<f64>> = std::iter::repeat_with(|| C::<f64>::rand(&mut rng, T))
.take(N / 2)
.collect();
let m0: R<N> = ckks.encoder.encode(&z0)?;
let m1: R<N> = ckks.encoder.encode(&z1)?;
let ct0 = ckks.encrypt(&mut rng, &pk, &m0)?;
let ct1 = ckks.encrypt(&mut rng, &pk, &m1)?;
let ct2 = ckks.add(&ct0, &ct1)?;
let m2_decrypted = ckks.decrypt(&sk, ct2)?;
let z_decrypted = ckks.encoder.decode(&m2_decrypted)?;
let rounded_z_decrypted: Vec<C<f64>> = z_decrypted
.iter()
.map(|&c| C::<f64>::new(c.re.round(), c.im.round()))
.collect();
let expected_z2: Vec<C<f64>> = itertools::zip_eq(z0, z1).map(|(a, b)| a + b).collect();
assert_eq!(rounded_z_decrypted, expected_z2);
}
Ok(())
}
#[test]
fn test_sub() -> Result<()> {
const Q: u64 = 2u64.pow(16) + 1;
const N: usize = 16;
const T: u64 = 10;
let scale_factor = C::<f64>::new(1024.0, 0.0); // delta
let mut rng = rand::thread_rng();
for _ in 0..1000 {
let ckks = CKKS::<Q, N>::new(scale_factor);
let (sk, pk) = ckks.new_key(&mut rng)?;
let z0: Vec<C<f64>> = std::iter::repeat_with(|| C::<f64>::rand(&mut rng, T))
.take(N / 2)
.collect();
let z1: Vec<C<f64>> = std::iter::repeat_with(|| C::<f64>::rand(&mut rng, T))
.take(N / 2)
.collect();
let m0: R<N> = ckks.encoder.encode(&z0)?;
let m1: R<N> = ckks.encoder.encode(&z1)?;
let ct0 = ckks.encrypt(&mut rng, &pk, &m0)?;
let ct1 = ckks.encrypt(&mut rng, &pk, &m1)?;
let ct2 = ckks.sub(&ct0, &ct1)?;
let m2_decrypted = ckks.decrypt(&sk, ct2)?;
let z_decrypted = ckks.encoder.decode(&m2_decrypted)?;
let rounded_z_decrypted: Vec<C<f64>> = z_decrypted
.iter()
.map(|&c| C::<f64>::new(c.re.round(), c.im.round()))
.collect();
let expected_z2: Vec<C<f64>> = itertools::zip_eq(z0, z1).map(|(a, b)| a - b).collect();
assert_eq!(rounded_z_decrypted, expected_z2);
}
Ok(())
}
}

Loading…
Cancel
Save