use poulpy_hal::{ api::{VmpPMatAlloc, VmpPMatBytesOf, VmpPrepare, VmpPrepareTmpBytes}, layouts::{Backend, Data, DataMut, DataRef, Module, Scratch, VmpPMat, VmpPMatToMut, VmpPMatToRef, ZnxInfos}, }; use crate::layouts::{ Base2K, Degree, Dnum, Dsize, GGSW, GGSWInfos, GGSWToRef, GLWEInfos, GetDegree, LWEInfos, Rank, TorusPrecision, }; #[derive(PartialEq, Eq)] pub struct GGSWPrepared { pub(crate) data: VmpPMat, pub(crate) k: TorusPrecision, pub(crate) base2k: Base2K, pub(crate) dsize: Dsize, } impl LWEInfos for GGSWPrepared { fn n(&self) -> Degree { Degree(self.data.n() as u32) } fn base2k(&self) -> Base2K { self.base2k } fn k(&self) -> TorusPrecision { self.k } fn size(&self) -> usize { self.data.size() } } impl GLWEInfos for GGSWPrepared { fn rank(&self) -> Rank { Rank(self.data.cols_out() as u32 - 1) } } impl GGSWInfos for GGSWPrepared { fn dsize(&self) -> Dsize { self.dsize } fn dnum(&self) -> Dnum { Dnum(self.data.rows() as u32) } } pub trait GGSWPreparedAlloc where Self: GetDegree + VmpPMatAlloc + VmpPMatBytesOf, { fn alloc_ggsw_prepared( &self, base2k: Base2K, k: TorusPrecision, dnum: Dnum, dsize: Dsize, rank: Rank, ) -> GGSWPrepared, B> { let size: usize = k.0.div_ceil(base2k.0) as usize; debug_assert!( size as u32 > dsize.0, "invalid ggsw: ceil(k/base2k): {size} <= dsize: {}", dsize.0 ); assert!( dnum.0 * dsize.0 <= size as u32, "invalid ggsw: dnum: {} * dsize:{} > ceil(k/base2k): {size}", dnum.0, dsize.0, ); GGSWPrepared { data: self.vmp_pmat_alloc( dnum.into(), (rank + 1).into(), (rank + 1).into(), k.0.div_ceil(base2k.0) as usize, ), k, base2k, dsize, } } fn alloc_ggsw_prepared_from_infos(&self, infos: &A) -> GGSWPrepared, B> where A: GGSWInfos, { assert_eq!(self.ring_degree(), infos.n()); self.alloc_ggsw_prepared( infos.base2k(), infos.k(), infos.dnum(), infos.dsize(), infos.rank(), ) } fn bytes_of_ggsw_prepared(&self, base2k: Base2K, k: TorusPrecision, dnum: Dnum, dsize: Dsize, rank: Rank) -> usize { let size: usize = k.0.div_ceil(base2k.0) as usize; debug_assert!( size as u32 > dsize.0, "invalid ggsw: ceil(k/base2k): {size} <= dsize: {}", dsize.0 ); assert!( dnum.0 * dsize.0 <= size as u32, "invalid ggsw: dnum: {} * dsize:{} > ceil(k/base2k): {size}", dnum.0, dsize.0, ); self.bytes_of_vmp_pmat(dnum.into(), (rank + 1).into(), (rank + 1).into(), size) } fn bytes_of_ggsw_prepared_from_infos(&self, infos: &A) -> usize where A: GGSWInfos, { assert_eq!(self.ring_degree(), infos.n()); self.bytes_of_ggsw_prepared( infos.base2k(), infos.k(), infos.dnum(), infos.dsize(), infos.rank(), ) } } impl GGSWPreparedAlloc for Module where Self: GetDegree + VmpPMatAlloc + VmpPMatBytesOf {} impl GGSWPrepared, B> { pub fn alloc_from_infos(module: &M, infos: &A) -> Self where A: GGSWInfos, M: GGSWPreparedAlloc, { module.alloc_ggsw_prepared_from_infos(infos) } pub fn alloc(module: &M, base2k: Base2K, k: TorusPrecision, dnum: Dnum, dsize: Dsize, rank: Rank) -> Self where M: GGSWPreparedAlloc, { module.alloc_ggsw_prepared(base2k, k, dnum, dsize, rank) } pub fn bytes_of_from_infos(module: &M, infos: &A) -> usize where A: GGSWInfos, M: GGSWPreparedAlloc, { module.bytes_of_ggsw_prepared_from_infos(infos) } pub fn bytes_of(module: &M, base2k: Base2K, k: TorusPrecision, dnum: Dnum, dsize: Dsize, rank: Rank) -> usize where M: GGSWPreparedAlloc, { module.bytes_of_ggsw_prepared(base2k, k, dnum, dsize, rank) } } impl GGSWPrepared { pub fn data(&self) -> &VmpPMat { &self.data } } pub trait GGSWPrepare where Self: GetDegree + VmpPrepareTmpBytes + VmpPrepare, { fn ggsw_prepare_tmp_bytes(&self, infos: &A) -> usize where A: GGSWInfos, { assert_eq!(self.ring_degree(), infos.n()); self.vmp_prepare_tmp_bytes( infos.dnum().into(), (infos.rank() + 1).into(), (infos.rank() + 1).into(), infos.size(), ) } fn ggsw_prepare(&self, res: &mut R, other: &O, scratch: &mut Scratch) where R: GGSWPreparedToMut, O: GGSWToRef, { let mut res: GGSWPrepared<&mut [u8], B> = res.to_mut(); let other: GGSW<&[u8]> = other.to_ref(); assert_eq!(res.n(), self.ring_degree()); assert_eq!(other.n(), self.ring_degree()); assert_eq!(res.k, other.k); assert_eq!(res.base2k, other.base2k); assert_eq!(res.dsize, other.dsize); self.vmp_prepare(&mut res.data, &other.data, scratch); } } impl GGSWPrepare for Module where Self: GetDegree + VmpPrepareTmpBytes + VmpPrepare {} impl GGSWPrepared, B> { pub fn prepare_tmp_bytes(&self, module: &M, infos: &A) -> usize where A: GGSWInfos, M: GGSWPrepare, { module.ggsw_prepare_tmp_bytes(infos) } } impl GGSWPrepared { pub fn prepare(&mut self, module: &M, other: &O, scratch: &mut Scratch) where O: GGSWToRef, M: GGSWPrepare, { module.ggsw_prepare(self, other, scratch); } } pub trait GGSWPreparedToMut { fn to_mut(&mut self) -> GGSWPrepared<&mut [u8], B>; } impl GGSWPreparedToMut for GGSWPrepared { fn to_mut(&mut self) -> GGSWPrepared<&mut [u8], B> { GGSWPrepared { base2k: self.base2k, k: self.k, dsize: self.dsize, data: self.data.to_mut(), } } } pub trait GGSWPreparedToRef { fn to_ref(&self) -> GGSWPrepared<&[u8], B>; } impl GGSWPreparedToRef for GGSWPrepared { fn to_ref(&self) -> GGSWPrepared<&[u8], B> { GGSWPrepared { base2k: self.base2k, k: self.k, dsize: self.dsize, data: self.data.to_ref(), } } }