use base2k::{ Backend, DataView, DataViewMut, MatZnxDft, MatZnxDftAlloc, MatZnxDftToMut, MatZnxDftToRef, Module, ScalarZnxDftToRef, VecZnx, VecZnxAlloc, VecZnxDft, VecZnxDftAlloc, VecZnxDftToMut, VecZnxDftToRef, VecZnxToMut, VecZnxToRef, ZnxInfos, }; pub trait Infos { type Inner: ZnxInfos; fn inner(&self) -> &Self::Inner; /// Returns the ring degree of the polynomials. fn n(&self) -> usize { self.inner().n() } /// Returns the base two logarithm of the ring dimension of the polynomials. fn log_n(&self) -> usize { self.inner().log_n() } /// Returns the number of rows. fn rows(&self) -> usize { self.inner().rows() } /// Returns the number of polynomials in each row. fn cols(&self) -> usize { self.inner().cols() } /// Returns the number of size per polynomial. fn size(&self) -> usize { let size: usize = self.inner().size(); debug_assert_eq!(size, derive_size(self.log_base2k(), self.log_q())); size } /// Returns the total number of small polynomials. fn poly_count(&self) -> usize { self.rows() * self.cols() * self.size() } /// Returns the base 2 logarithm of the ciphertext base. fn log_base2k(&self) -> usize; /// Returns the base 2 logarithm of the ciphertext modulus. fn log_q(&self) -> usize; } pub struct RLWECt{ data: VecZnx, log_base2k: usize, log_q: usize, } impl Infos for RLWECt { type Inner = T; fn inner(&self) -> &Self::Inner { &self.data } fn log_base2k(&self) -> usize { self.log_base2k } fn log_q(&self) -> usize { self.log_q } } impl DataView for Ciphertext { type D = D; fn data(&self) -> &Self::D { &self.data } } impl DataViewMut for Ciphertext { fn data_mut(&mut self) -> &mut Self::D { &mut self.data } } pub struct Plaintext { data: T, log_base2k: usize, log_q: usize, } impl Infos for Plaintext { type Inner = T; fn inner(&self) -> &Self::Inner { &self.data } fn log_base2k(&self) -> usize { self.log_base2k } fn log_q(&self) -> usize { self.log_q } } impl Plaintext { pub fn data(&self) -> &T { &self.data } pub fn data_mut(&mut self) -> &mut T { &mut self.data } } pub(crate) type CtVecZnx = Ciphertext>; pub(crate) type CtVecZnxDft = Ciphertext>; pub(crate) type CtMatZnxDft = Ciphertext>; pub(crate) type PtVecZnx = Plaintext>; pub(crate) type PtVecZnxDft = Plaintext>; pub(crate) type PtMatZnxDft = Plaintext>; impl VecZnxToMut for Ciphertext where D: VecZnxToMut, { fn to_mut(&mut self) -> VecZnx<&mut [u8]> { self.data_mut().to_mut() } } impl VecZnxToRef for Ciphertext where D: VecZnxToRef, { fn to_ref(&self) -> VecZnx<&[u8]> { self.data().to_ref() } } impl Ciphertext>> { pub fn new(module: &Module, log_base2k: usize, log_q: usize, cols: usize) -> Self { Self { data: module.new_vec_znx(cols, derive_size(log_base2k, log_q)), log_base2k: log_base2k, log_q: log_q, } } } impl VecZnxToMut for Plaintext where D: VecZnxToMut, { fn to_mut(&mut self) -> VecZnx<&mut [u8]> { self.data_mut().to_mut() } } impl VecZnxToRef for Plaintext where D: VecZnxToRef, { fn to_ref(&self) -> VecZnx<&[u8]> { self.data().to_ref() } } impl Plaintext>> { pub fn new(module: &Module, log_base2k: usize, log_q: usize) -> Self { Self { data: module.new_vec_znx(1, derive_size(log_base2k, log_q)), log_base2k: log_base2k, log_q: log_q, } } } impl VecZnxDftToMut for Ciphertext where D: VecZnxDftToMut, { fn to_mut(&mut self) -> VecZnxDft<&mut [u8], B> { self.data_mut().to_mut() } } impl VecZnxDftToRef for Ciphertext where D: VecZnxDftToRef, { fn to_ref(&self) -> VecZnxDft<&[u8], B> { self.data().to_ref() } } impl Ciphertext, B>> { pub fn new(module: &Module, log_base2k: usize, log_q: usize, cols: usize) -> Self { Self { data: module.new_vec_znx_dft(cols, derive_size(log_base2k, log_q)), log_base2k: log_base2k, log_q: log_q, } } } impl MatZnxDftToMut for Ciphertext where D: MatZnxDftToMut, { fn to_mut(&mut self) -> MatZnxDft<&mut [u8], B> { self.data_mut().to_mut() } } impl MatZnxDftToRef for Ciphertext where D: MatZnxDftToRef, { fn to_ref(&self) -> MatZnxDft<&[u8], B> { self.data().to_ref() } } impl Ciphertext, B>> { pub fn new(module: &Module, log_base2k: usize, rows: usize, cols_in: usize, cols_out: usize, log_q: usize) -> Self { Self { data: module.new_mat_znx_dft(rows, cols_in, cols_out, derive_size(log_base2k, log_q)), log_base2k: log_base2k, log_q: log_q, } } } pub(crate) fn derive_size(log_base2k: usize, log_q: usize) -> usize { (log_q + log_base2k - 1) / log_base2k }