use backend::hal::{ api::{VmpPMatAlloc, VmpPMatAllocBytes, VmpPMatPrepare}, layouts::{Backend, Data, DataMut, DataRef, Module, Scratch, VmpPMat}, }; use crate::layouts::{ GGLWETensorKey, Infos, prepared::{GGLWESwitchingKeyPrepared, Prepare, PrepareAlloc}, }; #[derive(PartialEq, Eq)] pub struct GGLWETensorKeyPrepared { pub(crate) keys: Vec>, } impl GGLWETensorKeyPrepared, B> { pub fn alloc(module: &Module, n: usize, basek: usize, k: usize, rows: usize, digits: usize, rank: usize) -> Self where Module: VmpPMatAlloc, { let mut keys: Vec, B>> = Vec::new(); let pairs: usize = (((rank + 1) * rank) >> 1).max(1); (0..pairs).for_each(|_| { keys.push(GGLWESwitchingKeyPrepared::alloc( module, n, basek, k, rows, digits, 1, rank, )); }); Self { keys } } pub fn bytes_of(module: &Module, n: usize, basek: usize, k: usize, rows: usize, digits: usize, rank: usize) -> usize where Module: VmpPMatAllocBytes, { let pairs: usize = (((rank + 1) * rank) >> 1).max(1); pairs * GGLWESwitchingKeyPrepared::bytes_of(module, n, basek, k, rows, digits, 1, rank) } } impl Infos for GGLWETensorKeyPrepared { type Inner = VmpPMat; fn inner(&self) -> &Self::Inner { &self.keys[0].inner() } fn basek(&self) -> usize { self.keys[0].basek() } fn k(&self) -> usize { self.keys[0].k() } } impl GGLWETensorKeyPrepared { pub fn rank(&self) -> usize { self.keys[0].rank() } pub fn rank_in(&self) -> usize { self.keys[0].rank_in() } pub fn rank_out(&self) -> usize { self.keys[0].rank_out() } pub fn digits(&self) -> usize { self.keys[0].digits() } } impl GGLWETensorKeyPrepared { // Returns a mutable reference to GLWESwitchingKey_{s}(s[i] * s[j]) pub fn at_mut(&mut self, mut i: usize, mut j: usize) -> &mut GGLWESwitchingKeyPrepared { if i > j { std::mem::swap(&mut i, &mut j); }; let rank: usize = self.rank(); &mut self.keys[i * rank + j - (i * (i + 1) / 2)] } } impl GGLWETensorKeyPrepared { // Returns a reference to GLWESwitchingKey_{s}(s[i] * s[j]) pub fn at(&self, mut i: usize, mut j: usize) -> &GGLWESwitchingKeyPrepared { if i > j { std::mem::swap(&mut i, &mut j); }; let rank: usize = self.rank(); &self.keys[i * rank + j - (i * (i + 1) / 2)] } } impl Prepare> for GGLWETensorKeyPrepared where Module: VmpPMatPrepare, { fn prepare(&mut self, module: &Module, other: &GGLWETensorKey, scratch: &mut Scratch) { #[cfg(debug_assertions)] { assert_eq!(self.keys.len(), other.keys.len()); } self.keys .iter_mut() .zip(other.keys.iter()) .for_each(|(a, b)| { a.prepare(module, b, scratch); }); } } impl PrepareAlloc, B>> for GGLWETensorKey where Module: VmpPMatAlloc + VmpPMatPrepare, { fn prepare_alloc(&self, module: &Module, scratch: &mut Scratch) -> GGLWETensorKeyPrepared, B> { let mut tsk_prepared: GGLWETensorKeyPrepared, B> = GGLWETensorKeyPrepared::alloc( module, self.n(), self.basek(), self.k(), self.rows(), self.digits(), self.rank(), ); tsk_prepared.prepare(module, self, scratch); tsk_prepared } }