ckks addition & substraction of ciphertexts

This commit is contained in:
2025-07-05 16:58:41 +02:00
parent 6090116a8b
commit 84c54e8edd
2 changed files with 113 additions and 11 deletions

View File

@@ -153,10 +153,9 @@ mod tests {
#[test] #[test]
fn test_encode_decode() -> Result<()> { fn test_encode_decode() -> Result<()> {
const Q: u64 = 1024; const Q: u64 = 1024;
// const N: usize = 4; // ie. m=2*n=8 const N: usize = 32;
const N: usize = 16;
let T = 16; // WIP let T = 128; // WIP
let mut rng = rand::thread_rng(); let mut rng = rand::thread_rng();
for _ in 0..100 { for _ in 0..100 {

View File

@@ -72,7 +72,7 @@ impl<const Q: u64, const N: usize> CKKS<Q, N> {
fn decrypt( fn decrypt(
&self, // TODO maybe rm? &self, // TODO maybe rm?
sk: SecretKey<Q, N>, sk: &SecretKey<Q, N>,
c: (Rq<Q, N>, Rq<Q, N>), c: (Rq<Q, N>, Rq<Q, N>),
) -> Result<R<N>> { ) -> Result<R<N>> {
let m = c.0.clone() + c.1 * sk.0; let m = c.0.clone() + c.1 * sk.0;
@@ -95,10 +95,25 @@ impl<const Q: u64, const N: usize> CKKS<Q, N> {
sk: SecretKey<Q, N>, sk: SecretKey<Q, N>,
c: (Rq<Q, N>, Rq<Q, N>), c: (Rq<Q, N>, Rq<Q, N>),
) -> Result<Vec<C<f64>>> { ) -> Result<Vec<C<f64>>> {
let d = self.decrypt(sk, c)?; let d = self.decrypt(&sk, c)?;
self.encoder.decode(&d) self.encoder.decode(&d)
} }
pub fn add(
&self,
c0: &(Rq<Q, N>, Rq<Q, N>),
c1: &(Rq<Q, N>, Rq<Q, N>),
) -> Result<(Rq<Q, N>, Rq<Q, N>)> {
Ok((&c0.0 + &c1.0, &c0.1 + &c1.1))
}
pub fn sub(
&self,
c0: &(Rq<Q, N>, Rq<Q, N>),
c1: &(Rq<Q, N>, Rq<Q, N>),
) -> Result<(Rq<Q, N>, Rq<Q, N>)> {
Ok((&c0.0 - &c1.0, &c0.1 + &c1.1))
}
} }
#[cfg(test)] #[cfg(test)]
@@ -108,10 +123,10 @@ mod tests {
#[test] #[test]
fn test_encrypt_decrypt() -> Result<()> { fn test_encrypt_decrypt() -> Result<()> {
const Q: u64 = 2u64.pow(16) + 1; const Q: u64 = 2u64.pow(16) + 1;
const T: u64 = 16; const N: usize = 32;
const N: usize = 8; const T: u64 = 50;
let scale_factor_u64 = 512_u64; // delta let scale_factor_u64 = 512_u64; // delta
let scale_factor = C::<f64>::new(512.0, 0.0); // delta let scale_factor = C::<f64>::new(scale_factor_u64 as f64, 0.0); // delta
let mut rng = rand::thread_rng(); let mut rng = rand::thread_rng();
@@ -124,7 +139,7 @@ mod tests {
let m = m_raw * scale_factor_u64; let m = m_raw * scale_factor_u64;
let ct = ckks.encrypt(&mut rng, &pk, &m)?; let ct = ckks.encrypt(&mut rng, &pk, &m)?;
let m_decrypted = ckks.decrypt(sk, ct)?; let m_decrypted = ckks.decrypt(&sk, ct)?;
let m_decrypted: Vec<u64> = m_decrypted let m_decrypted: Vec<u64> = m_decrypted
.coeffs() .coeffs()
@@ -141,8 +156,8 @@ mod tests {
#[test] #[test]
fn test_encode_encrypt_decrypt_decode() -> Result<()> { fn test_encode_encrypt_decrypt_decode() -> Result<()> {
const Q: u64 = 2u64.pow(16) + 1; const Q: u64 = 2u64.pow(16) + 1;
const N: usize = 16;
const T: u64 = 16; const T: u64 = 16;
const N: usize = 4;
let scale_factor = C::<f64>::new(512.0, 0.0); // delta let scale_factor = C::<f64>::new(512.0, 0.0); // delta
let mut rng = rand::thread_rng(); let mut rng = rand::thread_rng();
@@ -155,6 +170,7 @@ mod tests {
.take(N / 2) .take(N / 2)
.collect(); .collect();
let m: R<N> = ckks.encoder.encode(&z)?; let m: R<N> = ckks.encoder.encode(&z)?;
println!("{}", m);
// sanity check // sanity check
{ {
@@ -167,7 +183,8 @@ mod tests {
} }
let ct = ckks.encrypt(&mut rng, &pk, &m)?; let ct = ckks.encrypt(&mut rng, &pk, &m)?;
let m_decrypted = ckks.decrypt(sk, ct)?; let m_decrypted = ckks.decrypt(&sk, ct)?;
println!("{}", m_decrypted);
let z_decrypted = ckks.encoder.decode(&m_decrypted)?; let z_decrypted = ckks.encoder.decode(&m_decrypted)?;
@@ -180,4 +197,90 @@ mod tests {
Ok(()) Ok(())
} }
#[test]
fn test_add() -> Result<()> {
const Q: u64 = 2u64.pow(16) + 1;
const N: usize = 16;
const T: u64 = 10;
let scale_factor = C::<f64>::new(1024.0, 0.0); // delta
let mut rng = rand::thread_rng();
for _ in 0..1000 {
let ckks = CKKS::<Q, N>::new(scale_factor);
let (sk, pk) = ckks.new_key(&mut rng)?;
let z0: Vec<C<f64>> = std::iter::repeat_with(|| C::<f64>::rand(&mut rng, T))
.take(N / 2)
.collect();
let z1: Vec<C<f64>> = std::iter::repeat_with(|| C::<f64>::rand(&mut rng, T))
.take(N / 2)
.collect();
let m0: R<N> = ckks.encoder.encode(&z0)?;
let m1: R<N> = ckks.encoder.encode(&z1)?;
let ct0 = ckks.encrypt(&mut rng, &pk, &m0)?;
let ct1 = ckks.encrypt(&mut rng, &pk, &m1)?;
let ct2 = ckks.add(&ct0, &ct1)?;
let m2_decrypted = ckks.decrypt(&sk, ct2)?;
let z_decrypted = ckks.encoder.decode(&m2_decrypted)?;
let rounded_z_decrypted: Vec<C<f64>> = z_decrypted
.iter()
.map(|&c| C::<f64>::new(c.re.round(), c.im.round()))
.collect();
let expected_z2: Vec<C<f64>> = itertools::zip_eq(z0, z1).map(|(a, b)| a + b).collect();
assert_eq!(rounded_z_decrypted, expected_z2);
}
Ok(())
}
#[test]
fn test_sub() -> Result<()> {
const Q: u64 = 2u64.pow(16) + 1;
const N: usize = 16;
const T: u64 = 10;
let scale_factor = C::<f64>::new(1024.0, 0.0); // delta
let mut rng = rand::thread_rng();
for _ in 0..1000 {
let ckks = CKKS::<Q, N>::new(scale_factor);
let (sk, pk) = ckks.new_key(&mut rng)?;
let z0: Vec<C<f64>> = std::iter::repeat_with(|| C::<f64>::rand(&mut rng, T))
.take(N / 2)
.collect();
let z1: Vec<C<f64>> = std::iter::repeat_with(|| C::<f64>::rand(&mut rng, T))
.take(N / 2)
.collect();
let m0: R<N> = ckks.encoder.encode(&z0)?;
let m1: R<N> = ckks.encoder.encode(&z1)?;
let ct0 = ckks.encrypt(&mut rng, &pk, &m0)?;
let ct1 = ckks.encrypt(&mut rng, &pk, &m1)?;
let ct2 = ckks.sub(&ct0, &ct1)?;
let m2_decrypted = ckks.decrypt(&sk, ct2)?;
let z_decrypted = ckks.encoder.decode(&m2_decrypted)?;
let rounded_z_decrypted: Vec<C<f64>> = z_decrypted
.iter()
.map(|&c| C::<f64>::new(c.re.round(), c.im.round()))
.collect();
let expected_z2: Vec<C<f64>> = itertools::zip_eq(z0, z1).map(|(a, b)| a - b).collect();
assert_eq!(rounded_z_decrypted, expected_z2);
}
Ok(())
}
} }