use backend::hal::{ api::{FillUniform, Reset, VecZnxCopy, VecZnxFillUniform}, layouts::{Backend, Data, DataMut, DataRef, MatZnx, Module, ReaderFrom, WriterTo}, }; use crate::layouts::{ GGLWETensorKey, Infos, compressed::{Decompress, GGLWESwitchingKeyCompressed}, }; use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; use std::fmt; #[derive(PartialEq, Eq, Clone)] pub struct GGLWETensorKeyCompressed { pub(crate) keys: Vec>, } impl fmt::Debug for GGLWETensorKeyCompressed { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "{}", self) } } impl FillUniform for GGLWETensorKeyCompressed { fn fill_uniform(&mut self, source: &mut sampling::source::Source) { self.keys .iter_mut() .for_each(|key: &mut GGLWESwitchingKeyCompressed| key.fill_uniform(source)) } } impl Reset for GGLWETensorKeyCompressed where MatZnx: Reset, { fn reset(&mut self) { self.keys .iter_mut() .for_each(|key: &mut GGLWESwitchingKeyCompressed| key.reset()) } } impl fmt::Display for GGLWETensorKeyCompressed { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { writeln!(f, "(GLWETensorKeyCompressed)",)?; for (i, key) in self.keys.iter().enumerate() { write!(f, "{}: {}", i, key)?; } Ok(()) } } impl GGLWETensorKeyCompressed> { pub fn alloc(n: usize, basek: usize, k: usize, rows: usize, digits: usize, rank: usize) -> Self { let mut keys: Vec>> = Vec::new(); let pairs: usize = (((rank + 1) * rank) >> 1).max(1); (0..pairs).for_each(|_| { keys.push(GGLWESwitchingKeyCompressed::alloc( n, basek, k, rows, digits, 1, rank, )); }); Self { keys } } pub fn bytes_of(n: usize, basek: usize, k: usize, rows: usize, digits: usize, rank: usize) -> usize { let pairs: usize = (((rank + 1) * rank) >> 1).max(1); pairs * GGLWESwitchingKeyCompressed::bytes_of(n, basek, k, rows, digits, 1) } } impl Infos for GGLWETensorKeyCompressed { type Inner = MatZnx; fn inner(&self) -> &Self::Inner { self.keys[0].inner() } fn basek(&self) -> usize { self.keys[0].basek() } fn k(&self) -> usize { self.keys[0].k() } } impl GGLWETensorKeyCompressed { pub fn rank(&self) -> usize { self.keys[0].rank() } pub fn digits(&self) -> usize { self.keys[0].digits() } pub fn rank_in(&self) -> usize { self.keys[0].rank_in() } pub fn rank_out(&self) -> usize { self.keys[0].rank_out() } } impl ReaderFrom for GGLWETensorKeyCompressed { fn read_from(&mut self, reader: &mut R) -> std::io::Result<()> { let len: usize = reader.read_u64::()? as usize; if self.keys.len() != len { return Err(std::io::Error::new( std::io::ErrorKind::InvalidData, format!("self.keys.len()={} != read len={}", self.keys.len(), len), )); } for key in &mut self.keys { key.read_from(reader)?; } Ok(()) } } impl WriterTo for GGLWETensorKeyCompressed { fn write_to(&self, writer: &mut W) -> std::io::Result<()> { writer.write_u64::(self.keys.len() as u64)?; for key in &self.keys { key.write_to(writer)?; } Ok(()) } } impl GGLWETensorKeyCompressed { pub(crate) fn at_mut(&mut self, mut i: usize, mut j: usize) -> &mut GGLWESwitchingKeyCompressed { if i > j { std::mem::swap(&mut i, &mut j); }; let rank: usize = self.rank(); &mut self.keys[i * rank + j - (i * (i + 1) / 2)] } } impl Decompress> for GGLWETensorKey { fn decompress(&mut self, module: &Module, other: &GGLWETensorKeyCompressed) where Module: VecZnxFillUniform + VecZnxCopy, { #[cfg(debug_assertions)] { assert_eq!( self.keys.len(), other.keys.len(), "invalid receiver: self.keys.len()={} != other.keys.len()={}", self.keys.len(), other.keys.len() ); } self.keys .iter_mut() .zip(other.keys.iter()) .for_each(|(a, b)| { a.decompress(module, b); }); } }