use poulpy_hal::{ layouts::{Data, DataMut, DataRef, FillUniform, ReaderFrom, WriterTo}, source::Source, }; use crate::layouts::{Base2K, Degree, Dnum, Dsize, GGLWEInfos, GGLWESwitchingKey, GLWEInfos, LWEInfos, Rank, TorusPrecision}; use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; use std::fmt; #[derive(PartialEq, Eq, Copy, Clone, Debug)] pub struct GGLWETensorKeyLayout { pub n: Degree, pub base2k: Base2K, pub k: TorusPrecision, pub rank: Rank, pub dnum: Dnum, pub dsize: Dsize, } #[derive(PartialEq, Eq, Clone)] pub struct GGLWETensorKey { pub(crate) keys: Vec>, } impl LWEInfos for GGLWETensorKey { fn n(&self) -> Degree { self.keys[0].n() } fn base2k(&self) -> Base2K { self.keys[0].base2k() } fn k(&self) -> TorusPrecision { self.keys[0].k() } fn size(&self) -> usize { self.keys[0].size() } } impl GLWEInfos for GGLWETensorKey { fn rank(&self) -> Rank { self.keys[0].rank_out() } } impl GGLWEInfos for GGLWETensorKey { fn rank_in(&self) -> Rank { self.rank_out() } fn rank_out(&self) -> Rank { self.keys[0].rank_out() } fn dsize(&self) -> Dsize { self.keys[0].dsize() } fn dnum(&self) -> Dnum { self.keys[0].dnum() } } impl LWEInfos for GGLWETensorKeyLayout { fn n(&self) -> Degree { self.n } fn base2k(&self) -> Base2K { self.base2k } fn k(&self) -> TorusPrecision { self.k } } impl GLWEInfos for GGLWETensorKeyLayout { fn rank(&self) -> Rank { self.rank_out() } } impl GGLWEInfos for GGLWETensorKeyLayout { fn rank_in(&self) -> Rank { self.rank } fn dsize(&self) -> Dsize { self.dsize } fn rank_out(&self) -> Rank { self.rank } fn dnum(&self) -> Dnum { self.dnum } } impl fmt::Debug for GGLWETensorKey { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "{self}") } } impl FillUniform for GGLWETensorKey { fn fill_uniform(&mut self, log_bound: usize, source: &mut Source) { self.keys .iter_mut() .for_each(|key: &mut GGLWESwitchingKey| key.fill_uniform(log_bound, source)) } } impl fmt::Display for GGLWETensorKey { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { writeln!(f, "(GLWETensorKey)",)?; for (i, key) in self.keys.iter().enumerate() { write!(f, "{i}: {key}")?; } Ok(()) } } impl GGLWETensorKey> { pub fn alloc(infos: &A) -> Self where A: GGLWEInfos, { assert_eq!( infos.rank_in(), infos.rank_out(), "rank_in != rank_out is not supported for GGLWETensorKey" ); Self::alloc_with( infos.n(), infos.base2k(), infos.k(), infos.rank_out(), infos.dnum(), infos.dsize(), ) } pub fn alloc_with(n: Degree, base2k: Base2K, k: TorusPrecision, rank: Rank, dnum: Dnum, dsize: Dsize) -> Self { let mut keys: Vec>> = Vec::new(); let pairs: u32 = (((rank.0 + 1) * rank.0) >> 1).max(1); (0..pairs).for_each(|_| { keys.push(GGLWESwitchingKey::alloc_with( n, base2k, k, Rank(1), rank, dnum, dsize, )); }); Self { keys } } pub fn alloc_bytes(infos: &A) -> usize where A: GGLWEInfos, { assert_eq!( infos.rank_in(), infos.rank_out(), "rank_in != rank_out is not supported for GGLWETensorKey" ); let rank_out: usize = infos.rank_out().into(); let pairs: usize = (((rank_out + 1) * rank_out) >> 1).max(1); pairs * GGLWESwitchingKey::alloc_bytes_with( infos.n(), infos.base2k(), infos.k(), Rank(1), infos.rank_out(), infos.dnum(), infos.dsize(), ) } pub fn alloc_bytes_with(n: Degree, base2k: Base2K, k: TorusPrecision, rank: Rank, dnum: Dnum, dsize: Dsize) -> usize { let pairs: usize = (((rank.0 + 1) * rank.0) >> 1).max(1) as usize; pairs * GGLWESwitchingKey::alloc_bytes_with(n, base2k, k, Rank(1), rank, dnum, dsize) } } impl GGLWETensorKey { // Returns a mutable reference to GLWESwitchingKey_{s}(s[i] * s[j]) pub fn at_mut(&mut self, mut i: usize, mut j: usize) -> &mut GGLWESwitchingKey { if i > j { std::mem::swap(&mut i, &mut j); }; let rank: usize = self.rank_out().into(); &mut self.keys[i * rank + j - (i * (i + 1) / 2)] } } impl GGLWETensorKey { // Returns a reference to GLWESwitchingKey_{s}(s[i] * s[j]) pub fn at(&self, mut i: usize, mut j: usize) -> &GGLWESwitchingKey { if i > j { std::mem::swap(&mut i, &mut j); }; let rank: usize = self.rank_out().into(); &self.keys[i * rank + j - (i * (i + 1) / 2)] } } impl ReaderFrom for GGLWETensorKey { fn read_from(&mut self, reader: &mut R) -> std::io::Result<()> { let len: usize = reader.read_u64::()? as usize; if self.keys.len() != len { return Err(std::io::Error::new( std::io::ErrorKind::InvalidData, format!("self.keys.len()={} != read len={}", self.keys.len(), len), )); } for key in &mut self.keys { key.read_from(reader)?; } Ok(()) } } impl WriterTo for GGLWETensorKey { fn write_to(&self, writer: &mut W) -> std::io::Result<()> { writer.write_u64::(self.keys.len() as u64)?; for key in &self.keys { key.write_to(writer)?; } Ok(()) } }