use poulpy_hal::{ api::ModuleN, layouts::{Backend, DataMut, Module, Scratch, ZnxView, ZnxViewMut, ZnxZero}, }; use crate::{ GLWEKeyswitch, GLWERotate, ScratchTakeCore, layouts::{GGLWEInfos, GGLWEPreparedToRef, GLWE, GLWEInfos, GLWELayout, GLWEToRef, LWE, LWEInfos, LWEToMut, Rank}, }; pub trait LWESampleExtract where Self: ModuleN, { fn lwe_sample_extract(&self, res: &mut R, a: &A) where R: LWEToMut, A: GLWEToRef, { let res: &mut LWE<&mut [u8]> = &mut res.to_mut(); let a: &GLWE<&[u8]> = &a.to_ref(); assert!(res.n() <= a.n()); assert_eq!(a.n(), self.n() as u32); assert!(res.base2k() == a.base2k()); let min_size: usize = res.size().min(a.size()); let n: usize = res.n().into(); res.data.zero(); (0..min_size).for_each(|i| { let data_lwe: &mut [i64] = res.data.at_mut(0, i); data_lwe[0] = a.data.at(0, i)[0]; data_lwe[1..].copy_from_slice(&a.data.at(1, i)[..n]); }); } } impl LWESampleExtract for Module where Self: ModuleN {} impl LWEFromGLWE for Module where Self: GLWEKeyswitch + LWESampleExtract + GLWERotate {} pub trait LWEFromGLWE where Self: GLWEKeyswitch + LWESampleExtract + GLWERotate, { fn lwe_from_glwe_tmp_bytes(&self, lwe_infos: &R, glwe_infos: &A, key_infos: &K) -> usize where R: LWEInfos, A: GLWEInfos, K: GGLWEInfos, { let res_infos: GLWELayout = GLWELayout { n: self.n().into(), base2k: lwe_infos.base2k(), k: lwe_infos.k(), rank: Rank(1), }; GLWE::bytes_of( self.n().into(), lwe_infos.base2k(), lwe_infos.k(), 1u32.into(), ) + GLWE::bytes_of_from_infos(glwe_infos) + self.glwe_keyswitch_tmp_bytes(&res_infos, glwe_infos, key_infos) } fn lwe_from_glwe(&self, res: &mut R, a: &A, a_idx: usize, key: &K, scratch: &mut Scratch) where R: LWEToMut, A: GLWEToRef, K: GGLWEPreparedToRef + GGLWEInfos, Scratch: ScratchTakeCore, { let res: &mut LWE<&mut [u8]> = &mut res.to_mut(); let a: &GLWE<&[u8]> = &a.to_ref(); assert_eq!(a.n(), self.n() as u32); assert_eq!(key.n(), self.n() as u32); assert!(res.n() <= self.n() as u32); let glwe_layout: GLWELayout = GLWELayout { n: self.n().into(), base2k: res.base2k(), k: res.k(), rank: Rank(1), }; let (mut tmp_glwe_rank_1, scratch_1) = scratch.take_glwe(&glwe_layout); match a_idx { 0 => { self.glwe_keyswitch(&mut tmp_glwe_rank_1, a, key, scratch_1); } _ => { let (mut tmp_glwe_in, scratch_2) = scratch_1.take_glwe(a); self.glwe_rotate(-(a_idx as i64), &mut tmp_glwe_in, a); self.glwe_keyswitch(&mut tmp_glwe_rank_1, &tmp_glwe_in, key, scratch_2); } } self.lwe_sample_extract(res, &tmp_glwe_rank_1); } } impl LWE> { pub fn from_glwe_tmp_bytes(module: &M, lwe_infos: &R, glwe_infos: &A, key_infos: &K) -> usize where R: LWEInfos, A: GLWEInfos, K: GGLWEInfos, M: LWEFromGLWE, { module.lwe_from_glwe_tmp_bytes(lwe_infos, glwe_infos, key_infos) } } impl LWE { pub fn sample_extract(&mut self, module: &M, a: &A) where A: GLWEToRef, M: LWESampleExtract, { module.lwe_sample_extract(self, a); } pub fn from_glwe(&mut self, module: &M, a: &A, a_idx: usize, key: &K, scratch: &mut Scratch) where A: GLWEToRef, K: GGLWEPreparedToRef + GGLWEInfos, M: LWEFromGLWE, Scratch: ScratchTakeCore, { module.lwe_from_glwe(self, a, a_idx, key, scratch); } }