fixed scratch API

This commit is contained in:
Pro7ech
2025-10-21 10:47:46 +02:00
parent 681ec7e349
commit fef2a2fc27
28 changed files with 112 additions and 153 deletions

View File

@@ -3,8 +3,7 @@ use std::marker::PhantomData;
use poulpy_core::layouts::{Base2K, Dnum, Dsize, GGSWInfos, GLWEInfos, LWEInfos, Rank, TorusPrecision, prepared::GGSWPrepared};
#[cfg(test)]
use poulpy_core::{
TakeGGSW,
layouts::{GGSW, prepared::GLWESecretPrepared},
layouts::{prepared::GLWESecretPrepared, GGSW}, ScratchTakeCore,
};
use poulpy_hal::{
api::VmpPMatAlloc,
@@ -13,13 +12,12 @@ use poulpy_hal::{
#[cfg(test)]
use poulpy_hal::{
api::{
ScratchAvailable, SvpApplyDftToDftInplace, TakeScalarZnx, TakeVecZnx, TakeVecZnxDft, VecZnxAddInplace, VecZnxAddNormal,
SvpApplyDftToDftInplace, VecZnxAddInplace, VecZnxAddNormal,
VecZnxAddScalarInplace, VecZnxBigAddInplace, VecZnxBigAddSmallInplace, VecZnxBigAlloc, VecZnxBigBytesOf,
VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxDftAlloc, VecZnxDftApply, VecZnxDftBytesOf, VecZnxFillUniform,
VecZnxIdftApplyConsume, VecZnxIdftApplyTmpA, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub,
VecZnxSubInplace, VmpPrepare,
},
oep::{ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, TakeVecZnxBigImpl, TakeVecZnxDftImpl},
source::Source,
};
@@ -137,7 +135,7 @@ impl<D: DataMut, T: UnsignedInteger + ToBits, BE: Backend> FheUintBlocksPrep<D,
+ VecZnxNormalize<BE>
+ VecZnxSub
+ VmpPrepare<BE>,
Scratch<BE>: TakeVecZnxDft<BE> + ScratchAvailable + TakeVecZnx + TakeGGSW + TakeScalarZnx,
Scratch<BE>: ScratchTakeCore<BE>,
{
#[cfg(debug_assertions)]
{
@@ -146,11 +144,11 @@ impl<D: DataMut, T: UnsignedInteger + ToBits, BE: Backend> FheUintBlocksPrep<D,
assert_eq!(sk.n(), module.n() as u32);
}
let (mut tmp_ggsw, scratch_1) = scratch.take_ggsw(self);
let (mut tmp_ggsw, scratch_1) = scratch.take_ggsw(module, self);
let (mut pt, scratch_2) = scratch_1.take_scalar_znx(module.n(), 1);
for i in 0..T::WORD_SIZE {
use poulpy_core::layouts::prepared::Prepare;
use poulpy_hal::layouts::ZnxViewMut;
pt.at_mut(0, 0)[0] = value.bit(i) as i64;
@@ -205,7 +203,6 @@ impl<D: DataRef, T: UnsignedInteger + ToBits> FheUintBlocksPrepDebug<D, T> {
+ VecZnxIdftApplyTmpA<BE>
+ VecZnxAddScalarInplace
+ VecZnxSubInplace,
BE: Backend + TakeVecZnxDftImpl<BE> + TakeVecZnxBigImpl<BE> + ScratchOwnedAllocImpl<BE> + ScratchOwnedBorrowImpl<BE>,
{
for (i, ggsw) in self.blocks.iter().enumerate() {
use poulpy_hal::layouts::{ScalarZnx, ZnxViewMut};