use poulpy_hal::{ api::{VmpPMatAlloc, VmpPMatAllocBytes, VmpPrepare, VmpPrepareTmpBytes}, layouts::{Backend, Data, DataMut, DataRef, Module, Scratch, VmpPMat, VmpPMatToMut, VmpPMatToRef, ZnxInfos}, }; use crate::layouts::{ Base2K, Degree, Dnum, Dsize, GGLWE, GGLWEInfos, GGLWEToRef, GLWEInfos, GetDegree, LWEInfos, Rank, TorusPrecision, }; #[derive(PartialEq, Eq)] pub struct GGLWEPrepared { pub(crate) data: VmpPMat, pub(crate) k: TorusPrecision, pub(crate) base2k: Base2K, pub(crate) dsize: Dsize, } impl LWEInfos for GGLWEPrepared { fn n(&self) -> Degree { Degree(self.data.n() as u32) } fn base2k(&self) -> Base2K { self.base2k } fn k(&self) -> TorusPrecision { self.k } fn size(&self) -> usize { self.data.size() } } impl GLWEInfos for GGLWEPrepared { fn rank(&self) -> Rank { self.rank_out() } } impl GGLWEInfos for GGLWEPrepared { fn rank_in(&self) -> Rank { Rank(self.data.cols_in() as u32) } fn rank_out(&self) -> Rank { Rank(self.data.cols_out() as u32 - 1) } fn dsize(&self) -> Dsize { self.dsize } fn dnum(&self) -> Dnum { Dnum(self.data.rows() as u32) } } pub trait GGLWEPreparedAlloc where Self: GetDegree + VmpPMatAlloc + VmpPMatAllocBytes, { fn gglwe_prepared_alloc( &self, base2k: Base2K, k: TorusPrecision, rank_in: Rank, rank_out: Rank, dnum: Dnum, dsize: Dsize, ) -> GGLWEPrepared, B> { let size: usize = k.0.div_ceil(base2k.0) as usize; debug_assert!( size as u32 > dsize.0, "invalid gglwe: ceil(k/base2k): {size} <= dsize: {}", dsize.0 ); assert!( dnum.0 * dsize.0 <= size as u32, "invalid gglwe: dnum: {} * dsize:{} > ceil(k/base2k): {size}", dnum.0, dsize.0, ); GGLWEPrepared { data: self.vmp_pmat_alloc(dnum.into(), rank_in.into(), (rank_out + 1).into(), size), k, base2k, dsize, } } fn gglwe_prepared_alloc_from_infos(&self, infos: &A) -> GGLWEPrepared, B> where A: GGLWEInfos, { assert_eq!(self.n(), infos.n()); self.gglwe_prepared_alloc( infos.base2k(), infos.k(), infos.rank_in(), infos.rank_out(), infos.dnum(), infos.dsize(), ) } fn gglwe_prepared_alloc_bytes( &self, base2k: Base2K, k: TorusPrecision, rank_in: Rank, rank_out: Rank, dnum: Dnum, dsize: Dsize, ) -> usize { let size: usize = k.0.div_ceil(base2k.0) as usize; debug_assert!( size as u32 > dsize.0, "invalid gglwe: ceil(k/base2k): {size} <= dsize: {}", dsize.0 ); assert!( dnum.0 * dsize.0 <= size as u32, "invalid gglwe: dnum: {} * dsize:{} > ceil(k/base2k): {size}", dnum.0, dsize.0, ); self.vmp_pmat_alloc_bytes(dnum.into(), rank_in.into(), (rank_out + 1).into(), size) } fn gglwe_prepared_alloc_bytes_from_infos(&self, infos: &A) -> usize where A: GGLWEInfos, { assert_eq!(self.n(), infos.n()); self.gglwe_prepared_alloc_bytes( infos.base2k(), infos.k(), infos.rank_in(), infos.rank_out(), infos.dnum(), infos.dsize(), ) } } impl GGLWEPreparedAlloc for Module where Module: GetDegree + VmpPMatAlloc + VmpPMatAllocBytes {} impl GGLWEPrepared, B> where Module: GGLWEPreparedAlloc, { pub fn alloc_from_infos(module: &Module, infos: &A) -> Self where A: GGLWEInfos, { module.gglwe_prepared_alloc_from_infos(infos) } pub fn alloc( module: &Module, base2k: Base2K, k: TorusPrecision, rank_in: Rank, rank_out: Rank, dnum: Dnum, dsize: Dsize, ) -> Self { module.gglwe_prepared_alloc(base2k, k, rank_in, rank_out, dnum, dsize) } pub fn alloc_bytes_from_infos(module: &Module, infos: &A) -> usize where A: GGLWEInfos, { module.gglwe_prepared_alloc_bytes_from_infos(infos) } pub fn alloc_bytes( module: &Module, base2k: Base2K, k: TorusPrecision, rank_in: Rank, rank_out: Rank, dnum: Dnum, dsize: Dsize, ) -> usize { module.gglwe_prepared_alloc_bytes(base2k, k, rank_in, rank_out, dnum, dsize) } } pub trait GGLWEPrepare where Self: GetDegree + VmpPrepareTmpBytes + VmpPrepare, { fn gglwe_prepare_tmp_bytes(&self, infos: &A) -> usize where A: GGLWEInfos, { self.vmp_prepare_tmp_bytes( infos.dnum().into(), infos.rank_in().into(), (infos.rank() + 1).into(), infos.size(), ) } fn gglwe_prepare(&self, res: &mut R, other: &O, scratch: &mut Scratch) where R: GGLWEPreparedToMut, O: GGLWEToRef, { let mut res: GGLWEPrepared<&mut [u8], B> = res.to_mut(); let other: GGLWE<&[u8]> = other.to_ref(); assert_eq!(res.n(), self.n()); assert_eq!(other.n(), self.n()); assert_eq!(res.base2k, other.base2k); assert_eq!(res.k, other.k); assert_eq!(res.dsize, other.dsize); self.vmp_prepare(&mut res.data, &other.data, scratch); } } impl GGLWEPrepare for Module where Self: GetDegree + VmpPrepareTmpBytes + VmpPrepare {} impl GGLWEPrepared where Module: GGLWEPrepare, { pub fn prepare(&mut self, module: &Module, other: &O, scratch: &mut Scratch) where O: GGLWEToRef, { module.gglwe_prepare(self, other, scratch); } } impl GGLWEPrepared, B> { pub fn prepare_tmp_bytes(&self, module: &Module) -> usize where Module: GGLWEPrepare, { module.gglwe_prepare_tmp_bytes(self) } } pub trait GGLWEPreparedToMut { fn to_mut(&mut self) -> GGLWEPrepared<&mut [u8], B>; } impl GGLWEPreparedToMut for GGLWEPrepared { fn to_mut(&mut self) -> GGLWEPrepared<&mut [u8], B> { GGLWEPrepared { k: self.k, base2k: self.base2k, dsize: self.dsize, data: self.data.to_mut(), } } } pub trait GGLWEPreparedToRef { fn to_ref(&self) -> GGLWEPrepared<&[u8], B>; } impl GGLWEPreparedToRef for GGLWEPrepared { fn to_ref(&self) -> GGLWEPrepared<&[u8], B> { GGLWEPrepared { k: self.k, base2k: self.base2k, dsize: self.dsize, data: self.data.to_ref(), } } }