use backend::hal::{ api::{VmpPMatAlloc, VmpPMatAllocBytes, VmpPrepare}, layouts::{Backend, Data, DataMut, DataRef, Module, Scratch, VmpPMat}, }; use crate::layouts::{ GGSWCiphertext, Infos, prepared::{Prepare, PrepareAlloc}, }; #[derive(PartialEq, Eq)] pub struct GGSWCiphertextPrepared { pub(crate) data: VmpPMat, pub(crate) basek: usize, pub(crate) k: usize, pub(crate) digits: usize, } impl GGSWCiphertextPrepared, B> { pub fn alloc(module: &Module, n: usize, basek: usize, k: usize, rows: usize, digits: usize, rank: usize) -> Self where Module: VmpPMatAlloc, { let size: usize = k.div_ceil(basek); debug_assert!(digits > 0, "invalid ggsw: `digits` == 0"); debug_assert!( size > digits, "invalid ggsw: ceil(k/basek): {} <= digits: {}", size, digits ); assert!( rows * digits <= size, "invalid ggsw: rows: {} * digits:{} > ceil(k/basek): {}", rows, digits, size ); Self { data: module.vmp_pmat_alloc(n, rows, rank + 1, rank + 1, k.div_ceil(basek)), basek, k: k, digits, } } pub fn bytes_of(module: &Module, n: usize, basek: usize, k: usize, rows: usize, digits: usize, rank: usize) -> usize where Module: VmpPMatAllocBytes, { let size: usize = k.div_ceil(basek); debug_assert!( size > digits, "invalid ggsw: ceil(k/basek): {} <= digits: {}", size, digits ); assert!( rows * digits <= size, "invalid ggsw: rows: {} * digits:{} > ceil(k/basek): {}", rows, digits, size ); module.vmp_pmat_alloc_bytes(n, rows, rank + 1, rank + 1, size) } } impl Infos for GGSWCiphertextPrepared { type Inner = VmpPMat; fn inner(&self) -> &Self::Inner { &self.data } fn basek(&self) -> usize { self.basek } fn k(&self) -> usize { self.k } } impl GGSWCiphertextPrepared { pub fn rank(&self) -> usize { self.data.cols_out() - 1 } pub fn digits(&self) -> usize { self.digits } } impl GGSWCiphertextPrepared { pub fn data(&self) -> &VmpPMat { &self.data } } impl Prepare> for GGSWCiphertextPrepared where Module: VmpPrepare, { fn prepare(&mut self, module: &Module, other: &GGSWCiphertext, scratch: &mut Scratch) { module.vmp_prepare(&mut self.data, &other.data, scratch); self.k = other.k; self.basek = other.basek; self.digits = other.digits; } } impl PrepareAlloc, B>> for GGSWCiphertext where Module: VmpPMatAlloc + VmpPrepare, { fn prepare_alloc(&self, module: &Module, scratch: &mut Scratch) -> GGSWCiphertextPrepared, B> { let mut ggsw_prepared: GGSWCiphertextPrepared, B> = GGSWCiphertextPrepared::alloc( module, self.n(), self.basek(), self.k(), self.rows(), self.digits(), self.rank(), ); ggsw_prepared.prepare(module, self, scratch); ggsw_prepared } }