use poulpy_hal::{ api::{VmpPMatAlloc, VmpPMatAllocBytes, VmpPrepare}, layouts::{Backend, Data, DataMut, DataRef, Module, Scratch}, }; use crate::layouts::{ Base2K, Degree, Dnum, Dsize, GGLWEInfos, GGLWESwitchingKey, GLWEInfos, LWEInfos, Rank, TorusPrecision, prepared::{GGLWECiphertextPrepared, Prepare, PrepareAlloc, PrepareScratchSpace}, }; #[derive(PartialEq, Eq)] pub struct GGLWESwitchingKeyPrepared { pub(crate) key: GGLWECiphertextPrepared, pub(crate) sk_in_n: usize, // Degree of sk_in pub(crate) sk_out_n: usize, // Degree of sk_out } impl LWEInfos for GGLWESwitchingKeyPrepared { fn n(&self) -> Degree { self.key.n() } fn base2k(&self) -> Base2K { self.key.base2k() } fn k(&self) -> TorusPrecision { self.key.k() } fn size(&self) -> usize { self.key.size() } } impl GLWEInfos for GGLWESwitchingKeyPrepared { fn rank(&self) -> Rank { self.rank_out() } } impl GGLWEInfos for GGLWESwitchingKeyPrepared { fn rank_in(&self) -> Rank { self.key.rank_in() } fn rank_out(&self) -> Rank { self.key.rank_out() } fn dsize(&self) -> Dsize { self.key.dsize() } fn dnum(&self) -> Dnum { self.key.dnum() } } impl GGLWESwitchingKeyPrepared, B> { pub fn alloc(module: &Module, infos: &A) -> Self where A: GGLWEInfos, Module: VmpPMatAlloc, { debug_assert_eq!(module.n() as u32, infos.n(), "module.n() != infos.n()"); GGLWESwitchingKeyPrepared::, B> { key: GGLWECiphertextPrepared::alloc(module, infos), sk_in_n: 0, sk_out_n: 0, } } pub fn alloc_with( module: &Module, base2k: Base2K, k: TorusPrecision, rank_in: Rank, rank_out: Rank, dnum: Dnum, dsize: Dsize, ) -> Self where Module: VmpPMatAlloc, { GGLWESwitchingKeyPrepared::, B> { key: GGLWECiphertextPrepared::alloc_with(module, base2k, k, rank_in, rank_out, dnum, dsize), sk_in_n: 0, sk_out_n: 0, } } pub fn alloc_bytes(module: &Module, infos: &A) -> usize where A: GGLWEInfos, Module: VmpPMatAllocBytes, { debug_assert_eq!(module.n() as u32, infos.n(), "module.n() != infos.n()"); GGLWECiphertextPrepared::alloc_bytes(module, infos) } pub fn alloc_bytes_with( module: &Module, base2k: Base2K, k: TorusPrecision, rank_in: Rank, rank_out: Rank, dnum: Dnum, dsize: Dsize, ) -> usize where Module: VmpPMatAllocBytes, { GGLWECiphertextPrepared::alloc_bytes_with(module, base2k, k, rank_in, rank_out, dnum, dsize) } } impl PrepareScratchSpace for GGLWESwitchingKeyPrepared, B> where GGLWECiphertextPrepared, B>: PrepareScratchSpace, { fn prepare_scratch_space(module: &Module, infos: &A) -> usize { GGLWECiphertextPrepared::prepare_scratch_space(module, infos) } } impl Prepare> for GGLWESwitchingKeyPrepared where Module: VmpPrepare, { fn prepare(&mut self, module: &Module, other: &GGLWESwitchingKey, scratch: &mut Scratch) { self.key.prepare(module, &other.key, scratch); self.sk_in_n = other.sk_in_n; self.sk_out_n = other.sk_out_n; } } impl PrepareAlloc, B>> for GGLWESwitchingKey where Module: VmpPMatAlloc + VmpPrepare, { fn prepare_alloc(&self, module: &Module, scratch: &mut Scratch) -> GGLWESwitchingKeyPrepared, B> { let mut atk_prepared: GGLWESwitchingKeyPrepared, B> = GGLWESwitchingKeyPrepared::alloc(module, self); atk_prepared.prepare(module, self, scratch); atk_prepared } }