use poulpy_hal::{ api::{ ScratchAvailable, TakeVecZnx, TakeVecZnxDft, VecZnxAutomorphismInplace, VecZnxBigAddSmallInplace, VecZnxBigAutomorphismInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxBigSubSmallInplace, VecZnxBigSubSmallNegateInplace, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeTmpBytes, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, }, layouts::{Backend, DataMut, DataRef, Module, Scratch, VecZnxBig}, }; use crate::layouts::{GGLWELayoutInfos, GLWECiphertext, GLWEInfos, LWEInfos, prepared::GGLWEAutomorphismKeyPrepared}; impl GLWECiphertext> { pub fn automorphism_scratch_space( module: &Module, out_infos: &OUT, in_infos: &IN, key_infos: &KEY, ) -> usize where OUT: GLWEInfos, IN: GLWEInfos, KEY: GGLWELayoutInfos, Module: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VecZnxNormalizeTmpBytes, { Self::keyswitch_scratch_space(module, out_infos, in_infos, key_infos) } pub fn automorphism_inplace_scratch_space(module: &Module, out_infos: &OUT, key_infos: &KEY) -> usize where OUT: GLWEInfos, KEY: GGLWELayoutInfos, Module: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VecZnxNormalizeTmpBytes, { Self::keyswitch_inplace_scratch_space(module, out_infos, key_infos) } } impl GLWECiphertext { pub fn automorphism( &mut self, module: &Module, lhs: &GLWECiphertext, rhs: &GGLWEAutomorphismKeyPrepared, scratch: &mut Scratch, ) where Module: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VmpApplyDftToDft + VmpApplyDftToDftAdd + VecZnxDftApply + VecZnxIdftApplyConsume + VecZnxBigAddSmallInplace + VecZnxBigNormalize + VecZnxAutomorphismInplace + VecZnxNormalize + VecZnxNormalizeTmpBytes, Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, { self.keyswitch(module, lhs, &rhs.key, scratch); (0..(self.rank() + 1).into()).for_each(|i| { module.vec_znx_automorphism_inplace(rhs.p(), &mut self.data, i, scratch); }) } pub fn automorphism_inplace( &mut self, module: &Module, rhs: &GGLWEAutomorphismKeyPrepared, scratch: &mut Scratch, ) where Module: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VmpApplyDftToDft + VmpApplyDftToDftAdd + VecZnxDftApply + VecZnxIdftApplyConsume + VecZnxBigAddSmallInplace + VecZnxBigNormalize + VecZnxAutomorphismInplace + VecZnxNormalize + VecZnxNormalizeTmpBytes, Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, { self.keyswitch_inplace(module, &rhs.key, scratch); (0..(self.rank() + 1).into()).for_each(|i| { module.vec_znx_automorphism_inplace(rhs.p(), &mut self.data, i, scratch); }) } pub fn automorphism_add( &mut self, module: &Module, lhs: &GLWECiphertext, rhs: &GGLWEAutomorphismKeyPrepared, scratch: &mut Scratch, ) where Module: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VmpApplyDftToDft + VmpApplyDftToDftAdd + VecZnxDftApply + VecZnxIdftApplyConsume + VecZnxBigAddSmallInplace + VecZnxBigNormalize + VecZnxBigAutomorphismInplace + VecZnxNormalizeTmpBytes + VecZnxNormalize, Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, { #[cfg(debug_assertions)] { self.assert_keyswitch(module, lhs, &rhs.key, scratch); } let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self.n().into(), (self.rank() + 1).into(), rhs.size()); // TODO: optimise size let mut res_big: VecZnxBig<_, B> = lhs.keyswitch_internal(module, res_dft, &rhs.key, scratch_1); (0..(self.rank() + 1).into()).for_each(|i| { module.vec_znx_big_automorphism_inplace(rhs.p(), &mut res_big, i, scratch_1); module.vec_znx_big_add_small_inplace(&mut res_big, i, &lhs.data, i); module.vec_znx_big_normalize( self.base2k().into(), &mut self.data, i, rhs.base2k().into(), &res_big, i, scratch_1, ); }) } pub fn automorphism_add_inplace( &mut self, module: &Module, rhs: &GGLWEAutomorphismKeyPrepared, scratch: &mut Scratch, ) where Module: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VmpApplyDftToDft + VmpApplyDftToDftAdd + VecZnxDftApply + VecZnxIdftApplyConsume + VecZnxBigAddSmallInplace + VecZnxBigNormalize + VecZnxBigAutomorphismInplace + VecZnxNormalizeTmpBytes + VecZnxNormalize, Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, { #[cfg(debug_assertions)] { self.assert_keyswitch_inplace(module, &rhs.key, scratch); } let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self.n().into(), (self.rank() + 1).into(), rhs.size()); // TODO: optimise size let mut res_big: VecZnxBig<_, B> = self.keyswitch_internal(module, res_dft, &rhs.key, scratch_1); (0..(self.rank() + 1).into()).for_each(|i| { module.vec_znx_big_automorphism_inplace(rhs.p(), &mut res_big, i, scratch_1); module.vec_znx_big_add_small_inplace(&mut res_big, i, &self.data, i); module.vec_znx_big_normalize( self.base2k().into(), &mut self.data, i, rhs.base2k().into(), &res_big, i, scratch_1, ); }) } pub fn automorphism_sub_ab( &mut self, module: &Module, lhs: &GLWECiphertext, rhs: &GGLWEAutomorphismKeyPrepared, scratch: &mut Scratch, ) where Module: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VmpApplyDftToDft + VmpApplyDftToDftAdd + VecZnxDftApply + VecZnxIdftApplyConsume + VecZnxBigAddSmallInplace + VecZnxBigNormalize + VecZnxBigAutomorphismInplace + VecZnxBigSubSmallInplace + VecZnxNormalizeTmpBytes + VecZnxNormalize, Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, { #[cfg(debug_assertions)] { self.assert_keyswitch(module, lhs, &rhs.key, scratch); } let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self.n().into(), (self.rank() + 1).into(), rhs.size()); // TODO: optimise size let mut res_big: VecZnxBig<_, B> = lhs.keyswitch_internal(module, res_dft, &rhs.key, scratch_1); (0..(self.rank() + 1).into()).for_each(|i| { module.vec_znx_big_automorphism_inplace(rhs.p(), &mut res_big, i, scratch_1); module.vec_znx_big_sub_small_inplace(&mut res_big, i, &lhs.data, i); module.vec_znx_big_normalize( self.base2k().into(), &mut self.data, i, rhs.base2k().into(), &res_big, i, scratch_1, ); }) } pub fn automorphism_sub_inplace( &mut self, module: &Module, rhs: &GGLWEAutomorphismKeyPrepared, scratch: &mut Scratch, ) where Module: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VmpApplyDftToDft + VmpApplyDftToDftAdd + VecZnxDftApply + VecZnxIdftApplyConsume + VecZnxBigAddSmallInplace + VecZnxBigNormalize + VecZnxBigAutomorphismInplace + VecZnxBigSubSmallInplace + VecZnxNormalizeTmpBytes + VecZnxNormalize, Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, { #[cfg(debug_assertions)] { self.assert_keyswitch_inplace(module, &rhs.key, scratch); } let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self.n().into(), (self.rank() + 1).into(), rhs.size()); // TODO: optimise size let mut res_big: VecZnxBig<_, B> = self.keyswitch_internal(module, res_dft, &rhs.key, scratch_1); (0..(self.rank() + 1).into()).for_each(|i| { module.vec_znx_big_automorphism_inplace(rhs.p(), &mut res_big, i, scratch_1); module.vec_znx_big_sub_small_inplace(&mut res_big, i, &self.data, i); module.vec_znx_big_normalize( self.base2k().into(), &mut self.data, i, rhs.base2k().into(), &res_big, i, scratch_1, ); }) } pub fn automorphism_sub_negate( &mut self, module: &Module, lhs: &GLWECiphertext, rhs: &GGLWEAutomorphismKeyPrepared, scratch: &mut Scratch, ) where Module: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VmpApplyDftToDft + VmpApplyDftToDftAdd + VecZnxDftApply + VecZnxIdftApplyConsume + VecZnxBigAddSmallInplace + VecZnxBigNormalize + VecZnxBigAutomorphismInplace + VecZnxBigSubSmallNegateInplace + VecZnxNormalizeTmpBytes + VecZnxNormalize, Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, { #[cfg(debug_assertions)] { self.assert_keyswitch(module, lhs, &rhs.key, scratch); } let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self.n().into(), (self.rank() + 1).into(), rhs.size()); // TODO: optimise size let mut res_big: VecZnxBig<_, B> = lhs.keyswitch_internal(module, res_dft, &rhs.key, scratch_1); (0..(self.rank() + 1).into()).for_each(|i| { module.vec_znx_big_automorphism_inplace(rhs.p(), &mut res_big, i, scratch_1); module.vec_znx_big_sub_small_negate_inplace(&mut res_big, i, &lhs.data, i); module.vec_znx_big_normalize( self.base2k().into(), &mut self.data, i, rhs.base2k().into(), &res_big, i, scratch_1, ); }) } pub fn automorphism_sub_negate_inplace( &mut self, module: &Module, rhs: &GGLWEAutomorphismKeyPrepared, scratch: &mut Scratch, ) where Module: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VmpApplyDftToDft + VmpApplyDftToDftAdd + VecZnxDftApply + VecZnxIdftApplyConsume + VecZnxBigAddSmallInplace + VecZnxBigNormalize + VecZnxBigAutomorphismInplace + VecZnxBigSubSmallNegateInplace + VecZnxNormalizeTmpBytes + VecZnxNormalize, Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, { #[cfg(debug_assertions)] { self.assert_keyswitch_inplace(module, &rhs.key, scratch); } let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self.n().into(), (self.rank() + 1).into(), rhs.size()); // TODO: optimise size let mut res_big: VecZnxBig<_, B> = self.keyswitch_internal(module, res_dft, &rhs.key, scratch_1); (0..(self.rank() + 1).into()).for_each(|i| { module.vec_znx_big_automorphism_inplace(rhs.p(), &mut res_big, i, scratch_1); module.vec_znx_big_sub_small_negate_inplace(&mut res_big, i, &self.data, i); module.vec_znx_big_normalize( self.base2k().into(), &mut self.data, i, rhs.base2k().into(), &res_big, i, scratch_1, ); }) } }