mirror of
https://github.com/arnaucube/poulpy.git
synced 2026-02-10 13:16:44 +01:00
Added GGSW key-switching along with algo description
This commit is contained in:
@@ -99,6 +99,13 @@ pub trait MatZnxDftOps<BACKEND: Backend> {
|
||||
R: VecZnxDftToMut<BACKEND>,
|
||||
A: VecZnxDftToRef<BACKEND>,
|
||||
B: MatZnxDftToRef<BACKEND>;
|
||||
|
||||
// Same as [MatZnxDftOps::vmp_apply] except result is added on R instead of overwritting R.
|
||||
fn vmp_apply_add<R, A, B>(&self, res: &mut R, a: &A, b: &B, scratch: &mut Scratch)
|
||||
where
|
||||
R: VecZnxDftToMut<BACKEND>,
|
||||
A: VecZnxDftToRef<BACKEND>,
|
||||
B: MatZnxDftToRef<BACKEND>;
|
||||
}
|
||||
|
||||
impl<B: Backend> MatZnxDftAlloc<B> for Module<B> {
|
||||
@@ -301,6 +308,59 @@ impl MatZnxDftOps<FFT64> for Module<FFT64> {
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fn vmp_apply_add<R, A, B>(&self, res: &mut R, a: &A, b: &B, scratch: &mut Scratch)
|
||||
where
|
||||
R: VecZnxDftToMut<FFT64>,
|
||||
A: VecZnxDftToRef<FFT64>,
|
||||
B: MatZnxDftToRef<FFT64> {
|
||||
let mut res: VecZnxDft<&mut [u8], _> = res.to_mut();
|
||||
let a: VecZnxDft<&[u8], _> = a.to_ref();
|
||||
let b: MatZnxDft<&[u8], _> = b.to_ref();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(res.n(), self.n());
|
||||
assert_eq!(b.n(), self.n());
|
||||
assert_eq!(a.n(), self.n());
|
||||
assert_eq!(
|
||||
res.cols(),
|
||||
b.cols_out(),
|
||||
"res.cols(): {} != b.cols_out: {}",
|
||||
res.cols(),
|
||||
b.cols_out()
|
||||
);
|
||||
assert_eq!(
|
||||
a.cols(),
|
||||
b.cols_in(),
|
||||
"a.cols(): {} != b.cols_in: {}",
|
||||
a.cols(),
|
||||
b.cols_in()
|
||||
);
|
||||
}
|
||||
|
||||
let (tmp_bytes, _) = scratch.tmp_slice(self.vmp_apply_tmp_bytes(
|
||||
res.size(),
|
||||
a.size(),
|
||||
b.rows(),
|
||||
b.cols_in(),
|
||||
b.cols_out(),
|
||||
b.size(),
|
||||
));
|
||||
unsafe {
|
||||
vmp::vmp_apply_dft_to_dft_add(
|
||||
self.ptr,
|
||||
res.as_mut_ptr() as *mut vec_znx_dft_t,
|
||||
(res.size() * res.cols()) as u64,
|
||||
a.as_ptr() as *const vec_znx_dft_t,
|
||||
(a.size() * a.cols()) as u64,
|
||||
b.as_ptr() as *const vmp::vmp_pmat_t,
|
||||
(b.rows() * b.cols_in()) as u64,
|
||||
(b.size() * b.cols_out()) as u64,
|
||||
tmp_bytes.as_mut_ptr(),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
|
||||
Reference in New Issue
Block a user