use poulpy_hal::{ layouts::{Backend, Data, DataMut, DataRef, FillUniform, Module, ReaderFrom, WriterTo}, source::Source, }; use crate::layouts::{ Base2K, Degree, Dnum, Dsize, GGLWEInfos, GLWEInfos, LWEInfos, Rank, TensorKey, TensorKeyToMut, TorusPrecision, compressed::{ GLWESwitchingKeyCompressed, GLWESwitchingKeyCompressedAlloc, GLWESwitchingKeyCompressedToMut, GLWESwitchingKeyCompressedToRef, GLWESwitchingKeyDecompress, }, }; use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; use std::fmt; #[derive(PartialEq, Eq, Clone)] pub struct TensorKeyCompressed { pub(crate) keys: Vec>, } impl LWEInfos for TensorKeyCompressed { fn n(&self) -> Degree { self.keys[0].n() } fn base2k(&self) -> Base2K { self.keys[0].base2k() } fn k(&self) -> TorusPrecision { self.keys[0].k() } fn size(&self) -> usize { self.keys[0].size() } } impl GLWEInfos for TensorKeyCompressed { fn rank(&self) -> Rank { self.rank_out() } } impl GGLWEInfos for TensorKeyCompressed { fn rank_in(&self) -> Rank { self.rank_out() } fn rank_out(&self) -> Rank { self.keys[0].rank_out() } fn dsize(&self) -> Dsize { self.keys[0].dsize() } fn dnum(&self) -> Dnum { self.keys[0].dnum() } } impl fmt::Debug for TensorKeyCompressed { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "{self}") } } impl FillUniform for TensorKeyCompressed { fn fill_uniform(&mut self, log_bound: usize, source: &mut Source) { self.keys .iter_mut() .for_each(|key: &mut GLWESwitchingKeyCompressed| key.fill_uniform(log_bound, source)) } } impl fmt::Display for TensorKeyCompressed { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { writeln!(f, "(GLWETensorKeyCompressed)",)?; for (i, key) in self.keys.iter().enumerate() { write!(f, "{i}: {key}")?; } Ok(()) } } pub trait TensorKeyCompressedAlloc where Self: GLWESwitchingKeyCompressedAlloc, { fn alloc_tensor_key_compressed( &self, base2k: Base2K, k: TorusPrecision, rank: Rank, dnum: Dnum, dsize: Dsize, ) -> TensorKeyCompressed> { let pairs: u32 = (((rank.as_u32() + 1) * rank.as_u32()) >> 1).max(1); TensorKeyCompressed { keys: (0..pairs) .map(|_| self.alloc_glwe_switching_key_compressed(base2k, k, Rank(1), rank, dnum, dsize)) .collect(), } } fn alloc_tensor_key_compressed_from_infos(&self, infos: &A) -> TensorKeyCompressed> where A: GGLWEInfos, { assert_eq!( infos.rank_in(), infos.rank_out(), "rank_in != rank_out is not supported for GGLWETensorKeyCompressed" ); self.alloc_tensor_key_compressed( infos.base2k(), infos.k(), infos.rank(), infos.dnum(), infos.dsize(), ) } fn bytes_of_tensor_key_compressed(&self, base2k: Base2K, k: TorusPrecision, rank: Rank, dnum: Dnum, dsize: Dsize) -> usize { let pairs: usize = (((rank.0 + 1) * rank.0) >> 1).max(1) as usize; pairs * self.bytes_of_glwe_switching_key_compressed(base2k, k, Rank(1), dnum, dsize) } fn bytes_of_tensor_key_compressed_from_infos(&self, infos: &A) -> usize where A: GGLWEInfos, { self.bytes_of_tensor_key_compressed( infos.base2k(), infos.k(), infos.rank(), infos.dnum(), infos.dsize(), ) } } impl TensorKeyCompressed> { pub fn alloc_from_infos(module: &M, infos: &A) -> Self where A: GGLWEInfos, M: TensorKeyCompressedAlloc, { module.alloc_tensor_key_compressed_from_infos(infos) } pub fn alloc(module: &M, base2k: Base2K, k: TorusPrecision, rank: Rank, dnum: Dnum, dsize: Dsize) -> Self where M: TensorKeyCompressedAlloc, { module.alloc_tensor_key_compressed(base2k, k, rank, dnum, dsize) } pub fn bytes_of_from_infos(module: &M, infos: &A) -> usize where A: GGLWEInfos, M: TensorKeyCompressedAlloc, { module.bytes_of_tensor_key_compressed_from_infos(infos) } pub fn bytes_of(module: &M, base2k: Base2K, k: TorusPrecision, rank: Rank, dnum: Dnum, dsize: Dsize) -> usize where M: TensorKeyCompressedAlloc, { module.bytes_of_tensor_key_compressed(base2k, k, rank, dnum, dsize) } } impl ReaderFrom for TensorKeyCompressed { fn read_from(&mut self, reader: &mut R) -> std::io::Result<()> { let len: usize = reader.read_u64::()? as usize; if self.keys.len() != len { return Err(std::io::Error::new( std::io::ErrorKind::InvalidData, format!("self.keys.len()={} != read len={}", self.keys.len(), len), )); } for key in &mut self.keys { key.read_from(reader)?; } Ok(()) } } impl WriterTo for TensorKeyCompressed { fn write_to(&self, writer: &mut W) -> std::io::Result<()> { writer.write_u64::(self.keys.len() as u64)?; for key in &self.keys { key.write_to(writer)?; } Ok(()) } } impl TensorKeyCompressed { pub(crate) fn at_mut(&mut self, mut i: usize, mut j: usize) -> &mut GLWESwitchingKeyCompressed { if i > j { std::mem::swap(&mut i, &mut j); }; let rank: usize = self.rank_out().into(); &mut self.keys[i * rank + j - (i * (i + 1) / 2)] } } pub trait TensorKeyDecompress where Self: GLWESwitchingKeyDecompress, { fn decompress_tensor_key(&self, res: &mut R, other: &O) where R: TensorKeyToMut, O: TensorKeyCompressedToRef, { let res: &mut TensorKey<&mut [u8]> = &mut res.to_mut(); let other: &TensorKeyCompressed<&[u8]> = &other.to_ref(); assert_eq!( res.keys.len(), other.keys.len(), "invalid receiver: res.keys.len()={} != other.keys.len()={}", res.keys.len(), other.keys.len() ); for (a, b) in res.keys.iter_mut().zip(other.keys.iter()) { self.decompress_glwe_switching_key(a, b); } } } impl TensorKeyDecompress for Module where Self: GLWESwitchingKeyDecompress {} impl TensorKey { pub fn decompress(&mut self, module: &M, other: &O) where O: TensorKeyCompressedToRef, M: TensorKeyDecompress, { module.decompress_tensor_key(self, other); } } pub trait TensorKeyCompressedToMut { fn to_mut(&mut self) -> TensorKeyCompressed<&mut [u8]>; } impl TensorKeyCompressedToMut for TensorKeyCompressed where GLWESwitchingKeyCompressed: GLWESwitchingKeyCompressedToMut, { fn to_mut(&mut self) -> TensorKeyCompressed<&mut [u8]> { TensorKeyCompressed { keys: self.keys.iter_mut().map(|c| c.to_mut()).collect(), } } } pub trait TensorKeyCompressedToRef { fn to_ref(&self) -> TensorKeyCompressed<&[u8]>; } impl TensorKeyCompressedToRef for TensorKeyCompressed where GLWESwitchingKeyCompressed: GLWESwitchingKeyCompressedToRef, { fn to_ref(&self) -> TensorKeyCompressed<&[u8]> { TensorKeyCompressed { keys: self.keys.iter().map(|c| c.to_ref()).collect(), } } }