Tracking PR for v0.7.0 releaseal-gkr-basic-workflow
@ -0,0 +1,3 @@ |
|||||
|
[submodule "PQClean"] |
||||
|
path = PQClean |
||||
|
url = https://github.com/PQClean/PQClean.git |
@ -0,0 +1,78 @@ |
|||||
|
#include <stddef.h> |
||||
|
#include <arm_sve.h> |
||||
|
#include "library.h" |
||||
|
#include "rpo_hash.h" |
||||
|
|
||||
|
// The STATE_WIDTH of RPO hash is 12x u64 elements. |
||||
|
// The current generation of SVE-enabled processors - Neoverse V1 |
||||
|
// (e.g. AWS Graviton3) have 256-bit vector registers (4x u64) |
||||
|
// This allows us to split the state into 3 vectors of 4 elements |
||||
|
// and process all 3 independent of each other. |
||||
|
|
||||
|
// We see the biggest performance gains by leveraging both |
||||
|
// vector and scalar operations on parts of the state array. |
||||
|
// Due to high latency of vector operations, the processor is able |
||||
|
// to reorder and pipeline scalar instructions while we wait for |
||||
|
// vector results. This effectively gives us some 'free' scalar |
||||
|
// operations and masks vector latency. |
||||
|
// |
||||
|
// This also means that we can fully saturate all four arithmetic |
||||
|
// units of the processor (2x scalar, 2x SIMD) |
||||
|
// |
||||
|
// THIS ANALYSIS NEEDS TO BE PERFORMED AGAIN ONCE PROCESSORS |
||||
|
// GAIN WIDER REGISTERS. It's quite possible that with 8x u64 |
||||
|
// vectors processing 2 partially filled vectors might |
||||
|
// be easier and faster than dealing with scalar operations |
||||
|
// on the remainder of the array. |
||||
|
// |
||||
|
// FOR NOW THIS IS ONLY ENABLED ON 4x u64 VECTORS! It falls back |
||||
|
// to the regular, already highly-optimized scalar version |
||||
|
// if the conditions are not met. |
||||
|
|
||||
|
bool add_constants_and_apply_sbox(uint64_t state[STATE_WIDTH], uint64_t constants[STATE_WIDTH]) { |
||||
|
const uint64_t vl = svcntd(); // number of u64 numbers in one SVE vector |
||||
|
|
||||
|
if (vl != 4) { |
||||
|
return false; |
||||
|
} |
||||
|
|
||||
|
svbool_t ptrue = svptrue_b64(); |
||||
|
|
||||
|
svuint64_t state1 = svld1(ptrue, state + 0*vl); |
||||
|
svuint64_t state2 = svld1(ptrue, state + 1*vl); |
||||
|
|
||||
|
svuint64_t const1 = svld1(ptrue, constants + 0*vl); |
||||
|
svuint64_t const2 = svld1(ptrue, constants + 1*vl); |
||||
|
|
||||
|
add_constants(ptrue, &state1, &const1, &state2, &const2, state+8, constants+8); |
||||
|
apply_sbox(ptrue, &state1, &state2, state+8); |
||||
|
|
||||
|
svst1(ptrue, state + 0*vl, state1); |
||||
|
svst1(ptrue, state + 1*vl, state2); |
||||
|
|
||||
|
return true; |
||||
|
} |
||||
|
|
||||
|
bool add_constants_and_apply_inv_sbox(uint64_t state[STATE_WIDTH], uint64_t constants[STATE_WIDTH]) { |
||||
|
const uint64_t vl = svcntd(); // number of u64 numbers in one SVE vector |
||||
|
|
||||
|
if (vl != 4) { |
||||
|
return false; |
||||
|
} |
||||
|
|
||||
|
svbool_t ptrue = svptrue_b64(); |
||||
|
|
||||
|
svuint64_t state1 = svld1(ptrue, state + 0 * vl); |
||||
|
svuint64_t state2 = svld1(ptrue, state + 1 * vl); |
||||
|
|
||||
|
svuint64_t const1 = svld1(ptrue, constants + 0 * vl); |
||||
|
svuint64_t const2 = svld1(ptrue, constants + 1 * vl); |
||||
|
|
||||
|
add_constants(ptrue, &state1, &const1, &state2, &const2, state + 8, constants + 8); |
||||
|
apply_inv_sbox(ptrue, &state1, &state2, state + 8); |
||||
|
|
||||
|
svst1(ptrue, state + 0 * vl, state1); |
||||
|
svst1(ptrue, state + 1 * vl, state2); |
||||
|
|
||||
|
return true; |
||||
|
} |
@ -0,0 +1,12 @@ |
|||||
|
#ifndef CRYPTO_LIBRARY_H |
||||
|
#define CRYPTO_LIBRARY_H |
||||
|
|
||||
|
#include <stdint.h> |
||||
|
#include <stdbool.h> |
||||
|
|
||||
|
#define STATE_WIDTH 12 |
||||
|
|
||||
|
bool add_constants_and_apply_sbox(uint64_t state[STATE_WIDTH], uint64_t constants[STATE_WIDTH]); |
||||
|
bool add_constants_and_apply_inv_sbox(uint64_t state[STATE_WIDTH], uint64_t constants[STATE_WIDTH]); |
||||
|
|
||||
|
#endif //CRYPTO_LIBRARY_H |
@ -0,0 +1,221 @@ |
|||||
|
#ifndef RPO_SVE_RPO_HASH_H |
||||
|
#define RPO_SVE_RPO_HASH_H |
||||
|
|
||||
|
#include <arm_sve.h> |
||||
|
#include <stddef.h> |
||||
|
#include <stdint.h> |
||||
|
#include <string.h> |
||||
|
|
||||
|
#define COPY(NAME, VIN1, VIN2, SIN3) \ |
||||
|
svuint64_t NAME ## _1 = VIN1; \ |
||||
|
svuint64_t NAME ## _2 = VIN2; \ |
||||
|
uint64_t NAME ## _3[4]; \ |
||||
|
memcpy(NAME ## _3, SIN3, 4 * sizeof(uint64_t)) |
||||
|
|
||||
|
#define MULTIPLY(PRED, DEST, OP) \ |
||||
|
mul(PRED, &DEST ## _1, &OP ## _1, &DEST ## _2, &OP ## _2, DEST ## _3, OP ## _3) |
||||
|
|
||||
|
#define SQUARE(PRED, NAME) \ |
||||
|
sq(PRED, &NAME ## _1, &NAME ## _2, NAME ## _3) |
||||
|
|
||||
|
#define SQUARE_DEST(PRED, DEST, SRC) \ |
||||
|
COPY(DEST, SRC ## _1, SRC ## _2, SRC ## _3); \ |
||||
|
SQUARE(PRED, DEST); |
||||
|
|
||||
|
#define POW_ACC(PRED, NAME, CNT, TAIL) \ |
||||
|
for (size_t i = 0; i < CNT; i++) { \ |
||||
|
SQUARE(PRED, NAME); \ |
||||
|
} \ |
||||
|
MULTIPLY(PRED, NAME, TAIL); |
||||
|
|
||||
|
#define POW_ACC_DEST(PRED, DEST, CNT, HEAD, TAIL) \ |
||||
|
COPY(DEST, HEAD ## _1, HEAD ## _2, HEAD ## _3); \ |
||||
|
POW_ACC(PRED, DEST, CNT, TAIL) |
||||
|
|
||||
|
extern inline void add_constants( |
||||
|
svbool_t pg, |
||||
|
svuint64_t *state1, |
||||
|
svuint64_t *const1, |
||||
|
svuint64_t *state2, |
||||
|
svuint64_t *const2, |
||||
|
uint64_t *state3, |
||||
|
uint64_t *const3 |
||||
|
) { |
||||
|
uint64_t Ms = 0xFFFFFFFF00000001ull; |
||||
|
svuint64_t Mv = svindex_u64(Ms, 0); |
||||
|
|
||||
|
uint64_t p_1 = Ms - const3[0]; |
||||
|
uint64_t p_2 = Ms - const3[1]; |
||||
|
uint64_t p_3 = Ms - const3[2]; |
||||
|
uint64_t p_4 = Ms - const3[3]; |
||||
|
|
||||
|
uint64_t x_1, x_2, x_3, x_4; |
||||
|
uint32_t adj_1 = -__builtin_sub_overflow(state3[0], p_1, &x_1); |
||||
|
uint32_t adj_2 = -__builtin_sub_overflow(state3[1], p_2, &x_2); |
||||
|
uint32_t adj_3 = -__builtin_sub_overflow(state3[2], p_3, &x_3); |
||||
|
uint32_t adj_4 = -__builtin_sub_overflow(state3[3], p_4, &x_4); |
||||
|
|
||||
|
state3[0] = x_1 - (uint64_t)adj_1; |
||||
|
state3[1] = x_2 - (uint64_t)adj_2; |
||||
|
state3[2] = x_3 - (uint64_t)adj_3; |
||||
|
state3[3] = x_4 - (uint64_t)adj_4; |
||||
|
|
||||
|
svuint64_t p1 = svsub_x(pg, Mv, *const1); |
||||
|
svuint64_t p2 = svsub_x(pg, Mv, *const2); |
||||
|
|
||||
|
svuint64_t x1 = svsub_x(pg, *state1, p1); |
||||
|
svuint64_t x2 = svsub_x(pg, *state2, p2); |
||||
|
|
||||
|
svbool_t pt1 = svcmplt_u64(pg, *state1, p1); |
||||
|
svbool_t pt2 = svcmplt_u64(pg, *state2, p2); |
||||
|
|
||||
|
*state1 = svsub_m(pt1, x1, (uint32_t)-1); |
||||
|
*state2 = svsub_m(pt2, x2, (uint32_t)-1); |
||||
|
} |
||||
|
|
||||
|
extern inline void mul( |
||||
|
svbool_t pg, |
||||
|
svuint64_t *r1, |
||||
|
const svuint64_t *op1, |
||||
|
svuint64_t *r2, |
||||
|
const svuint64_t *op2, |
||||
|
uint64_t *r3, |
||||
|
const uint64_t *op3 |
||||
|
) { |
||||
|
__uint128_t x_1 = r3[0]; |
||||
|
__uint128_t x_2 = r3[1]; |
||||
|
__uint128_t x_3 = r3[2]; |
||||
|
__uint128_t x_4 = r3[3]; |
||||
|
|
||||
|
x_1 *= (__uint128_t) op3[0]; |
||||
|
x_2 *= (__uint128_t) op3[1]; |
||||
|
x_3 *= (__uint128_t) op3[2]; |
||||
|
x_4 *= (__uint128_t) op3[3]; |
||||
|
|
||||
|
uint64_t x0_1 = x_1; |
||||
|
uint64_t x0_2 = x_2; |
||||
|
uint64_t x0_3 = x_3; |
||||
|
uint64_t x0_4 = x_4; |
||||
|
|
||||
|
svuint64_t l1 = svmul_x(pg, *r1, *op1); |
||||
|
svuint64_t l2 = svmul_x(pg, *r2, *op2); |
||||
|
|
||||
|
uint64_t x1_1 = (x_1 >> 64); |
||||
|
uint64_t x1_2 = (x_2 >> 64); |
||||
|
uint64_t x1_3 = (x_3 >> 64); |
||||
|
uint64_t x1_4 = (x_4 >> 64); |
||||
|
|
||||
|
uint64_t a_1, a_2, a_3, a_4; |
||||
|
uint64_t e_1 = __builtin_add_overflow(x0_1, (x0_1 << 32), &a_1); |
||||
|
uint64_t e_2 = __builtin_add_overflow(x0_2, (x0_2 << 32), &a_2); |
||||
|
uint64_t e_3 = __builtin_add_overflow(x0_3, (x0_3 << 32), &a_3); |
||||
|
uint64_t e_4 = __builtin_add_overflow(x0_4, (x0_4 << 32), &a_4); |
||||
|
|
||||
|
svuint64_t ls1 = svlsl_x(pg, l1, 32); |
||||
|
svuint64_t ls2 = svlsl_x(pg, l2, 32); |
||||
|
|
||||
|
svuint64_t a1 = svadd_x(pg, l1, ls1); |
||||
|
svuint64_t a2 = svadd_x(pg, l2, ls2); |
||||
|
|
||||
|
svbool_t e1 = svcmplt(pg, a1, l1); |
||||
|
svbool_t e2 = svcmplt(pg, a2, l2); |
||||
|
|
||||
|
svuint64_t as1 = svlsr_x(pg, a1, 32); |
||||
|
svuint64_t as2 = svlsr_x(pg, a2, 32); |
||||
|
|
||||
|
svuint64_t b1 = svsub_x(pg, a1, as1); |
||||
|
svuint64_t b2 = svsub_x(pg, a2, as2); |
||||
|
|
||||
|
b1 = svsub_m(e1, b1, 1); |
||||
|
b2 = svsub_m(e2, b2, 1); |
||||
|
|
||||
|
uint64_t b_1 = a_1 - (a_1 >> 32) - e_1; |
||||
|
uint64_t b_2 = a_2 - (a_2 >> 32) - e_2; |
||||
|
uint64_t b_3 = a_3 - (a_3 >> 32) - e_3; |
||||
|
uint64_t b_4 = a_4 - (a_4 >> 32) - e_4; |
||||
|
|
||||
|
uint64_t r_1, r_2, r_3, r_4; |
||||
|
uint32_t c_1 = __builtin_sub_overflow(x1_1, b_1, &r_1); |
||||
|
uint32_t c_2 = __builtin_sub_overflow(x1_2, b_2, &r_2); |
||||
|
uint32_t c_3 = __builtin_sub_overflow(x1_3, b_3, &r_3); |
||||
|
uint32_t c_4 = __builtin_sub_overflow(x1_4, b_4, &r_4); |
||||
|
|
||||
|
svuint64_t h1 = svmulh_x(pg, *r1, *op1); |
||||
|
svuint64_t h2 = svmulh_x(pg, *r2, *op2); |
||||
|
|
||||
|
svuint64_t tr1 = svsub_x(pg, h1, b1); |
||||
|
svuint64_t tr2 = svsub_x(pg, h2, b2); |
||||
|
|
||||
|
svbool_t c1 = svcmplt_u64(pg, h1, b1); |
||||
|
svbool_t c2 = svcmplt_u64(pg, h2, b2); |
||||
|
|
||||
|
*r1 = svsub_m(c1, tr1, (uint32_t) -1); |
||||
|
*r2 = svsub_m(c2, tr2, (uint32_t) -1); |
||||
|
|
||||
|
uint32_t minus1_1 = 0 - c_1; |
||||
|
uint32_t minus1_2 = 0 - c_2; |
||||
|
uint32_t minus1_3 = 0 - c_3; |
||||
|
uint32_t minus1_4 = 0 - c_4; |
||||
|
|
||||
|
r3[0] = r_1 - (uint64_t)minus1_1; |
||||
|
r3[1] = r_2 - (uint64_t)minus1_2; |
||||
|
r3[2] = r_3 - (uint64_t)minus1_3; |
||||
|
r3[3] = r_4 - (uint64_t)minus1_4; |
||||
|
} |
||||
|
|
||||
|
extern inline void sq(svbool_t pg, svuint64_t *a, svuint64_t *b, uint64_t *c) { |
||||
|
mul(pg, a, a, b, b, c, c); |
||||
|
} |
||||
|
|
||||
|
extern inline void apply_sbox( |
||||
|
svbool_t pg, |
||||
|
svuint64_t *state1, |
||||
|
svuint64_t *state2, |
||||
|
uint64_t *state3 |
||||
|
) { |
||||
|
COPY(x, *state1, *state2, state3); // copy input to x |
||||
|
SQUARE(pg, x); // x contains input^2 |
||||
|
mul(pg, state1, &x_1, state2, &x_2, state3, x_3); // state contains input^3 |
||||
|
SQUARE(pg, x); // x contains input^4 |
||||
|
mul(pg, state1, &x_1, state2, &x_2, state3, x_3); // state contains input^7 |
||||
|
} |
||||
|
|
||||
|
extern inline void apply_inv_sbox( |
||||
|
svbool_t pg, |
||||
|
svuint64_t *state_1, |
||||
|
svuint64_t *state_2, |
||||
|
uint64_t *state_3 |
||||
|
) { |
||||
|
// base^10 |
||||
|
COPY(t1, *state_1, *state_2, state_3); |
||||
|
SQUARE(pg, t1); |
||||
|
|
||||
|
// base^100 |
||||
|
SQUARE_DEST(pg, t2, t1); |
||||
|
|
||||
|
// base^100100 |
||||
|
POW_ACC_DEST(pg, t3, 3, t2, t2); |
||||
|
|
||||
|
// base^100100100100 |
||||
|
POW_ACC_DEST(pg, t4, 6, t3, t3); |
||||
|
|
||||
|
// compute base^100100100100100100100100 |
||||
|
POW_ACC_DEST(pg, t5, 12, t4, t4); |
||||
|
|
||||
|
// compute base^100100100100100100100100100100 |
||||
|
POW_ACC_DEST(pg, t6, 6, t5, t3); |
||||
|
|
||||
|
// compute base^1001001001001001001001001001000100100100100100100100100100100 |
||||
|
POW_ACC_DEST(pg, t7, 31, t6, t6); |
||||
|
|
||||
|
// compute base^1001001001001001001001001001000110110110110110110110110110110111 |
||||
|
SQUARE(pg, t7); |
||||
|
MULTIPLY(pg, t7, t6); |
||||
|
SQUARE(pg, t7); |
||||
|
SQUARE(pg, t7); |
||||
|
MULTIPLY(pg, t7, t1); |
||||
|
MULTIPLY(pg, t7, t2); |
||||
|
mul(pg, state_1, &t7_1, state_2, &t7_2, state_3, t7_3); |
||||
|
} |
||||
|
|
||||
|
#endif //RPO_SVE_RPO_HASH_H |
@ -0,0 +1,50 @@ |
|||||
|
fn main() {
|
||||
|
#[cfg(feature = "std")]
|
||||
|
compile_rpo_falcon();
|
||||
|
|
||||
|
#[cfg(all(target_feature = "sve", feature = "sve"))]
|
||||
|
compile_arch_arm64_sve();
|
||||
|
}
|
||||
|
|
||||
|
#[cfg(feature = "std")]
|
||||
|
fn compile_rpo_falcon() {
|
||||
|
use std::path::PathBuf;
|
||||
|
|
||||
|
const RPO_FALCON_PATH: &str = "src/dsa/rpo_falcon512/falcon_c";
|
||||
|
|
||||
|
println!("cargo:rerun-if-changed={RPO_FALCON_PATH}/falcon.h");
|
||||
|
println!("cargo:rerun-if-changed={RPO_FALCON_PATH}/falcon.c");
|
||||
|
println!("cargo:rerun-if-changed={RPO_FALCON_PATH}/rpo.h");
|
||||
|
println!("cargo:rerun-if-changed={RPO_FALCON_PATH}/rpo.c");
|
||||
|
|
||||
|
let target_dir: PathBuf = ["PQClean", "crypto_sign", "falcon-512", "clean"].iter().collect();
|
||||
|
let common_dir: PathBuf = ["PQClean", "common"].iter().collect();
|
||||
|
|
||||
|
let scheme_files = glob::glob(target_dir.join("*.c").to_str().unwrap()).unwrap();
|
||||
|
let common_files = glob::glob(common_dir.join("*.c").to_str().unwrap()).unwrap();
|
||||
|
|
||||
|
cc::Build::new()
|
||||
|
.include(&common_dir)
|
||||
|
.include(target_dir)
|
||||
|
.files(scheme_files.into_iter().map(|p| p.unwrap().to_string_lossy().into_owned()))
|
||||
|
.files(common_files.into_iter().map(|p| p.unwrap().to_string_lossy().into_owned()))
|
||||
|
.file(format!("{RPO_FALCON_PATH}/falcon.c"))
|
||||
|
.file(format!("{RPO_FALCON_PATH}/rpo.c"))
|
||||
|
.flag("-O3")
|
||||
|
.compile("rpo_falcon512");
|
||||
|
}
|
||||
|
|
||||
|
#[cfg(all(target_feature = "sve", feature = "sve"))]
|
||||
|
fn compile_arch_arm64_sve() {
|
||||
|
const RPO_SVE_PATH: &str = "arch/arm64-sve/rpo";
|
||||
|
|
||||
|
println!("cargo:rerun-if-changed={RPO_SVE_PATH}/library.c");
|
||||
|
println!("cargo:rerun-if-changed={RPO_SVE_PATH}/library.h");
|
||||
|
println!("cargo:rerun-if-changed={RPO_SVE_PATH}/rpo_hash.h");
|
||||
|
|
||||
|
cc::Build::new()
|
||||
|
.file(format!("{RPO_SVE_PATH}/library.c"))
|
||||
|
.flag("-march=armv8-a+sve")
|
||||
|
.flag("-O3")
|
||||
|
.compile("rpo_sve");
|
||||
|
}
|
@ -0,0 +1 @@ |
|||||
|
pub mod rpo_falcon512;
|
@ -0,0 +1,55 @@ |
|||||
|
use super::{LOG_N, MODULUS, PK_LEN};
|
||||
|
use core::fmt;
|
||||
|
|
||||
|
#[derive(Clone, Debug, PartialEq, Eq)]
|
||||
|
pub enum FalconError {
|
||||
|
KeyGenerationFailed,
|
||||
|
PubKeyDecodingExtraData,
|
||||
|
PubKeyDecodingInvalidCoefficient(u32),
|
||||
|
PubKeyDecodingInvalidLength(usize),
|
||||
|
PubKeyDecodingInvalidTag(u8),
|
||||
|
SigDecodingTooBigHighBits(u32),
|
||||
|
SigDecodingInvalidRemainder,
|
||||
|
SigDecodingNonZeroUnusedBitsLastByte,
|
||||
|
SigDecodingMinusZero,
|
||||
|
SigDecodingIncorrectEncodingAlgorithm,
|
||||
|
SigDecodingNotSupportedDegree(u8),
|
||||
|
SigGenerationFailed,
|
||||
|
}
|
||||
|
|
||||
|
impl fmt::Display for FalconError {
|
||||
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
|
use FalconError::*;
|
||||
|
match self {
|
||||
|
KeyGenerationFailed => write!(f, "Failed to generate a private-public key pair"),
|
||||
|
PubKeyDecodingExtraData => {
|
||||
|
write!(f, "Failed to decode public key: input not fully consumed")
|
||||
|
}
|
||||
|
PubKeyDecodingInvalidCoefficient(val) => {
|
||||
|
write!(f, "Failed to decode public key: coefficient {val} is greater than or equal to the field modulus {MODULUS}")
|
||||
|
}
|
||||
|
PubKeyDecodingInvalidLength(len) => {
|
||||
|
write!(f, "Failed to decode public key: expected {PK_LEN} bytes but received {len}")
|
||||
|
}
|
||||
|
PubKeyDecodingInvalidTag(byte) => {
|
||||
|
write!(f, "Failed to decode public key: expected the first byte to be {LOG_N} but was {byte}")
|
||||
|
}
|
||||
|
SigDecodingTooBigHighBits(m) => {
|
||||
|
write!(f, "Failed to decode signature: high bits {m} exceed 2048")
|
||||
|
}
|
||||
|
SigDecodingInvalidRemainder => {
|
||||
|
write!(f, "Failed to decode signature: incorrect remaining data")
|
||||
|
}
|
||||
|
SigDecodingNonZeroUnusedBitsLastByte => {
|
||||
|
write!(f, "Failed to decode signature: Non-zero unused bits in the last byte")
|
||||
|
}
|
||||
|
SigDecodingMinusZero => write!(f, "Failed to decode signature: -0 is forbidden"),
|
||||
|
SigDecodingIncorrectEncodingAlgorithm => write!(f, "Failed to decode signature: not supported encoding algorithm"),
|
||||
|
SigDecodingNotSupportedDegree(log_n) => write!(f, "Failed to decode signature: only supported irreducible polynomial degree is 512, 2^{log_n} was provided"),
|
||||
|
SigGenerationFailed => write!(f, "Failed to generate a signature"),
|
||||
|
}
|
||||
|
}
|
||||
|
}
|
||||
|
|
||||
|
#[cfg(feature = "std")]
|
||||
|
impl std::error::Error for FalconError {}
|
@ -0,0 +1,402 @@ |
|||||
|
/* |
||||
|
* Wrapper for implementing the PQClean API. |
||||
|
*/ |
||||
|
|
||||
|
#include <string.h> |
||||
|
#include "randombytes.h" |
||||
|
#include "falcon.h" |
||||
|
#include "inner.h" |
||||
|
#include "rpo.h" |
||||
|
|
||||
|
#define NONCELEN 40 |
||||
|
|
||||
|
/* |
||||
|
* Encoding formats (nnnn = log of degree, 9 for Falcon-512, 10 for Falcon-1024) |
||||
|
* |
||||
|
* private key: |
||||
|
* header byte: 0101nnnn |
||||
|
* private f (6 or 5 bits by element, depending on degree) |
||||
|
* private g (6 or 5 bits by element, depending on degree) |
||||
|
* private F (8 bits by element) |
||||
|
* |
||||
|
* public key: |
||||
|
* header byte: 0000nnnn |
||||
|
* public h (14 bits by element) |
||||
|
* |
||||
|
* signature: |
||||
|
* header byte: 0011nnnn |
||||
|
* nonce 40 bytes |
||||
|
* value (12 bits by element) |
||||
|
* |
||||
|
* message + signature: |
||||
|
* signature length (2 bytes, big-endian) |
||||
|
* nonce 40 bytes |
||||
|
* message |
||||
|
* header byte: 0010nnnn |
||||
|
* value (12 bits by element) |
||||
|
* (signature length is 1+len(value), not counting the nonce) |
||||
|
*/ |
||||
|
|
||||
|
/* see falcon.h */ |
||||
|
int PQCLEAN_FALCON512_CLEAN_crypto_sign_keypair_from_seed_rpo( |
||||
|
uint8_t *pk, |
||||
|
uint8_t *sk, |
||||
|
unsigned char *seed |
||||
|
) { |
||||
|
union |
||||
|
{ |
||||
|
uint8_t b[FALCON_KEYGEN_TEMP_9]; |
||||
|
uint64_t dummy_u64; |
||||
|
fpr dummy_fpr; |
||||
|
} tmp; |
||||
|
int8_t f[512], g[512], F[512]; |
||||
|
uint16_t h[512]; |
||||
|
inner_shake256_context rng; |
||||
|
size_t u, v; |
||||
|
|
||||
|
/* |
||||
|
* Generate key pair. |
||||
|
*/ |
||||
|
inner_shake256_init(&rng); |
||||
|
inner_shake256_inject(&rng, seed, sizeof seed); |
||||
|
inner_shake256_flip(&rng); |
||||
|
PQCLEAN_FALCON512_CLEAN_keygen(&rng, f, g, F, NULL, h, 9, tmp.b); |
||||
|
inner_shake256_ctx_release(&rng); |
||||
|
|
||||
|
/* |
||||
|
* Encode private key. |
||||
|
*/ |
||||
|
sk[0] = 0x50 + 9; |
||||
|
u = 1; |
||||
|
v = PQCLEAN_FALCON512_CLEAN_trim_i8_encode( |
||||
|
sk + u, PQCLEAN_FALCON512_CLEAN_CRYPTO_SECRETKEYBYTES - u, |
||||
|
f, 9, PQCLEAN_FALCON512_CLEAN_max_fg_bits[9]); |
||||
|
if (v == 0) |
||||
|
{ |
||||
|
return -1; |
||||
|
} |
||||
|
u += v; |
||||
|
v = PQCLEAN_FALCON512_CLEAN_trim_i8_encode( |
||||
|
sk + u, PQCLEAN_FALCON512_CLEAN_CRYPTO_SECRETKEYBYTES - u, |
||||
|
g, 9, PQCLEAN_FALCON512_CLEAN_max_fg_bits[9]); |
||||
|
if (v == 0) |
||||
|
{ |
||||
|
return -1; |
||||
|
} |
||||
|
u += v; |
||||
|
v = PQCLEAN_FALCON512_CLEAN_trim_i8_encode( |
||||
|
sk + u, PQCLEAN_FALCON512_CLEAN_CRYPTO_SECRETKEYBYTES - u, |
||||
|
F, 9, PQCLEAN_FALCON512_CLEAN_max_FG_bits[9]); |
||||
|
if (v == 0) |
||||
|
{ |
||||
|
return -1; |
||||
|
} |
||||
|
u += v; |
||||
|
if (u != PQCLEAN_FALCON512_CLEAN_CRYPTO_SECRETKEYBYTES) |
||||
|
{ |
||||
|
return -1; |
||||
|
} |
||||
|
|
||||
|
/* |
||||
|
* Encode public key. |
||||
|
*/ |
||||
|
pk[0] = 0x00 + 9; |
||||
|
v = PQCLEAN_FALCON512_CLEAN_modq_encode( |
||||
|
pk + 1, PQCLEAN_FALCON512_CLEAN_CRYPTO_PUBLICKEYBYTES - 1, |
||||
|
h, 9); |
||||
|
if (v != PQCLEAN_FALCON512_CLEAN_CRYPTO_PUBLICKEYBYTES - 1) |
||||
|
{ |
||||
|
return -1; |
||||
|
} |
||||
|
|
||||
|
return 0; |
||||
|
} |
||||
|
|
||||
|
int PQCLEAN_FALCON512_CLEAN_crypto_sign_keypair_rpo( |
||||
|
uint8_t *pk, |
||||
|
uint8_t *sk |
||||
|
) { |
||||
|
unsigned char seed[48]; |
||||
|
|
||||
|
/* |
||||
|
* Generate a random seed. |
||||
|
*/ |
||||
|
randombytes(seed, sizeof seed); |
||||
|
|
||||
|
return PQCLEAN_FALCON512_CLEAN_crypto_sign_keypair_from_seed_rpo(pk, sk, seed); |
||||
|
} |
||||
|
|
||||
|
/* |
||||
|
* Compute the signature. nonce[] receives the nonce and must have length |
||||
|
* NONCELEN bytes. sigbuf[] receives the signature value (without nonce |
||||
|
* or header byte), with *sigbuflen providing the maximum value length and |
||||
|
* receiving the actual value length. |
||||
|
* |
||||
|
* If a signature could be computed but not encoded because it would |
||||
|
* exceed the output buffer size, then a new signature is computed. If |
||||
|
* the provided buffer size is too low, this could loop indefinitely, so |
||||
|
* the caller must provide a size that can accommodate signatures with a |
||||
|
* large enough probability. |
||||
|
* |
||||
|
* Return value: 0 on success, -1 on error. |
||||
|
*/ |
||||
|
static int do_sign( |
||||
|
uint8_t *nonce, |
||||
|
uint8_t *sigbuf, |
||||
|
size_t *sigbuflen, |
||||
|
const uint8_t *m, |
||||
|
size_t mlen, |
||||
|
const uint8_t *sk |
||||
|
) { |
||||
|
union |
||||
|
{ |
||||
|
uint8_t b[72 * 512]; |
||||
|
uint64_t dummy_u64; |
||||
|
fpr dummy_fpr; |
||||
|
} tmp; |
||||
|
int8_t f[512], g[512], F[512], G[512]; |
||||
|
struct |
||||
|
{ |
||||
|
int16_t sig[512]; |
||||
|
uint16_t hm[512]; |
||||
|
} r; |
||||
|
unsigned char seed[48]; |
||||
|
inner_shake256_context sc; |
||||
|
rpo128_context rc; |
||||
|
size_t u, v; |
||||
|
|
||||
|
/* |
||||
|
* Decode the private key. |
||||
|
*/ |
||||
|
if (sk[0] != 0x50 + 9) |
||||
|
{ |
||||
|
return -1; |
||||
|
} |
||||
|
u = 1; |
||||
|
v = PQCLEAN_FALCON512_CLEAN_trim_i8_decode( |
||||
|
f, 9, PQCLEAN_FALCON512_CLEAN_max_fg_bits[9], |
||||
|
sk + u, PQCLEAN_FALCON512_CLEAN_CRYPTO_SECRETKEYBYTES - u); |
||||
|
if (v == 0) |
||||
|
{ |
||||
|
return -1; |
||||
|
} |
||||
|
u += v; |
||||
|
v = PQCLEAN_FALCON512_CLEAN_trim_i8_decode( |
||||
|
g, 9, PQCLEAN_FALCON512_CLEAN_max_fg_bits[9], |
||||
|
sk + u, PQCLEAN_FALCON512_CLEAN_CRYPTO_SECRETKEYBYTES - u); |
||||
|
if (v == 0) |
||||
|
{ |
||||
|
return -1; |
||||
|
} |
||||
|
u += v; |
||||
|
v = PQCLEAN_FALCON512_CLEAN_trim_i8_decode( |
||||
|
F, 9, PQCLEAN_FALCON512_CLEAN_max_FG_bits[9], |
||||
|
sk + u, PQCLEAN_FALCON512_CLEAN_CRYPTO_SECRETKEYBYTES - u); |
||||
|
if (v == 0) |
||||
|
{ |
||||
|
return -1; |
||||
|
} |
||||
|
u += v; |
||||
|
if (u != PQCLEAN_FALCON512_CLEAN_CRYPTO_SECRETKEYBYTES) |
||||
|
{ |
||||
|
return -1; |
||||
|
} |
||||
|
if (!PQCLEAN_FALCON512_CLEAN_complete_private(G, f, g, F, 9, tmp.b)) |
||||
|
{ |
||||
|
return -1; |
||||
|
} |
||||
|
|
||||
|
/* |
||||
|
* Create a random nonce (40 bytes). |
||||
|
*/ |
||||
|
randombytes(nonce, NONCELEN); |
||||
|
|
||||
|
/* ==== Start: Deviation from the reference implementation ================================= */ |
||||
|
|
||||
|
// Transform the nonce into 8 chunks each of size 5 bytes. We do this in order to be sure that |
||||
|
// the conversion to field elements succeeds |
||||
|
uint8_t buffer[64]; |
||||
|
memset(buffer, 0, 64); |
||||
|
for (size_t i = 0; i < 8; i++) |
||||
|
{ |
||||
|
buffer[8 * i] = nonce[5 * i]; |
||||
|
buffer[8 * i + 1] = nonce[5 * i + 1]; |
||||
|
buffer[8 * i + 2] = nonce[5 * i + 2]; |
||||
|
buffer[8 * i + 3] = nonce[5 * i + 3]; |
||||
|
buffer[8 * i + 4] = nonce[5 * i + 4]; |
||||
|
} |
||||
|
|
||||
|
/* |
||||
|
* Hash message nonce + message into a vector. |
||||
|
*/ |
||||
|
rpo128_init(&rc); |
||||
|
rpo128_absorb(&rc, buffer, NONCELEN + 24); |
||||
|
rpo128_absorb(&rc, m, mlen); |
||||
|
rpo128_finalize(&rc); |
||||
|
PQCLEAN_FALCON512_CLEAN_hash_to_point_rpo(&rc, r.hm, 9); |
||||
|
rpo128_release(&rc); |
||||
|
|
||||
|
/* ==== End: Deviation from the reference implementation =================================== */ |
||||
|
|
||||
|
/* |
||||
|
* Initialize a RNG. |
||||
|
*/ |
||||
|
randombytes(seed, sizeof seed); |
||||
|
inner_shake256_init(&sc); |
||||
|
inner_shake256_inject(&sc, seed, sizeof seed); |
||||
|
inner_shake256_flip(&sc); |
||||
|
|
||||
|
/* |
||||
|
* Compute and return the signature. This loops until a signature |
||||
|
* value is found that fits in the provided buffer. |
||||
|
*/ |
||||
|
for (;;) |
||||
|
{ |
||||
|
PQCLEAN_FALCON512_CLEAN_sign_dyn(r.sig, &sc, f, g, F, G, r.hm, 9, tmp.b); |
||||
|
v = PQCLEAN_FALCON512_CLEAN_comp_encode(sigbuf, *sigbuflen, r.sig, 9); |
||||
|
if (v != 0) |
||||
|
{ |
||||
|
inner_shake256_ctx_release(&sc); |
||||
|
*sigbuflen = v; |
||||
|
return 0; |
||||
|
} |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
/* |
||||
|
* Verify a signature. The nonce has size NONCELEN bytes. sigbuf[] |
||||
|
* (of size sigbuflen) contains the signature value, not including the |
||||
|
* header byte or nonce. Return value is 0 on success, -1 on error. |
||||
|
*/ |
||||
|
static int do_verify( |
||||
|
const uint8_t *nonce, |
||||
|
const uint8_t *sigbuf, |
||||
|
size_t sigbuflen, |
||||
|
const uint8_t *m, |
||||
|
size_t mlen, |
||||
|
const uint8_t *pk |
||||
|
) { |
||||
|
union |
||||
|
{ |
||||
|
uint8_t b[2 * 512]; |
||||
|
uint64_t dummy_u64; |
||||
|
fpr dummy_fpr; |
||||
|
} tmp; |
||||
|
uint16_t h[512], hm[512]; |
||||
|
int16_t sig[512]; |
||||
|
rpo128_context rc; |
||||
|
|
||||
|
/* |
||||
|
* Decode public key. |
||||
|
*/ |
||||
|
if (pk[0] != 0x00 + 9) |
||||
|
{ |
||||
|
return -1; |
||||
|
} |
||||
|
if (PQCLEAN_FALCON512_CLEAN_modq_decode(h, 9, |
||||
|
pk + 1, PQCLEAN_FALCON512_CLEAN_CRYPTO_PUBLICKEYBYTES - 1) |
||||
|
!= PQCLEAN_FALCON512_CLEAN_CRYPTO_PUBLICKEYBYTES - 1) |
||||
|
{ |
||||
|
return -1; |
||||
|
} |
||||
|
PQCLEAN_FALCON512_CLEAN_to_ntt_monty(h, 9); |
||||
|
|
||||
|
/* |
||||
|
* Decode signature. |
||||
|
*/ |
||||
|
if (sigbuflen == 0) |
||||
|
{ |
||||
|
return -1; |
||||
|
} |
||||
|
if (PQCLEAN_FALCON512_CLEAN_comp_decode(sig, 9, sigbuf, sigbuflen) != sigbuflen) |
||||
|
{ |
||||
|
return -1; |
||||
|
} |
||||
|
|
||||
|
/* ==== Start: Deviation from the reference implementation ================================= */ |
||||
|
|
||||
|
/* |
||||
|
* Hash nonce + message into a vector. |
||||
|
*/ |
||||
|
|
||||
|
// Transform the nonce into 8 chunks each of size 5 bytes. We do this in order to be sure that |
||||
|
// the conversion to field elements succeeds |
||||
|
uint8_t buffer[64]; |
||||
|
memset(buffer, 0, 64); |
||||
|
for (size_t i = 0; i < 8; i++) |
||||
|
{ |
||||
|
buffer[8 * i] = nonce[5 * i]; |
||||
|
buffer[8 * i + 1] = nonce[5 * i + 1]; |
||||
|
buffer[8 * i + 2] = nonce[5 * i + 2]; |
||||
|
buffer[8 * i + 3] = nonce[5 * i + 3]; |
||||
|
buffer[8 * i + 4] = nonce[5 * i + 4]; |
||||
|
} |
||||
|
|
||||
|
rpo128_init(&rc); |
||||
|
rpo128_absorb(&rc, buffer, NONCELEN + 24); |
||||
|
rpo128_absorb(&rc, m, mlen); |
||||
|
rpo128_finalize(&rc); |
||||
|
PQCLEAN_FALCON512_CLEAN_hash_to_point_rpo(&rc, hm, 9); |
||||
|
rpo128_release(&rc); |
||||
|
|
||||
|
/* === End: Deviation from the reference implementation ==================================== */ |
||||
|
|
||||
|
/* |
||||
|
* Verify signature. |
||||
|
*/ |
||||
|
if (!PQCLEAN_FALCON512_CLEAN_verify_raw(hm, sig, h, 9, tmp.b)) |
||||
|
{ |
||||
|
return -1; |
||||
|
} |
||||
|
return 0; |
||||
|
} |
||||
|
|
||||
|
/* see falcon.h */ |
||||
|
int PQCLEAN_FALCON512_CLEAN_crypto_sign_signature_rpo( |
||||
|
uint8_t *sig, |
||||
|
size_t *siglen, |
||||
|
const uint8_t *m, |
||||
|
size_t mlen, |
||||
|
const uint8_t *sk |
||||
|
) { |
||||
|
/* |
||||
|
* The PQCLEAN_FALCON512_CLEAN_CRYPTO_BYTES constant is used for |
||||
|
* the signed message object (as produced by crypto_sign()) |
||||
|
* and includes a two-byte length value, so we take care here |
||||
|
* to only generate signatures that are two bytes shorter than |
||||
|
* the maximum. This is done to ensure that crypto_sign() |
||||
|
* and crypto_sign_signature() produce the exact same signature |
||||
|
* value, if used on the same message, with the same private key, |
||||
|
* and using the same output from randombytes() (this is for |
||||
|
* reproducibility of tests). |
||||
|
*/ |
||||
|
size_t vlen; |
||||
|
|
||||
|
vlen = PQCLEAN_FALCON512_CLEAN_CRYPTO_BYTES - NONCELEN - 3; |
||||
|
if (do_sign(sig + 1, sig + 1 + NONCELEN, &vlen, m, mlen, sk) < 0) |
||||
|
{ |
||||
|
return -1; |
||||
|
} |
||||
|
sig[0] = 0x30 + 9; |
||||
|
*siglen = 1 + NONCELEN + vlen; |
||||
|
return 0; |
||||
|
} |
||||
|
|
||||
|
/* see falcon.h */ |
||||
|
int PQCLEAN_FALCON512_CLEAN_crypto_sign_verify_rpo( |
||||
|
const uint8_t *sig, |
||||
|
size_t siglen, |
||||
|
const uint8_t *m, |
||||
|
size_t mlen, |
||||
|
const uint8_t *pk |
||||
|
) { |
||||
|
if (siglen < 1 + NONCELEN) |
||||
|
{ |
||||
|
return -1; |
||||
|
} |
||||
|
if (sig[0] != 0x30 + 9) |
||||
|
{ |
||||
|
return -1; |
||||
|
} |
||||
|
return do_verify(sig + 1, sig + 1 + NONCELEN, siglen - 1 - NONCELEN, m, mlen, pk); |
||||
|
} |
@ -0,0 +1,66 @@ |
|||||
|
#include <stddef.h> |
||||
|
#include <stdint.h> |
||||
|
|
||||
|
#define PQCLEAN_FALCON512_CLEAN_CRYPTO_SECRETKEYBYTES 1281 |
||||
|
#define PQCLEAN_FALCON512_CLEAN_CRYPTO_PUBLICKEYBYTES 897 |
||||
|
#define PQCLEAN_FALCON512_CLEAN_CRYPTO_BYTES 666 |
||||
|
|
||||
|
/* |
||||
|
* Generate a new key pair. Public key goes into pk[], private key in sk[]. |
||||
|
* Key sizes are exact (in bytes): |
||||
|
* public (pk): PQCLEAN_FALCON512_CLEAN_CRYPTO_PUBLICKEYBYTES |
||||
|
* private (sk): PQCLEAN_FALCON512_CLEAN_CRYPTO_SECRETKEYBYTES |
||||
|
* |
||||
|
* Return value: 0 on success, -1 on error. |
||||
|
* |
||||
|
* Note: This implementation follows the reference implementation in PQClean |
||||
|
* https://github.com/PQClean/PQClean/tree/master/crypto_sign/falcon-512 |
||||
|
* verbatim except for the sections that are marked otherwise. |
||||
|
*/ |
||||
|
int PQCLEAN_FALCON512_CLEAN_crypto_sign_keypair_rpo( |
||||
|
uint8_t *pk, uint8_t *sk); |
||||
|
|
||||
|
/* |
||||
|
* Generate a new key pair from seed. Public key goes into pk[], private key in sk[]. |
||||
|
* Key sizes are exact (in bytes): |
||||
|
* public (pk): PQCLEAN_FALCON512_CLEAN_CRYPTO_PUBLICKEYBYTES |
||||
|
* private (sk): PQCLEAN_FALCON512_CLEAN_CRYPTO_SECRETKEYBYTES |
||||
|
* |
||||
|
* Return value: 0 on success, -1 on error. |
||||
|
*/ |
||||
|
int PQCLEAN_FALCON512_CLEAN_crypto_sign_keypair_from_seed_rpo( |
||||
|
uint8_t *pk, uint8_t *sk, unsigned char *seed); |
||||
|
|
||||
|
/* |
||||
|
* Compute a signature on a provided message (m, mlen), with a given |
||||
|
* private key (sk). Signature is written in sig[], with length written |
||||
|
* into *siglen. Signature length is variable; maximum signature length |
||||
|
* (in bytes) is PQCLEAN_FALCON512_CLEAN_CRYPTO_BYTES. |
||||
|
* |
||||
|
* sig[], m[] and sk[] may overlap each other arbitrarily. |
||||
|
* |
||||
|
* Return value: 0 on success, -1 on error. |
||||
|
* |
||||
|
* Note: This implementation follows the reference implementation in PQClean |
||||
|
* https://github.com/PQClean/PQClean/tree/master/crypto_sign/falcon-512 |
||||
|
* verbatim except for the sections that are marked otherwise. |
||||
|
*/ |
||||
|
int PQCLEAN_FALCON512_CLEAN_crypto_sign_signature_rpo( |
||||
|
uint8_t *sig, size_t *siglen, |
||||
|
const uint8_t *m, size_t mlen, const uint8_t *sk); |
||||
|
|
||||
|
/* |
||||
|
* Verify a signature (sig, siglen) on a message (m, mlen) with a given |
||||
|
* public key (pk). |
||||
|
* |
||||
|
* sig[], m[] and pk[] may overlap each other arbitrarily. |
||||
|
* |
||||
|
* Return value: 0 on success, -1 on error. |
||||
|
* |
||||
|
* Note: This implementation follows the reference implementation in PQClean |
||||
|
* https://github.com/PQClean/PQClean/tree/master/crypto_sign/falcon-512 |
||||
|
* verbatim except for the sections that are marked otherwise. |
||||
|
*/ |
||||
|
int PQCLEAN_FALCON512_CLEAN_crypto_sign_verify_rpo( |
||||
|
const uint8_t *sig, size_t siglen, |
||||
|
const uint8_t *m, size_t mlen, const uint8_t *pk); |
@ -0,0 +1,582 @@ |
|||||
|
/* |
||||
|
* RPO implementation. |
||||
|
*/ |
||||
|
|
||||
|
#include <stdint.h> |
||||
|
#include <string.h> |
||||
|
#include <stdlib.h> |
||||
|
|
||||
|
/* ================================================================================================ |
||||
|
* Modular Arithmetic |
||||
|
*/ |
||||
|
|
||||
|
#define P 0xFFFFFFFF00000001 |
||||
|
#define M 12289 |
||||
|
|
||||
|
// From https://github.com/ncw/iprime/blob/master/mod_math_noasm.go |
||||
|
static uint64_t add_mod_p(uint64_t a, uint64_t b) |
||||
|
{ |
||||
|
a = P - a; |
||||
|
uint64_t res = b - a; |
||||
|
if (b < a) |
||||
|
res += P; |
||||
|
return res; |
||||
|
} |
||||
|
|
||||
|
static uint64_t sub_mod_p(uint64_t a, uint64_t b) |
||||
|
{ |
||||
|
uint64_t r = a - b; |
||||
|
if (a < b) |
||||
|
r += P; |
||||
|
return r; |
||||
|
} |
||||
|
|
||||
|
static uint64_t reduce_mod_p(uint64_t b, uint64_t a) |
||||
|
{ |
||||
|
uint32_t d = b >> 32, |
||||
|
c = b; |
||||
|
if (a >= P) |
||||
|
a -= P; |
||||
|
a = sub_mod_p(a, c); |
||||
|
a = sub_mod_p(a, d); |
||||
|
a = add_mod_p(a, ((uint64_t)c) << 32); |
||||
|
return a; |
||||
|
} |
||||
|
|
||||
|
static uint64_t mult_mod_p(uint64_t x, uint64_t y) |
||||
|
{ |
||||
|
uint32_t a = x, |
||||
|
b = x >> 32, |
||||
|
c = y, |
||||
|
d = y >> 32; |
||||
|
|
||||
|
/* first synthesize the product using 32*32 -> 64 bit multiplies */ |
||||
|
x = b * (uint64_t)c; /* b*c */ |
||||
|
y = a * (uint64_t)d; /* a*d */ |
||||
|
uint64_t e = a * (uint64_t)c, /* a*c */ |
||||
|
f = b * (uint64_t)d, /* b*d */ |
||||
|
t; |
||||
|
|
||||
|
x += y; /* b*c + a*d */ |
||||
|
/* carry? */ |
||||
|
if (x < y) |
||||
|
f += 1LL << 32; /* carry into upper 32 bits - can't overflow */ |
||||
|
|
||||
|
t = x << 32; |
||||
|
e += t; /* a*c + LSW(b*c + a*d) */ |
||||
|
/* carry? */ |
||||
|
if (e < t) |
||||
|
f += 1; /* carry into upper 64 bits - can't overflow*/ |
||||
|
t = x >> 32; |
||||
|
f += t; /* b*d + MSW(b*c + a*d) */ |
||||
|
/* can't overflow */ |
||||
|
|
||||
|
/* now reduce: (b*d + MSW(b*c + a*d), a*c + LSW(b*c + a*d)) */ |
||||
|
return reduce_mod_p(f, e); |
||||
|
} |
||||
|
|
||||
|
/* ================================================================================================ |
||||
|
* RPO128 Permutation |
||||
|
*/ |
||||
|
|
||||
|
static const uint64_t STATE_WIDTH = 12; |
||||
|
static const uint64_t NUM_ROUNDS = 7; |
||||
|
|
||||
|
/* |
||||
|
* MDS matrix |
||||
|
*/ |
||||
|
static const uint64_t MDS[12][12] = { |
||||
|
{ 7, 23, 8, 26, 13, 10, 9, 7, 6, 22, 21, 8 }, |
||||
|
{ 8, 7, 23, 8, 26, 13, 10, 9, 7, 6, 22, 21 }, |
||||
|
{ 21, 8, 7, 23, 8, 26, 13, 10, 9, 7, 6, 22 }, |
||||
|
{ 22, 21, 8, 7, 23, 8, 26, 13, 10, 9, 7, 6 }, |
||||
|
{ 6, 22, 21, 8, 7, 23, 8, 26, 13, 10, 9, 7 }, |
||||
|
{ 7, 6, 22, 21, 8, 7, 23, 8, 26, 13, 10, 9 }, |
||||
|
{ 9, 7, 6, 22, 21, 8, 7, 23, 8, 26, 13, 10 }, |
||||
|
{ 10, 9, 7, 6, 22, 21, 8, 7, 23, 8, 26, 13 }, |
||||
|
{ 13, 10, 9, 7, 6, 22, 21, 8, 7, 23, 8, 26 }, |
||||
|
{ 26, 13, 10, 9, 7, 6, 22, 21, 8, 7, 23, 8 }, |
||||
|
{ 8, 26, 13, 10, 9, 7, 6, 22, 21, 8, 7, 23 }, |
||||
|
{ 23, 8, 26, 13, 10, 9, 7, 6, 22, 21, 8, 7 }, |
||||
|
}; |
||||
|
|
||||
|
/* |
||||
|
* Round constants. |
||||
|
*/ |
||||
|
static const uint64_t ARK1[7][12] = { |
||||
|
{ |
||||
|
5789762306288267392ULL, |
||||
|
6522564764413701783ULL, |
||||
|
17809893479458208203ULL, |
||||
|
107145243989736508ULL, |
||||
|
6388978042437517382ULL, |
||||
|
15844067734406016715ULL, |
||||
|
9975000513555218239ULL, |
||||
|
3344984123768313364ULL, |
||||
|
9959189626657347191ULL, |
||||
|
12960773468763563665ULL, |
||||
|
9602914297752488475ULL, |
||||
|
16657542370200465908ULL, |
||||
|
}, |
||||
|
{ |
||||
|
12987190162843096997ULL, |
||||
|
653957632802705281ULL, |
||||
|
4441654670647621225ULL, |
||||
|
4038207883745915761ULL, |
||||
|
5613464648874830118ULL, |
||||
|
13222989726778338773ULL, |
||||
|
3037761201230264149ULL, |
||||
|
16683759727265180203ULL, |
||||
|
8337364536491240715ULL, |
||||
|
3227397518293416448ULL, |
||||
|
8110510111539674682ULL, |
||||
|
2872078294163232137ULL, |
||||
|
}, |
||||
|
{ |
||||
|
18072785500942327487ULL, |
||||
|
6200974112677013481ULL, |
||||
|
17682092219085884187ULL, |
||||
|
10599526828986756440ULL, |
||||
|
975003873302957338ULL, |
||||
|
8264241093196931281ULL, |
||||
|
10065763900435475170ULL, |
||||
|
2181131744534710197ULL, |
||||
|
6317303992309418647ULL, |
||||
|
1401440938888741532ULL, |
||||
|
8884468225181997494ULL, |
||||
|
13066900325715521532ULL, |
||||
|
}, |
||||
|
{ |
||||
|
5674685213610121970ULL, |
||||
|
5759084860419474071ULL, |
||||
|
13943282657648897737ULL, |
||||
|
1352748651966375394ULL, |
||||
|
17110913224029905221ULL, |
||||
|
1003883795902368422ULL, |
||||
|
4141870621881018291ULL, |
||||
|
8121410972417424656ULL, |
||||
|
14300518605864919529ULL, |
||||
|
13712227150607670181ULL, |
||||
|
17021852944633065291ULL, |
||||
|
6252096473787587650ULL, |
||||
|
}, |
||||
|
{ |
||||
|
4887609836208846458ULL, |
||||
|
3027115137917284492ULL, |
||||
|
9595098600469470675ULL, |
||||
|
10528569829048484079ULL, |
||||
|
7864689113198939815ULL, |
||||
|
17533723827845969040ULL, |
||||
|
5781638039037710951ULL, |
||||
|
17024078752430719006ULL, |
||||
|
109659393484013511ULL, |
||||
|
7158933660534805869ULL, |
||||
|
2955076958026921730ULL, |
||||
|
7433723648458773977ULL, |
||||
|
}, |
||||
|
{ |
||||
|
16308865189192447297ULL, |
||||
|
11977192855656444890ULL, |
||||
|
12532242556065780287ULL, |
||||
|
14594890931430968898ULL, |
||||
|
7291784239689209784ULL, |
||||
|
5514718540551361949ULL, |
||||
|
10025733853830934803ULL, |
||||
|
7293794580341021693ULL, |
||||
|
6728552937464861756ULL, |
||||
|
6332385040983343262ULL, |
||||
|
13277683694236792804ULL, |
||||
|
2600778905124452676ULL, |
||||
|
}, |
||||
|
{ |
||||
|
7123075680859040534ULL, |
||||
|
1034205548717903090ULL, |
||||
|
7717824418247931797ULL, |
||||
|
3019070937878604058ULL, |
||||
|
11403792746066867460ULL, |
||||
|
10280580802233112374ULL, |
||||
|
337153209462421218ULL, |
||||
|
13333398568519923717ULL, |
||||
|
3596153696935337464ULL, |
||||
|
8104208463525993784ULL, |
||||
|
14345062289456085693ULL, |
||||
|
17036731477169661256ULL, |
||||
|
}}; |
||||
|
|
||||
|
const uint64_t ARK2[7][12] = { |
||||
|
{ |
||||
|
6077062762357204287ULL, |
||||
|
15277620170502011191ULL, |
||||
|
5358738125714196705ULL, |
||||
|
14233283787297595718ULL, |
||||
|
13792579614346651365ULL, |
||||
|
11614812331536767105ULL, |
||||
|
14871063686742261166ULL, |
||||
|
10148237148793043499ULL, |
||||
|
4457428952329675767ULL, |
||||
|
15590786458219172475ULL, |
||||
|
10063319113072092615ULL, |
||||
|
14200078843431360086ULL, |
||||
|
}, |
||||
|
{ |
||||
|
6202948458916099932ULL, |
||||
|
17690140365333231091ULL, |
||||
|
3595001575307484651ULL, |
||||
|
373995945117666487ULL, |
||||
|
1235734395091296013ULL, |
||||
|
14172757457833931602ULL, |
||||
|
707573103686350224ULL, |
||||
|
15453217512188187135ULL, |
||||
|
219777875004506018ULL, |
||||
|
17876696346199469008ULL, |
||||
|
17731621626449383378ULL, |
||||
|
2897136237748376248ULL, |
||||
|
}, |
||||
|
{ |
||||
|
8023374565629191455ULL, |
||||
|
15013690343205953430ULL, |
||||
|
4485500052507912973ULL, |
||||
|
12489737547229155153ULL, |
||||
|
9500452585969030576ULL, |
||||
|
2054001340201038870ULL, |
||||
|
12420704059284934186ULL, |
||||
|
355990932618543755ULL, |
||||
|
9071225051243523860ULL, |
||||
|
12766199826003448536ULL, |
||||
|
9045979173463556963ULL, |
||||
|
12934431667190679898ULL, |
||||
|
}, |
||||
|
{ |
||||
|
18389244934624494276ULL, |
||||
|
16731736864863925227ULL, |
||||
|
4440209734760478192ULL, |
||||
|
17208448209698888938ULL, |
||||
|
8739495587021565984ULL, |
||||
|
17000774922218161967ULL, |
||||
|
13533282547195532087ULL, |
||||
|
525402848358706231ULL, |
||||
|
16987541523062161972ULL, |
||||
|
5466806524462797102ULL, |
||||
|
14512769585918244983ULL, |
||||
|
10973956031244051118ULL, |
||||
|
}, |
||||
|
{ |
||||
|
6982293561042362913ULL, |
||||
|
14065426295947720331ULL, |
||||
|
16451845770444974180ULL, |
||||
|
7139138592091306727ULL, |
||||
|
9012006439959783127ULL, |
||||
|
14619614108529063361ULL, |
||||
|
1394813199588124371ULL, |
||||
|
4635111139507788575ULL, |
||||
|
16217473952264203365ULL, |
||||
|
10782018226466330683ULL, |
||||
|
6844229992533662050ULL, |
||||
|
7446486531695178711ULL, |
||||
|
}, |
||||
|
{ |
||||
|
3736792340494631448ULL, |
||||
|
577852220195055341ULL, |
||||
|
6689998335515779805ULL, |
||||
|
13886063479078013492ULL, |
||||
|
14358505101923202168ULL, |
||||
|
7744142531772274164ULL, |
||||
|
16135070735728404443ULL, |
||||
|
12290902521256031137ULL, |
||||
|
12059913662657709804ULL, |
||||
|
16456018495793751911ULL, |
||||
|
4571485474751953524ULL, |
||||
|
17200392109565783176ULL, |
||||
|
}, |
||||
|
{ |
||||
|
17130398059294018733ULL, |
||||
|
519782857322261988ULL, |
||||
|
9625384390925085478ULL, |
||||
|
1664893052631119222ULL, |
||||
|
7629576092524553570ULL, |
||||
|
3485239601103661425ULL, |
||||
|
9755891797164033838ULL, |
||||
|
15218148195153269027ULL, |
||||
|
16460604813734957368ULL, |
||||
|
9643968136937729763ULL, |
||||
|
3611348709641382851ULL, |
||||
|
18256379591337759196ULL, |
||||
|
}, |
||||
|
}; |
||||
|
|
||||
|
static void apply_sbox(uint64_t *const state) |
||||
|
{ |
||||
|
for (uint64_t i = 0; i < STATE_WIDTH; i++) |
||||
|
{ |
||||
|
uint64_t t2 = mult_mod_p(*(state + i), *(state + i)); |
||||
|
uint64_t t4 = mult_mod_p(t2, t2); |
||||
|
|
||||
|
*(state + i) = mult_mod_p(*(state + i), mult_mod_p(t2, t4)); |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
static void apply_mds(uint64_t *state) |
||||
|
{ |
||||
|
uint64_t res[STATE_WIDTH]; |
||||
|
for (uint64_t i = 0; i < STATE_WIDTH; i++) |
||||
|
{ |
||||
|
res[i] = 0; |
||||
|
} |
||||
|
for (uint64_t i = 0; i < STATE_WIDTH; i++) |
||||
|
{ |
||||
|
for (uint64_t j = 0; j < STATE_WIDTH; j++) |
||||
|
{ |
||||
|
res[i] = add_mod_p(res[i], mult_mod_p(MDS[i][j], *(state + j))); |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
for (uint64_t i = 0; i < STATE_WIDTH; i++) |
||||
|
{ |
||||
|
*(state + i) = res[i]; |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
static void apply_constants(uint64_t *const state, const uint64_t *ark) |
||||
|
{ |
||||
|
for (uint64_t i = 0; i < STATE_WIDTH; i++) |
||||
|
{ |
||||
|
*(state + i) = add_mod_p(*(state + i), *(ark + i)); |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
static void exp_acc(const uint64_t m, const uint64_t *base, const uint64_t *tail, uint64_t *const res) |
||||
|
{ |
||||
|
for (uint64_t i = 0; i < m; i++) |
||||
|
{ |
||||
|
for (uint64_t j = 0; j < STATE_WIDTH; j++) |
||||
|
{ |
||||
|
if (i == 0) |
||||
|
{ |
||||
|
*(res + j) = mult_mod_p(*(base + j), *(base + j)); |
||||
|
} |
||||
|
else |
||||
|
{ |
||||
|
*(res + j) = mult_mod_p(*(res + j), *(res + j)); |
||||
|
} |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
for (uint64_t i = 0; i < STATE_WIDTH; i++) |
||||
|
{ |
||||
|
*(res + i) = mult_mod_p(*(res + i), *(tail + i)); |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
static void apply_inv_sbox(uint64_t *const state) |
||||
|
{ |
||||
|
uint64_t t1[STATE_WIDTH]; |
||||
|
for (uint64_t i = 0; i < STATE_WIDTH; i++) |
||||
|
{ |
||||
|
t1[i] = 0; |
||||
|
} |
||||
|
for (uint64_t i = 0; i < STATE_WIDTH; i++) |
||||
|
{ |
||||
|
t1[i] = mult_mod_p(*(state + i), *(state + i)); |
||||
|
} |
||||
|
|
||||
|
uint64_t t2[STATE_WIDTH]; |
||||
|
for (uint64_t i = 0; i < STATE_WIDTH; i++) |
||||
|
{ |
||||
|
t2[i] = 0; |
||||
|
} |
||||
|
for (uint64_t i = 0; i < STATE_WIDTH; i++) |
||||
|
{ |
||||
|
t2[i] = mult_mod_p(t1[i], t1[i]); |
||||
|
} |
||||
|
|
||||
|
uint64_t t3[STATE_WIDTH]; |
||||
|
for (uint64_t i = 0; i < STATE_WIDTH; i++) |
||||
|
{ |
||||
|
t3[i] = 0; |
||||
|
} |
||||
|
exp_acc(3, t2, t2, t3); |
||||
|
|
||||
|
uint64_t t4[STATE_WIDTH]; |
||||
|
for (uint64_t i = 0; i < STATE_WIDTH; i++) |
||||
|
{ |
||||
|
t4[i] = 0; |
||||
|
} |
||||
|
exp_acc(6, t3, t3, t4); |
||||
|
|
||||
|
uint64_t tmp[STATE_WIDTH]; |
||||
|
for (uint64_t i = 0; i < STATE_WIDTH; i++) |
||||
|
{ |
||||
|
tmp[i] = 0; |
||||
|
} |
||||
|
exp_acc(12, t4, t4, tmp); |
||||
|
|
||||
|
uint64_t t5[STATE_WIDTH]; |
||||
|
for (uint64_t i = 0; i < STATE_WIDTH; i++) |
||||
|
{ |
||||
|
t5[i] = 0; |
||||
|
} |
||||
|
exp_acc(6, tmp, t3, t5); |
||||
|
|
||||
|
uint64_t t6[STATE_WIDTH]; |
||||
|
for (uint64_t i = 0; i < STATE_WIDTH; i++) |
||||
|
{ |
||||
|
t6[i] = 0; |
||||
|
} |
||||
|
exp_acc(31, t5, t5, t6); |
||||
|
|
||||
|
for (uint64_t i = 0; i < STATE_WIDTH; i++) |
||||
|
{ |
||||
|
uint64_t a = mult_mod_p(mult_mod_p(t6[i], t6[i]), t5[i]); |
||||
|
a = mult_mod_p(a, a); |
||||
|
a = mult_mod_p(a, a); |
||||
|
uint64_t b = mult_mod_p(mult_mod_p(t1[i], t2[i]), *(state + i)); |
||||
|
|
||||
|
*(state + i) = mult_mod_p(a, b); |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
static void apply_round(uint64_t *const state, const uint64_t round) |
||||
|
{ |
||||
|
apply_mds(state); |
||||
|
apply_constants(state, ARK1[round]); |
||||
|
apply_sbox(state); |
||||
|
|
||||
|
apply_mds(state); |
||||
|
apply_constants(state, ARK2[round]); |
||||
|
apply_inv_sbox(state); |
||||
|
} |
||||
|
|
||||
|
static void apply_permutation(uint64_t *state) |
||||
|
{ |
||||
|
for (uint64_t i = 0; i < NUM_ROUNDS; i++) |
||||
|
{ |
||||
|
apply_round(state, i); |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
/* ================================================================================================ |
||||
|
* RPO128 implementation. This is supposed to substitute SHAKE256 in the hash-to-point algorithm. |
||||
|
*/ |
||||
|
|
||||
|
#include "rpo.h" |
||||
|
|
||||
|
void rpo128_init(rpo128_context *rc) |
||||
|
{ |
||||
|
rc->dptr = 32; |
||||
|
|
||||
|
memset(rc->st.A, 0, sizeof rc->st.A); |
||||
|
} |
||||
|
|
||||
|
void rpo128_absorb(rpo128_context *rc, const uint8_t *in, size_t len) |
||||
|
{ |
||||
|
size_t dptr; |
||||
|
|
||||
|
dptr = (size_t)rc->dptr; |
||||
|
while (len > 0) |
||||
|
{ |
||||
|
size_t clen, u; |
||||
|
|
||||
|
/* 136 * 8 = 1088 bit for the rate portion in the case of SHAKE256 |
||||
|
* For RPO, this is 64 * 8 = 512 bits |
||||
|
* The capacity for SHAKE256 is at the end while for RPO128 it is at the beginning |
||||
|
*/ |
||||
|
clen = 96 - dptr; |
||||
|
if (clen > len) |
||||
|
{ |
||||
|
clen = len; |
||||
|
} |
||||
|
|
||||
|
for (u = 0; u < clen; u++) |
||||
|
{ |
||||
|
rc->st.dbuf[dptr + u] = in[u]; |
||||
|
} |
||||
|
|
||||
|
dptr += clen; |
||||
|
in += clen; |
||||
|
len -= clen; |
||||
|
if (dptr == 96) |
||||
|
{ |
||||
|
apply_permutation(rc->st.A); |
||||
|
dptr = 32; |
||||
|
} |
||||
|
} |
||||
|
rc->dptr = dptr; |
||||
|
} |
||||
|
|
||||
|
void rpo128_finalize(rpo128_context *rc) |
||||
|
{ |
||||
|
// Set dptr to the end of the buffer, so that first call to extract will call the permutation. |
||||
|
rc->dptr = 96; |
||||
|
} |
||||
|
|
||||
|
void rpo128_squeeze(rpo128_context *rc, uint8_t *out, size_t len) |
||||
|
{ |
||||
|
size_t dptr; |
||||
|
|
||||
|
dptr = (size_t)rc->dptr; |
||||
|
while (len > 0) |
||||
|
{ |
||||
|
size_t clen; |
||||
|
|
||||
|
if (dptr == 96) |
||||
|
{ |
||||
|
apply_permutation(rc->st.A); |
||||
|
dptr = 32; |
||||
|
} |
||||
|
clen = 96 - dptr; |
||||
|
if (clen > len) |
||||
|
{ |
||||
|
clen = len; |
||||
|
} |
||||
|
len -= clen; |
||||
|
|
||||
|
memcpy(out, rc->st.dbuf + dptr, clen); |
||||
|
dptr += clen; |
||||
|
out += clen; |
||||
|
} |
||||
|
rc->dptr = dptr; |
||||
|
} |
||||
|
|
||||
|
void rpo128_release(rpo128_context *rc) |
||||
|
{ |
||||
|
memset(rc->st.A, 0, sizeof rc->st.A); |
||||
|
rc->dptr = 32; |
||||
|
} |
||||
|
|
||||
|
/* ================================================================================================ |
||||
|
* Hash-to-Point algorithm implementation based on RPO128 |
||||
|
*/ |
||||
|
|
||||
|
void PQCLEAN_FALCON512_CLEAN_hash_to_point_rpo(rpo128_context *rc, uint16_t *x, unsigned logn) |
||||
|
{ |
||||
|
/* |
||||
|
* This implementation avoids the rejection sampling step needed in the |
||||
|
* per-the-spec implementation. It uses a remark in https://falcon-sign.info/falcon.pdf |
||||
|
* page 31, which argues that the current variant is secure for the parameters set by NIST. |
||||
|
* Avoiding the rejection-sampling step leads to an implementation that is constant-time. |
||||
|
* TODO: Check that the current implementation is indeed constant-time. |
||||
|
*/ |
||||
|
size_t n; |
||||
|
|
||||
|
n = (size_t)1 << logn; |
||||
|
while (n > 0) |
||||
|
{ |
||||
|
uint8_t buf[8]; |
||||
|
uint64_t w; |
||||
|
|
||||
|
rpo128_squeeze(rc, (void *)buf, sizeof buf); |
||||
|
w = ((uint64_t)(buf[7]) << 56) | |
||||
|
((uint64_t)(buf[6]) << 48) | |
||||
|
((uint64_t)(buf[5]) << 40) | |
||||
|
((uint64_t)(buf[4]) << 32) | |
||||
|
((uint64_t)(buf[3]) << 24) | |
||||
|
((uint64_t)(buf[2]) << 16) | |
||||
|
((uint64_t)(buf[1]) << 8) | |
||||
|
((uint64_t)(buf[0])); |
||||
|
|
||||
|
w %= M; |
||||
|
|
||||
|
*x++ = (uint16_t)w; |
||||
|
n--; |
||||
|
} |
||||
|
} |
@ -0,0 +1,83 @@ |
|||||
|
#include <stdint.h> |
||||
|
#include <string.h> |
||||
|
|
||||
|
/* ================================================================================================ |
||||
|
* RPO hashing algorithm related structs and methods. |
||||
|
*/ |
||||
|
|
||||
|
/* |
||||
|
* RPO128 context. |
||||
|
* |
||||
|
* This structure is used by the hashing API. It is composed of an internal state that can be |
||||
|
* viewed as either: |
||||
|
* 1. 12 field elements in the Miden VM. |
||||
|
* 2. 96 bytes. |
||||
|
* |
||||
|
* The first view is used for the internal state in the context of the RPO hashing algorithm. The |
||||
|
* second view is used for the buffer used to absorb the data to be hashed. |
||||
|
* |
||||
|
* The pointer to the buffer is updated as the data is absorbed. |
||||
|
* |
||||
|
* 'rpo128_context' must be initialized with rpo128_init() before first use. |
||||
|
*/ |
||||
|
typedef struct |
||||
|
{ |
||||
|
union |
||||
|
{ |
||||
|
uint64_t A[12]; |
||||
|
uint8_t dbuf[96]; |
||||
|
} st; |
||||
|
uint64_t dptr; |
||||
|
} rpo128_context; |
||||
|
|
||||
|
/* |
||||
|
* Initializes an RPO state |
||||
|
*/ |
||||
|
void rpo128_init(rpo128_context *rc); |
||||
|
|
||||
|
/* |
||||
|
* Absorbs an array of bytes of length 'len' into the state. |
||||
|
*/ |
||||
|
void rpo128_absorb(rpo128_context *rc, const uint8_t *in, size_t len); |
||||
|
|
||||
|
/* |
||||
|
* Squeezes an array of bytes of length 'len' from the state. |
||||
|
*/ |
||||
|
void rpo128_squeeze(rpo128_context *rc, uint8_t *out, size_t len); |
||||
|
|
||||
|
/* |
||||
|
* Finalizes the state in preparation for squeezing. |
||||
|
* |
||||
|
* This function should be called after all the data has been absorbed. |
||||
|
* |
||||
|
* Note that the current implementation does not perform any sort of padding for domain separation |
||||
|
* purposes. The reason being that, for our purposes, we always perform the following sequence: |
||||
|
* 1. Absorb a Nonce (which is always 40 bytes packed as 8 field elements). |
||||
|
* 2. Absorb the message (which is always 4 field elements). |
||||
|
* 3. Call finalize. |
||||
|
* 4. Squeeze the output. |
||||
|
* 5. Call release. |
||||
|
*/ |
||||
|
void rpo128_finalize(rpo128_context *rc); |
||||
|
|
||||
|
/* |
||||
|
* Releases the state. |
||||
|
* |
||||
|
* This function should be called after the squeeze operation is finished. |
||||
|
*/ |
||||
|
void rpo128_release(rpo128_context *rc); |
||||
|
|
||||
|
/* ================================================================================================ |
||||
|
* Hash-to-Point algorithm for signature generation and signature verification. |
||||
|
*/ |
||||
|
|
||||
|
/* |
||||
|
* Hash-to-Point algorithm. |
||||
|
* |
||||
|
* This function generates a point in Z_q[x]/(phi) from a given message. |
||||
|
* |
||||
|
* It takes a finalized rpo128_context as input and it generates the coefficients of the polynomial |
||||
|
* representing the point. The coefficients are stored in the array 'x'. The number of coefficients |
||||
|
* is given by 'logn', which must in our case is 512. |
||||
|
*/ |
||||
|
void PQCLEAN_FALCON512_CLEAN_hash_to_point_rpo(rpo128_context *rc, uint16_t *x, unsigned logn); |
@ -0,0 +1,189 @@ |
|||||
|
use libc::c_int;
|
||||
|
|
||||
|
// C IMPLEMENTATION INTERFACE
|
||||
|
// ================================================================================================
|
||||
|
|
||||
|
#[link(name = "rpo_falcon512", kind = "static")]
|
||||
|
extern "C" {
|
||||
|
/// Generate a new key pair. Public key goes into pk[], private key in sk[].
|
||||
|
/// Key sizes are exact (in bytes):
|
||||
|
/// - public (pk): 897
|
||||
|
/// - private (sk): 1281
|
||||
|
///
|
||||
|
/// Return value: 0 on success, -1 on error.
|
||||
|
pub fn PQCLEAN_FALCON512_CLEAN_crypto_sign_keypair_rpo(pk: *mut u8, sk: *mut u8) -> c_int;
|
||||
|
|
||||
|
/// Generate a new key pair from seed. Public key goes into pk[], private key in sk[].
|
||||
|
/// Key sizes are exact (in bytes):
|
||||
|
/// - public (pk): 897
|
||||
|
/// - private (sk): 1281
|
||||
|
///
|
||||
|
/// Return value: 0 on success, -1 on error.
|
||||
|
pub fn PQCLEAN_FALCON512_CLEAN_crypto_sign_keypair_from_seed_rpo(
|
||||
|
pk: *mut u8,
|
||||
|
sk: *mut u8,
|
||||
|
seed: *const u8,
|
||||
|
) -> c_int;
|
||||
|
|
||||
|
/// Compute a signature on a provided message (m, mlen), with a given private key (sk).
|
||||
|
/// Signature is written in sig[], with length written into *siglen. Signature length is
|
||||
|
/// variable; maximum signature length (in bytes) is 666.
|
||||
|
///
|
||||
|
/// sig[], m[] and sk[] may overlap each other arbitrarily.
|
||||
|
///
|
||||
|
/// Return value: 0 on success, -1 on error.
|
||||
|
pub fn PQCLEAN_FALCON512_CLEAN_crypto_sign_signature_rpo(
|
||||
|
sig: *mut u8,
|
||||
|
siglen: *mut usize,
|
||||
|
m: *const u8,
|
||||
|
mlen: usize,
|
||||
|
sk: *const u8,
|
||||
|
) -> c_int;
|
||||
|
|
||||
|
// TEST HELPERS
|
||||
|
// --------------------------------------------------------------------------------------------
|
||||
|
|
||||
|
/// Verify a signature (sig, siglen) on a message (m, mlen) with a given public key (pk).
|
||||
|
///
|
||||
|
/// sig[], m[] and pk[] may overlap each other arbitrarily.
|
||||
|
///
|
||||
|
/// Return value: 0 on success, -1 on error.
|
||||
|
#[cfg(test)]
|
||||
|
pub fn PQCLEAN_FALCON512_CLEAN_crypto_sign_verify_rpo(
|
||||
|
sig: *const u8,
|
||||
|
siglen: usize,
|
||||
|
m: *const u8,
|
||||
|
mlen: usize,
|
||||
|
pk: *const u8,
|
||||
|
) -> c_int;
|
||||
|
|
||||
|
/// Hash-to-Point algorithm.
|
||||
|
///
|
||||
|
/// This function generates a point in Z_q[x]/(phi) from a given message.
|
||||
|
///
|
||||
|
/// It takes a finalized rpo128_context as input and it generates the coefficients of the polynomial
|
||||
|
/// representing the point. The coefficients are stored in the array 'x'. The number of coefficients
|
||||
|
/// is given by 'logn', which must in our case is 512.
|
||||
|
#[cfg(test)]
|
||||
|
pub fn PQCLEAN_FALCON512_CLEAN_hash_to_point_rpo(
|
||||
|
rc: *mut Rpo128Context,
|
||||
|
x: *mut u16,
|
||||
|
logn: usize,
|
||||
|
);
|
||||
|
|
||||
|
#[cfg(test)]
|
||||
|
pub fn rpo128_init(sc: *mut Rpo128Context);
|
||||
|
|
||||
|
#[cfg(test)]
|
||||
|
pub fn rpo128_absorb(
|
||||
|
sc: *mut Rpo128Context,
|
||||
|
data: *const ::std::os::raw::c_void,
|
||||
|
len: libc::size_t,
|
||||
|
);
|
||||
|
|
||||
|
#[cfg(test)]
|
||||
|
pub fn rpo128_finalize(sc: *mut Rpo128Context);
|
||||
|
}
|
||||
|
|
||||
|
#[repr(C)]
|
||||
|
#[cfg(test)]
|
||||
|
pub struct Rpo128Context {
|
||||
|
pub content: [u64; 13usize],
|
||||
|
}
|
||||
|
|
||||
|
// TESTS
|
||||
|
// ================================================================================================
|
||||
|
|
||||
|
#[cfg(all(test, feature = "std"))]
|
||||
|
mod tests {
|
||||
|
use super::*;
|
||||
|
use crate::dsa::rpo_falcon512::{NONCE_LEN, PK_LEN, SIG_LEN, SK_LEN};
|
||||
|
use rand_utils::{rand_array, rand_value, rand_vector};
|
||||
|
|
||||
|
#[test]
|
||||
|
fn falcon_ffi() {
|
||||
|
unsafe {
|
||||
|
//let mut rng = rand::thread_rng();
|
||||
|
|
||||
|
// --- generate a key pair from a seed ----------------------------
|
||||
|
|
||||
|
let mut pk = [0u8; PK_LEN];
|
||||
|
let mut sk = [0u8; SK_LEN];
|
||||
|
let seed: [u8; NONCE_LEN] = rand_array();
|
||||
|
|
||||
|
assert_eq!(
|
||||
|
0,
|
||||
|
PQCLEAN_FALCON512_CLEAN_crypto_sign_keypair_from_seed_rpo(
|
||||
|
pk.as_mut_ptr(),
|
||||
|
sk.as_mut_ptr(),
|
||||
|
seed.as_ptr()
|
||||
|
)
|
||||
|
);
|
||||
|
|
||||
|
// --- sign a message and make sure it verifies -------------------
|
||||
|
|
||||
|
let mlen: usize = rand_value::<u16>() as usize;
|
||||
|
let msg: Vec<u8> = rand_vector(mlen);
|
||||
|
let mut detached_sig = [0u8; NONCE_LEN + SIG_LEN];
|
||||
|
let mut siglen = 0;
|
||||
|
|
||||
|
assert_eq!(
|
||||
|
0,
|
||||
|
PQCLEAN_FALCON512_CLEAN_crypto_sign_signature_rpo(
|
||||
|
detached_sig.as_mut_ptr(),
|
||||
|
&mut siglen as *mut usize,
|
||||
|
msg.as_ptr(),
|
||||
|
msg.len(),
|
||||
|
sk.as_ptr()
|
||||
|
)
|
||||
|
);
|
||||
|
|
||||
|
assert_eq!(
|
||||
|
0,
|
||||
|
PQCLEAN_FALCON512_CLEAN_crypto_sign_verify_rpo(
|
||||
|
detached_sig.as_ptr(),
|
||||
|
siglen,
|
||||
|
msg.as_ptr(),
|
||||
|
msg.len(),
|
||||
|
pk.as_ptr()
|
||||
|
)
|
||||
|
);
|
||||
|
|
||||
|
// --- check verification of different signature ------------------
|
||||
|
|
||||
|
assert_eq!(
|
||||
|
-1,
|
||||
|
PQCLEAN_FALCON512_CLEAN_crypto_sign_verify_rpo(
|
||||
|
detached_sig.as_ptr(),
|
||||
|
siglen,
|
||||
|
msg.as_ptr(),
|
||||
|
msg.len() - 1,
|
||||
|
pk.as_ptr()
|
||||
|
)
|
||||
|
);
|
||||
|
|
||||
|
// --- check verification against a different pub key -------------
|
||||
|
|
||||
|
let mut pk_alt = [0u8; PK_LEN];
|
||||
|
let mut sk_alt = [0u8; SK_LEN];
|
||||
|
assert_eq!(
|
||||
|
0,
|
||||
|
PQCLEAN_FALCON512_CLEAN_crypto_sign_keypair_rpo(
|
||||
|
pk_alt.as_mut_ptr(),
|
||||
|
sk_alt.as_mut_ptr()
|
||||
|
)
|
||||
|
);
|
||||
|
|
||||
|
assert_eq!(
|
||||
|
-1,
|
||||
|
PQCLEAN_FALCON512_CLEAN_crypto_sign_verify_rpo(
|
||||
|
detached_sig.as_ptr(),
|
||||
|
siglen,
|
||||
|
msg.as_ptr(),
|
||||
|
msg.len(),
|
||||
|
pk_alt.as_ptr()
|
||||
|
)
|
||||
|
);
|
||||
|
}
|
||||
|
}
|
||||
|
}
|
@ -0,0 +1,227 @@ |
|||||
|
use super::{
|
||||
|
ByteReader, ByteWriter, Deserializable, DeserializationError, FalconError, Polynomial,
|
||||
|
PublicKeyBytes, Rpo256, SecretKeyBytes, Serializable, Signature, Word,
|
||||
|
};
|
||||
|
|
||||
|
#[cfg(feature = "std")]
|
||||
|
use super::{ffi, NonceBytes, StarkField, NONCE_LEN, PK_LEN, SIG_LEN, SK_LEN};
|
||||
|
|
||||
|
// PUBLIC KEY
|
||||
|
// ================================================================================================
|
||||
|
|
||||
|
/// A public key for verifying signatures.
|
||||
|
///
|
||||
|
/// The public key is a [Word] (i.e., 4 field elements) that is the hash of the coefficients of
|
||||
|
/// the polynomial representing the raw bytes of the expanded public key.
|
||||
|
///
|
||||
|
/// For Falcon-512, the first byte of the expanded public key is always equal to log2(512) i.e., 9.
|
||||
|
#[derive(Debug, Clone, Copy, PartialEq)]
|
||||
|
pub struct PublicKey(Word);
|
||||
|
|
||||
|
impl PublicKey {
|
||||
|
/// Returns a new [PublicKey] which is a commitment to the provided expanded public key.
|
||||
|
///
|
||||
|
/// # Errors
|
||||
|
/// Returns an error if the decoding of the public key fails.
|
||||
|
pub fn new(pk: PublicKeyBytes) -> Result<Self, FalconError> {
|
||||
|
let h = Polynomial::from_pub_key(&pk)?;
|
||||
|
let pk_felts = h.to_elements();
|
||||
|
let pk_digest = Rpo256::hash_elements(&pk_felts).into();
|
||||
|
Ok(Self(pk_digest))
|
||||
|
}
|
||||
|
|
||||
|
/// Verifies the provided signature against provided message and this public key.
|
||||
|
pub fn verify(&self, message: Word, signature: &Signature) -> bool {
|
||||
|
signature.verify(message, self.0)
|
||||
|
}
|
||||
|
}
|
||||
|
|
||||
|
impl From<PublicKey> for Word {
|
||||
|
fn from(key: PublicKey) -> Self {
|
||||
|
key.0
|
||||
|
}
|
||||
|
}
|
||||
|
|
||||
|
// KEY PAIR
|
||||
|
// ================================================================================================
|
||||
|
|
||||
|
/// A key pair (public and secret keys) for signing messages.
|
||||
|
///
|
||||
|
/// The secret key is a byte array of length [PK_LEN].
|
||||
|
/// The public key is a byte array of length [SK_LEN].
|
||||
|
#[derive(Debug, Clone, Copy, PartialEq)]
|
||||
|
pub struct KeyPair {
|
||||
|
public_key: PublicKeyBytes,
|
||||
|
secret_key: SecretKeyBytes,
|
||||
|
}
|
||||
|
|
||||
|
#[allow(clippy::new_without_default)]
|
||||
|
impl KeyPair {
|
||||
|
// CONSTRUCTORS
|
||||
|
// --------------------------------------------------------------------------------------------
|
||||
|
|
||||
|
/// Generates a (public_key, secret_key) key pair from OS-provided randomness.
|
||||
|
///
|
||||
|
/// # Errors
|
||||
|
/// Returns an error if key generation fails.
|
||||
|
#[cfg(feature = "std")]
|
||||
|
pub fn new() -> Result<Self, FalconError> {
|
||||
|
let mut public_key = [0u8; PK_LEN];
|
||||
|
let mut secret_key = [0u8; SK_LEN];
|
||||
|
|
||||
|
let res = unsafe {
|
||||
|
ffi::PQCLEAN_FALCON512_CLEAN_crypto_sign_keypair_rpo(
|
||||
|
public_key.as_mut_ptr(),
|
||||
|
secret_key.as_mut_ptr(),
|
||||
|
)
|
||||
|
};
|
||||
|
|
||||
|
if res == 0 {
|
||||
|
Ok(Self { public_key, secret_key })
|
||||
|
} else {
|
||||
|
Err(FalconError::KeyGenerationFailed)
|
||||
|
}
|
||||
|
}
|
||||
|
|
||||
|
/// Generates a (public_key, secret_key) key pair from the provided seed.
|
||||
|
///
|
||||
|
/// # Errors
|
||||
|
/// Returns an error if key generation fails.
|
||||
|
#[cfg(feature = "std")]
|
||||
|
pub fn from_seed(seed: &NonceBytes) -> Result<Self, FalconError> {
|
||||
|
let mut public_key = [0u8; PK_LEN];
|
||||
|
let mut secret_key = [0u8; SK_LEN];
|
||||
|
|
||||
|
let res = unsafe {
|
||||
|
ffi::PQCLEAN_FALCON512_CLEAN_crypto_sign_keypair_from_seed_rpo(
|
||||
|
public_key.as_mut_ptr(),
|
||||
|
secret_key.as_mut_ptr(),
|
||||
|
seed.as_ptr(),
|
||||
|
)
|
||||
|
};
|
||||
|
|
||||
|
if res == 0 {
|
||||
|
Ok(Self { public_key, secret_key })
|
||||
|
} else {
|
||||
|
Err(FalconError::KeyGenerationFailed)
|
||||
|
}
|
||||
|
}
|
||||
|
|
||||
|
// PUBLIC ACCESSORS
|
||||
|
// --------------------------------------------------------------------------------------------
|
||||
|
|
||||
|
/// Returns the public key corresponding to this key pair.
|
||||
|
pub fn public_key(&self) -> PublicKey {
|
||||
|
// TODO: memoize public key commitment as computing it requires quite a bit of hashing.
|
||||
|
// expect() is fine here because we assume that the key pair was constructed correctly.
|
||||
|
PublicKey::new(self.public_key).expect("invalid key pair")
|
||||
|
}
|
||||
|
|
||||
|
/// Returns the expanded public key corresponding to this key pair.
|
||||
|
pub fn expanded_public_key(&self) -> PublicKeyBytes {
|
||||
|
self.public_key
|
||||
|
}
|
||||
|
|
||||
|
// SIGNATURE GENERATION
|
||||
|
// --------------------------------------------------------------------------------------------
|
||||
|
|
||||
|
/// Signs a message with a secret key and a seed.
|
||||
|
///
|
||||
|
/// # Errors
|
||||
|
/// Returns an error of signature generation fails.
|
||||
|
#[cfg(feature = "std")]
|
||||
|
pub fn sign(&self, message: Word) -> Result<Signature, FalconError> {
|
||||
|
let msg = message.iter().flat_map(|e| e.as_int().to_le_bytes()).collect::<Vec<_>>();
|
||||
|
let msg_len = msg.len();
|
||||
|
let mut sig = [0_u8; SIG_LEN + NONCE_LEN];
|
||||
|
let mut sig_len: usize = 0;
|
||||
|
|
||||
|
let res = unsafe {
|
||||
|
ffi::PQCLEAN_FALCON512_CLEAN_crypto_sign_signature_rpo(
|
||||
|
sig.as_mut_ptr(),
|
||||
|
&mut sig_len as *mut usize,
|
||||
|
msg.as_ptr(),
|
||||
|
msg_len,
|
||||
|
self.secret_key.as_ptr(),
|
||||
|
)
|
||||
|
};
|
||||
|
|
||||
|
if res == 0 {
|
||||
|
Ok(Signature { sig, pk: self.public_key })
|
||||
|
} else {
|
||||
|
Err(FalconError::SigGenerationFailed)
|
||||
|
}
|
||||
|
}
|
||||
|
}
|
||||
|
|
||||
|
// SERIALIZATION / DESERIALIZATION
|
||||
|
// ================================================================================================
|
||||
|
|
||||
|
impl Serializable for KeyPair {
|
||||
|
fn write_into<W: ByteWriter>(&self, target: &mut W) {
|
||||
|
target.write_bytes(&self.public_key);
|
||||
|
target.write_bytes(&self.secret_key);
|
||||
|
}
|
||||
|
}
|
||||
|
|
||||
|
impl Deserializable for KeyPair {
|
||||
|
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
|
||||
|
let public_key: PublicKeyBytes = source.read_array()?;
|
||||
|
let secret_key: SecretKeyBytes = source.read_array()?;
|
||||
|
Ok(Self { public_key, secret_key })
|
||||
|
}
|
||||
|
}
|
||||
|
|
||||
|
// TESTS
|
||||
|
// ================================================================================================
|
||||
|
|
||||
|
#[cfg(all(test, feature = "std"))]
|
||||
|
mod tests {
|
||||
|
use super::{super::Felt, KeyPair, NonceBytes, Word};
|
||||
|
use rand_utils::{rand_array, rand_vector};
|
||||
|
|
||||
|
#[test]
|
||||
|
fn test_falcon_verification() {
|
||||
|
// generate random keys
|
||||
|
let keys = KeyPair::new().unwrap();
|
||||
|
let pk = keys.public_key();
|
||||
|
|
||||
|
// sign a random message
|
||||
|
let message: Word = rand_vector::<Felt>(4).try_into().expect("Should not fail.");
|
||||
|
let signature = keys.sign(message);
|
||||
|
|
||||
|
// make sure the signature verifies correctly
|
||||
|
assert!(pk.verify(message, signature.as_ref().unwrap()));
|
||||
|
|
||||
|
// a signature should not verify against a wrong message
|
||||
|
let message2: Word = rand_vector::<Felt>(4).try_into().expect("Should not fail.");
|
||||
|
assert!(!pk.verify(message2, signature.as_ref().unwrap()));
|
||||
|
|
||||
|
// a signature should not verify against a wrong public key
|
||||
|
let keys2 = KeyPair::new().unwrap();
|
||||
|
assert!(!keys2.public_key().verify(message, signature.as_ref().unwrap()))
|
||||
|
}
|
||||
|
|
||||
|
#[test]
|
||||
|
fn test_falcon_verification_from_seed() {
|
||||
|
// generate keys from a random seed
|
||||
|
let seed: NonceBytes = rand_array();
|
||||
|
let keys = KeyPair::from_seed(&seed).unwrap();
|
||||
|
let pk = keys.public_key();
|
||||
|
|
||||
|
// sign a random message
|
||||
|
let message: Word = rand_vector::<Felt>(4).try_into().expect("Should not fail.");
|
||||
|
let signature = keys.sign(message);
|
||||
|
|
||||
|
// make sure the signature verifies correctly
|
||||
|
assert!(pk.verify(message, signature.as_ref().unwrap()));
|
||||
|
|
||||
|
// a signature should not verify against a wrong message
|
||||
|
let message2: Word = rand_vector::<Felt>(4).try_into().expect("Should not fail.");
|
||||
|
assert!(!pk.verify(message2, signature.as_ref().unwrap()));
|
||||
|
|
||||
|
// a signature should not verify against a wrong public key
|
||||
|
let keys2 = KeyPair::new().unwrap();
|
||||
|
assert!(!keys2.public_key().verify(message, signature.as_ref().unwrap()))
|
||||
|
}
|
||||
|
}
|
@ -0,0 +1,60 @@ |
|||||
|
use crate::{
|
||||
|
hash::rpo::Rpo256,
|
||||
|
utils::{
|
||||
|
collections::Vec, ByteReader, ByteWriter, Deserializable, DeserializationError,
|
||||
|
Serializable,
|
||||
|
},
|
||||
|
Felt, StarkField, Word, ZERO,
|
||||
|
};
|
||||
|
|
||||
|
#[cfg(feature = "std")]
|
||||
|
mod ffi;
|
||||
|
|
||||
|
mod error;
|
||||
|
mod keys;
|
||||
|
mod polynomial;
|
||||
|
mod signature;
|
||||
|
|
||||
|
pub use error::FalconError;
|
||||
|
pub use keys::{KeyPair, PublicKey};
|
||||
|
pub use polynomial::Polynomial;
|
||||
|
pub use signature::Signature;
|
||||
|
|
||||
|
// CONSTANTS
|
||||
|
// ================================================================================================
|
||||
|
|
||||
|
// The Falcon modulus.
|
||||
|
const MODULUS: u16 = 12289;
|
||||
|
const MODULUS_MINUS_1_OVER_TWO: u16 = 6144;
|
||||
|
|
||||
|
// The Falcon parameters for Falcon-512. This is the degree of the polynomial `phi := x^N + 1`
|
||||
|
// defining the ring Z_p[x]/(phi).
|
||||
|
const N: usize = 512;
|
||||
|
const LOG_N: usize = 9;
|
||||
|
|
||||
|
/// Length of nonce used for key-pair generation.
|
||||
|
const NONCE_LEN: usize = 40;
|
||||
|
|
||||
|
/// Number of filed elements used to encode a nonce.
|
||||
|
const NONCE_ELEMENTS: usize = 8;
|
||||
|
|
||||
|
/// Public key length as a u8 vector.
|
||||
|
const PK_LEN: usize = 897;
|
||||
|
|
||||
|
/// Secret key length as a u8 vector.
|
||||
|
const SK_LEN: usize = 1281;
|
||||
|
|
||||
|
/// Signature length as a u8 vector.
|
||||
|
const SIG_LEN: usize = 626;
|
||||
|
|
||||
|
/// Bound on the squared-norm of the signature.
|
||||
|
const SIG_L2_BOUND: u64 = 34034726;
|
||||
|
|
||||
|
// TYPE ALIASES
|
||||
|
// ================================================================================================
|
||||
|
|
||||
|
type SignatureBytes = [u8; NONCE_LEN + SIG_LEN];
|
||||
|
type PublicKeyBytes = [u8; PK_LEN];
|
||||
|
type SecretKeyBytes = [u8; SK_LEN];
|
||||
|
type NonceBytes = [u8; NONCE_LEN];
|
||||
|
type NonceElements = [Felt; NONCE_ELEMENTS];
|
@ -0,0 +1,277 @@ |
|||||
|
use super::{FalconError, Felt, Vec, LOG_N, MODULUS, MODULUS_MINUS_1_OVER_TWO, N, PK_LEN};
|
||||
|
use core::ops::{Add, Mul, Sub};
|
||||
|
|
||||
|
// FALCON POLYNOMIAL
|
||||
|
// ================================================================================================
|
||||
|
|
||||
|
/// A polynomial over Z_p[x]/(phi) where phi := x^512 + 1
|
||||
|
#[derive(Debug, Copy, Clone, PartialEq)]
|
||||
|
pub struct Polynomial([u16; N]);
|
||||
|
|
||||
|
impl Polynomial {
|
||||
|
// CONSTRUCTORS
|
||||
|
// --------------------------------------------------------------------------------------------
|
||||
|
|
||||
|
/// Constructs a new polynomial from a list of coefficients.
|
||||
|
///
|
||||
|
/// # Safety
|
||||
|
/// This constructor validates that the coefficients are in the valid range only in debug mode.
|
||||
|
pub unsafe fn new(data: [u16; N]) -> Self {
|
||||
|
for value in data {
|
||||
|
debug_assert!(value < MODULUS);
|
||||
|
}
|
||||
|
|
||||
|
Self(data)
|
||||
|
}
|
||||
|
|
||||
|
/// Decodes raw bytes representing a public key into a polynomial in Z_p[x]/(phi).
|
||||
|
///
|
||||
|
/// # Errors
|
||||
|
/// Returns an error if:
|
||||
|
/// - The provided input is not exactly 897 bytes long.
|
||||
|
/// - The first byte of the input is not equal to log2(512) i.e., 9.
|
||||
|
/// - Any of the coefficients encoded in the provided input is greater than or equal to the
|
||||
|
/// Falcon field modulus.
|
||||
|
pub fn from_pub_key(input: &[u8]) -> Result<Self, FalconError> {
|
||||
|
if input.len() != PK_LEN {
|
||||
|
return Err(FalconError::PubKeyDecodingInvalidLength(input.len()));
|
||||
|
}
|
||||
|
|
||||
|
if input[0] != LOG_N as u8 {
|
||||
|
return Err(FalconError::PubKeyDecodingInvalidTag(input[0]));
|
||||
|
}
|
||||
|
|
||||
|
let mut acc = 0_u32;
|
||||
|
let mut acc_len = 0;
|
||||
|
|
||||
|
let mut output = [0_u16; N];
|
||||
|
let mut output_idx = 0;
|
||||
|
|
||||
|
for &byte in input.iter().skip(1) {
|
||||
|
acc = (acc << 8) | (byte as u32);
|
||||
|
acc_len += 8;
|
||||
|
|
||||
|
if acc_len >= 14 {
|
||||
|
acc_len -= 14;
|
||||
|
let w = (acc >> acc_len) & 0x3FFF;
|
||||
|
if w >= MODULUS as u32 {
|
||||
|
return Err(FalconError::PubKeyDecodingInvalidCoefficient(w));
|
||||
|
}
|
||||
|
output[output_idx] = w as u16;
|
||||
|
output_idx += 1;
|
||||
|
}
|
||||
|
}
|
||||
|
|
||||
|
if (acc & ((1u32 << acc_len) - 1)) == 0 {
|
||||
|
Ok(Self(output))
|
||||
|
} else {
|
||||
|
Err(FalconError::PubKeyDecodingExtraData)
|
||||
|
}
|
||||
|
}
|
||||
|
|
||||
|
/// Decodes the signature into the coefficients of a polynomial in Z_p[x]/(phi). It assumes
|
||||
|
/// that the signature has been encoded using the uncompressed format.
|
||||
|
///
|
||||
|
/// # Errors
|
||||
|
/// Returns an error if:
|
||||
|
/// - The signature has been encoded using a different algorithm than the reference compressed
|
||||
|
/// encoding algorithm.
|
||||
|
/// - The encoded signature polynomial is in Z_p[x]/(phi') where phi' = x^N' + 1 and N' != 512.
|
||||
|
/// - While decoding the high bits of a coefficient, the current accumulated value of its
|
||||
|
/// high bits is larger than 2048.
|
||||
|
/// - The decoded coefficient is -0.
|
||||
|
/// - The remaining unused bits in the last byte of `input` are non-zero.
|
||||
|
pub fn from_signature(input: &[u8]) -> Result<Self, FalconError> {
|
||||
|
let (encoding, log_n) = (input[0] >> 4, input[0] & 0b00001111);
|
||||
|
|
||||
|
if encoding != 0b0011 {
|
||||
|
return Err(FalconError::SigDecodingIncorrectEncodingAlgorithm);
|
||||
|
}
|
||||
|
if log_n != 0b1001 {
|
||||
|
return Err(FalconError::SigDecodingNotSupportedDegree(log_n));
|
||||
|
}
|
||||
|
|
||||
|
let input = &input[41..];
|
||||
|
let mut input_idx = 0;
|
||||
|
let mut acc = 0u32;
|
||||
|
let mut acc_len = 0;
|
||||
|
let mut output = [0_u16; N];
|
||||
|
|
||||
|
for e in output.iter_mut() {
|
||||
|
acc = (acc << 8) | (input[input_idx] as u32);
|
||||
|
input_idx += 1;
|
||||
|
let b = acc >> acc_len;
|
||||
|
let s = b & 128;
|
||||
|
let mut m = b & 127;
|
||||
|
|
||||
|
loop {
|
||||
|
if acc_len == 0 {
|
||||
|
acc = (acc << 8) | (input[input_idx] as u32);
|
||||
|
input_idx += 1;
|
||||
|
acc_len = 8;
|
||||
|
}
|
||||
|
acc_len -= 1;
|
||||
|
if ((acc >> acc_len) & 1) != 0 {
|
||||
|
break;
|
||||
|
}
|
||||
|
m += 128;
|
||||
|
if m >= 2048 {
|
||||
|
return Err(FalconError::SigDecodingTooBigHighBits(m));
|
||||
|
}
|
||||
|
}
|
||||
|
if s != 0 && m == 0 {
|
||||
|
return Err(FalconError::SigDecodingMinusZero);
|
||||
|
}
|
||||
|
|
||||
|
*e = if s != 0 { (MODULUS as u32 - m) as u16 } else { m as u16 };
|
||||
|
}
|
||||
|
|
||||
|
if (acc & ((1 << acc_len) - 1)) != 0 {
|
||||
|
return Err(FalconError::SigDecodingNonZeroUnusedBitsLastByte);
|
||||
|
}
|
||||
|
|
||||
|
Ok(Self(output))
|
||||
|
}
|
||||
|
|
||||
|
// PUBLIC ACCESSORS
|
||||
|
// --------------------------------------------------------------------------------------------
|
||||
|
|
||||
|
/// Returns the coefficients of this polynomial as integers.
|
||||
|
pub fn inner(&self) -> [u16; N] {
|
||||
|
self.0
|
||||
|
}
|
||||
|
|
||||
|
/// Returns the coefficients of this polynomial as field elements.
|
||||
|
pub fn to_elements(&self) -> Vec<Felt> {
|
||||
|
self.0.iter().map(|&a| Felt::from(a)).collect()
|
||||
|
}
|
||||
|
|
||||
|
// POLYNOMIAL OPERATIONS
|
||||
|
// --------------------------------------------------------------------------------------------
|
||||
|
|
||||
|
/// Multiplies two polynomials over Z_p[x] without reducing modulo p. Given that the degrees
|
||||
|
/// of the input polynomials are less than 512 and their coefficients are less than the modulus
|
||||
|
/// q equal to 12289, the resulting product polynomial is guaranteed to have coefficients less
|
||||
|
/// than the Miden prime.
|
||||
|
///
|
||||
|
/// Note that this multiplication is not over Z_p[x]/(phi).
|
||||
|
pub fn mul_modulo_p(a: &Self, b: &Self) -> [u64; 1024] {
|
||||
|
let mut c = [0; 2 * N];
|
||||
|
for i in 0..N {
|
||||
|
for j in 0..N {
|
||||
|
c[i + j] += a.0[i] as u64 * b.0[j] as u64;
|
||||
|
}
|
||||
|
}
|
||||
|
|
||||
|
c
|
||||
|
}
|
||||
|
|
||||
|
/// Reduces a polynomial, that is the product of two polynomials over Z_p[x], modulo
|
||||
|
/// the irreducible polynomial phi. This results in an element in Z_p[x]/(phi).
|
||||
|
pub fn reduce_negacyclic(a: &[u64; 1024]) -> Self {
|
||||
|
let mut c = [0; N];
|
||||
|
for i in 0..N {
|
||||
|
let ai = a[N + i] % MODULUS as u64;
|
||||
|
let neg_ai = (MODULUS - ai as u16) % MODULUS;
|
||||
|
|
||||
|
let bi = (a[i] % MODULUS as u64) as u16;
|
||||
|
c[i] = (neg_ai + bi) % MODULUS;
|
||||
|
}
|
||||
|
|
||||
|
Self(c)
|
||||
|
}
|
||||
|
|
||||
|
/// Computes the norm squared of a polynomial in Z_p[x]/(phi) after normalizing its
|
||||
|
/// coefficients to be in the interval (-p/2, p/2].
|
||||
|
pub fn sq_norm(&self) -> u64 {
|
||||
|
let mut res = 0;
|
||||
|
for e in self.0 {
|
||||
|
if e > MODULUS_MINUS_1_OVER_TWO {
|
||||
|
res += (MODULUS - e) as u64 * (MODULUS - e) as u64
|
||||
|
} else {
|
||||
|
res += e as u64 * e as u64
|
||||
|
}
|
||||
|
}
|
||||
|
res
|
||||
|
}
|
||||
|
}
|
||||
|
|
||||
|
// Returns a polynomial representing the zero polynomial i.e. default element.
|
||||
|
impl Default for Polynomial {
|
||||
|
fn default() -> Self {
|
||||
|
Self([0_u16; N])
|
||||
|
}
|
||||
|
}
|
||||
|
|
||||
|
/// Multiplication over Z_p[x]/(phi)
|
||||
|
impl Mul for Polynomial {
|
||||
|
type Output = Self;
|
||||
|
|
||||
|
fn mul(self, other: Self) -> <Self as Mul<Self>>::Output {
|
||||
|
let mut result = [0_u16; N];
|
||||
|
for j in 0..N {
|
||||
|
for k in 0..N {
|
||||
|
let i = (j + k) % N;
|
||||
|
let a = self.0[j] as usize;
|
||||
|
let b = other.0[k] as usize;
|
||||
|
let q = MODULUS as usize;
|
||||
|
let mut prod = a * b % q;
|
||||
|
if (N - 1) < (j + k) {
|
||||
|
prod = (q - prod) % q;
|
||||
|
}
|
||||
|
result[i] = ((result[i] as usize + prod) % q) as u16;
|
||||
|
}
|
||||
|
}
|
||||
|
|
||||
|
Polynomial(result)
|
||||
|
}
|
||||
|
}
|
||||
|
|
||||
|
/// Addition over Z_p[x]/(phi)
|
||||
|
impl Add for Polynomial {
|
||||
|
type Output = Self;
|
||||
|
|
||||
|
fn add(self, other: Self) -> <Self as Add<Self>>::Output {
|
||||
|
let mut res = self;
|
||||
|
res.0.iter_mut().zip(other.0.iter()).for_each(|(x, y)| *x = (*x + *y) % MODULUS);
|
||||
|
|
||||
|
res
|
||||
|
}
|
||||
|
}
|
||||
|
|
||||
|
/// Subtraction over Z_p[x]/(phi)
|
||||
|
impl Sub for Polynomial {
|
||||
|
type Output = Self;
|
||||
|
|
||||
|
fn sub(self, other: Self) -> <Self as Add<Self>>::Output {
|
||||
|
let mut res = self;
|
||||
|
res.0
|
||||
|
.iter_mut()
|
||||
|
.zip(other.0.iter())
|
||||
|
.for_each(|(x, y)| *x = (*x + MODULUS - *y) % MODULUS);
|
||||
|
|
||||
|
res
|
||||
|
}
|
||||
|
}
|
||||
|
|
||||
|
// TESTS
|
||||
|
// ================================================================================================
|
||||
|
|
||||
|
#[cfg(test)]
|
||||
|
mod tests {
|
||||
|
use super::{Polynomial, N};
|
||||
|
|
||||
|
#[test]
|
||||
|
fn test_negacyclic_reduction() {
|
||||
|
let coef1: [u16; N] = rand_utils::rand_array();
|
||||
|
let coef2: [u16; N] = rand_utils::rand_array();
|
||||
|
|
||||
|
let poly1 = Polynomial(coef1);
|
||||
|
let poly2 = Polynomial(coef2);
|
||||
|
|
||||
|
assert_eq!(
|
||||
|
poly1 * poly2,
|
||||
|
Polynomial::reduce_negacyclic(&Polynomial::mul_modulo_p(&poly1, &poly2))
|
||||
|
);
|
||||
|
}
|
||||
|
}
|
@ -0,0 +1,262 @@ |
|||||
|
use super::{
|
||||
|
ByteReader, ByteWriter, Deserializable, DeserializationError, NonceBytes, NonceElements,
|
||||
|
Polynomial, PublicKeyBytes, Rpo256, Serializable, SignatureBytes, StarkField, Word, MODULUS, N,
|
||||
|
SIG_L2_BOUND, ZERO,
|
||||
|
};
|
||||
|
use crate::utils::string::ToString;
|
||||
|
|
||||
|
// FALCON SIGNATURE
|
||||
|
// ================================================================================================
|
||||
|
|
||||
|
/// An RPO Falcon512 signature over a message.
|
||||
|
///
|
||||
|
/// The signature is a pair of polynomials (s1, s2) in (Z_p[x]/(phi))^2, where:
|
||||
|
/// - p := 12289
|
||||
|
/// - phi := x^512 + 1
|
||||
|
/// - s1 = c - s2 * h
|
||||
|
/// - h is a polynomial representing the public key and c is a polynomial that is the hash-to-point
|
||||
|
/// of the message being signed.
|
||||
|
///
|
||||
|
/// The signature verifies if and only if:
|
||||
|
/// 1. s1 = c - s2 * h
|
||||
|
/// 2. |s1|^2 + |s2|^2 <= SIG_L2_BOUND
|
||||
|
///
|
||||
|
/// where |.| is the norm.
|
||||
|
///
|
||||
|
/// [Signature] also includes the extended public key which is serialized as:
|
||||
|
/// 1. 1 byte representing the log2(512) i.e., 9.
|
||||
|
/// 2. 896 bytes for the public key. This is decoded into the `h` polynomial above.
|
||||
|
///
|
||||
|
/// The actual signature is serialized as:
|
||||
|
/// 1. A header byte specifying the algorithm used to encode the coefficients of the `s2` polynomial
|
||||
|
/// together with the degree of the irreducible polynomial phi.
|
||||
|
/// The general format of this byte is 0b0cc1nnnn where:
|
||||
|
/// a. cc is either 01 when the compressed encoding algorithm is used and 10 when the
|
||||
|
/// uncompressed algorithm is used.
|
||||
|
/// b. nnnn is log2(N) where N is the degree of the irreducible polynomial phi.
|
||||
|
/// The current implementation works always with cc equal to 0b01 and nnnn equal to 0b1001 and
|
||||
|
/// thus the header byte is always equal to 0b00111001.
|
||||
|
/// 2. 40 bytes for the nonce.
|
||||
|
/// 3. 625 bytes encoding the `s2` polynomial above.
|
||||
|
///
|
||||
|
/// The total size of the signature (including the extended public key) is 1563 bytes.
|
||||
|
pub struct Signature {
|
||||
|
pub(super) pk: PublicKeyBytes,
|
||||
|
pub(super) sig: SignatureBytes,
|
||||
|
}
|
||||
|
|
||||
|
impl Signature {
|
||||
|
// PUBLIC ACCESSORS
|
||||
|
// --------------------------------------------------------------------------------------------
|
||||
|
|
||||
|
/// Returns the public key polynomial h.
|
||||
|
pub fn pub_key_poly(&self) -> Polynomial {
|
||||
|
// TODO: memoize
|
||||
|
// we assume that the signature was constructed with a valid public key, and thus
|
||||
|
// expect() is OK here.
|
||||
|
Polynomial::from_pub_key(&self.pk).expect("invalid public key")
|
||||
|
}
|
||||
|
|
||||
|
/// Returns the nonce component of the signature represented as field elements.
|
||||
|
///
|
||||
|
/// Nonce bytes are converted to field elements by taking consecutive 5 byte chunks
|
||||
|
/// of the nonce and interpreting them as field elements.
|
||||
|
pub fn nonce(&self) -> NonceElements {
|
||||
|
// we assume that the signature was constructed with a valid signature, and thus
|
||||
|
// expect() is OK here.
|
||||
|
let nonce = self.sig[1..41].try_into().expect("invalid signature");
|
||||
|
decode_nonce(nonce)
|
||||
|
}
|
||||
|
|
||||
|
// Returns the polynomial representation of the signature in Z_p[x]/(phi).
|
||||
|
pub fn sig_poly(&self) -> Polynomial {
|
||||
|
// TODO: memoize
|
||||
|
// we assume that the signature was constructed with a valid signature, and thus
|
||||
|
// expect() is OK here.
|
||||
|
Polynomial::from_signature(&self.sig).expect("invalid signature")
|
||||
|
}
|
||||
|
|
||||
|
// HASH-TO-POINT
|
||||
|
// --------------------------------------------------------------------------------------------
|
||||
|
|
||||
|
/// Returns a polynomial in Z_p[x]/(phi) representing the hash of the provided message.
|
||||
|
pub fn hash_to_point(&self, message: Word) -> Polynomial {
|
||||
|
hash_to_point(message, &self.nonce())
|
||||
|
}
|
||||
|
|
||||
|
// SIGNATURE VERIFICATION
|
||||
|
// --------------------------------------------------------------------------------------------
|
||||
|
/// Returns true if this signature is a valid signature for the specified message generated
|
||||
|
/// against key pair matching the specified public key commitment.
|
||||
|
pub fn verify(&self, message: Word, pubkey_com: Word) -> bool {
|
||||
|
// Make sure the expanded public key matches the provided public key commitment
|
||||
|
let h = self.pub_key_poly();
|
||||
|
let h_digest: Word = Rpo256::hash_elements(&h.to_elements()).into();
|
||||
|
if h_digest != pubkey_com {
|
||||
|
return false;
|
||||
|
}
|
||||
|
|
||||
|
// Make sure the signature is valid
|
||||
|
let s2 = self.sig_poly();
|
||||
|
let c = self.hash_to_point(message);
|
||||
|
|
||||
|
let s1 = c - s2 * h;
|
||||
|
|
||||
|
let sq_norm = s1.sq_norm() + s2.sq_norm();
|
||||
|
sq_norm <= SIG_L2_BOUND
|
||||
|
}
|
||||
|
}
|
||||
|
|
||||
|
// SERIALIZATION / DESERIALIZATION
|
||||
|
// ================================================================================================
|
||||
|
|
||||
|
impl Serializable for Signature {
|
||||
|
fn write_into<W: ByteWriter>(&self, target: &mut W) {
|
||||
|
target.write_bytes(&self.pk);
|
||||
|
target.write_bytes(&self.sig);
|
||||
|
}
|
||||
|
}
|
||||
|
|
||||
|
impl Deserializable for Signature {
|
||||
|
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
|
||||
|
let pk: PublicKeyBytes = source.read_array()?;
|
||||
|
let sig: SignatureBytes = source.read_array()?;
|
||||
|
|
||||
|
// make sure public key and signature can be decoded correctly
|
||||
|
Polynomial::from_pub_key(&pk)
|
||||
|
.map_err(|err| DeserializationError::InvalidValue(err.to_string()))?;
|
||||
|
Polynomial::from_signature(&sig[41..])
|
||||
|
.map_err(|err| DeserializationError::InvalidValue(err.to_string()))?;
|
||||
|
|
||||
|
Ok(Self { pk, sig })
|
||||
|
}
|
||||
|
}
|
||||
|
|
||||
|
// HELPER FUNCTIONS
|
||||
|
// ================================================================================================
|
||||
|
|
||||
|
/// Returns a polynomial in Z_p[x]/(phi) representing the hash of the provided message and
|
||||
|
/// nonce.
|
||||
|
fn hash_to_point(message: Word, nonce: &NonceElements) -> Polynomial {
|
||||
|
let mut state = [ZERO; Rpo256::STATE_WIDTH];
|
||||
|
|
||||
|
// absorb the nonce into the state
|
||||
|
for (&n, s) in nonce.iter().zip(state[Rpo256::RATE_RANGE].iter_mut()) {
|
||||
|
*s = n;
|
||||
|
}
|
||||
|
Rpo256::apply_permutation(&mut state);
|
||||
|
|
||||
|
// absorb message into the state
|
||||
|
for (&m, s) in message.iter().zip(state[Rpo256::RATE_RANGE].iter_mut()) {
|
||||
|
*s = m;
|
||||
|
}
|
||||
|
|
||||
|
// squeeze the coefficients of the polynomial
|
||||
|
let mut i = 0;
|
||||
|
let mut res = [0_u16; N];
|
||||
|
for _ in 0..64 {
|
||||
|
Rpo256::apply_permutation(&mut state);
|
||||
|
for a in &state[Rpo256::RATE_RANGE] {
|
||||
|
res[i] = (a.as_int() % MODULUS as u64) as u16;
|
||||
|
i += 1;
|
||||
|
}
|
||||
|
}
|
||||
|
|
||||
|
// using the raw constructor is OK here because we reduce all coefficients by the modulus above
|
||||
|
unsafe { Polynomial::new(res) }
|
||||
|
}
|
||||
|
|
||||
|
/// Converts byte representation of the nonce into field element representation.
|
||||
|
fn decode_nonce(nonce: &NonceBytes) -> NonceElements {
|
||||
|
let mut buffer = [0_u8; 8];
|
||||
|
let mut result = [ZERO; 8];
|
||||
|
for (i, bytes) in nonce.chunks(5).enumerate() {
|
||||
|
buffer[..5].copy_from_slice(bytes);
|
||||
|
result[i] = u64::from_le_bytes(buffer).into();
|
||||
|
}
|
||||
|
|
||||
|
result
|
||||
|
}
|
||||
|
|
||||
|
// TESTS
|
||||
|
// ================================================================================================
|
||||
|
|
||||
|
#[cfg(all(test, feature = "std"))]
|
||||
|
mod tests {
|
||||
|
use super::{
|
||||
|
super::{ffi::*, Felt},
|
||||
|
*,
|
||||
|
};
|
||||
|
use libc::c_void;
|
||||
|
use rand_utils::rand_vector;
|
||||
|
|
||||
|
// Wrappers for unsafe functions
|
||||
|
impl Rpo128Context {
|
||||
|
/// Initializes the RPO state.
|
||||
|
pub fn init() -> Self {
|
||||
|
let mut ctx = Rpo128Context { content: [0u64; 13] };
|
||||
|
unsafe {
|
||||
|
rpo128_init(&mut ctx as *mut Rpo128Context);
|
||||
|
}
|
||||
|
ctx
|
||||
|
}
|
||||
|
|
||||
|
/// Absorbs data into the RPO state.
|
||||
|
pub fn absorb(&mut self, data: &[u8]) {
|
||||
|
unsafe {
|
||||
|
rpo128_absorb(
|
||||
|
self as *mut Rpo128Context,
|
||||
|
data.as_ptr() as *const c_void,
|
||||
|
data.len(),
|
||||
|
)
|
||||
|
}
|
||||
|
}
|
||||
|
|
||||
|
/// Finalizes the RPO state to prepare for squeezing.
|
||||
|
pub fn finalize(&mut self) {
|
||||
|
unsafe { rpo128_finalize(self as *mut Rpo128Context) }
|
||||
|
}
|
||||
|
}
|
||||
|
|
||||
|
#[test]
|
||||
|
fn test_hash_to_point() {
|
||||
|
// Create a random message and transform it into a u8 vector
|
||||
|
let msg_felts: Word = rand_vector::<Felt>(4).try_into().unwrap();
|
||||
|
let msg_bytes = msg_felts.iter().flat_map(|e| e.as_int().to_le_bytes()).collect::<Vec<_>>();
|
||||
|
|
||||
|
// Create a nonce i.e. a [u8; 40] array and pack into a [Felt; 8] array.
|
||||
|
let nonce: [u8; 40] = rand_vector::<u8>(40).try_into().unwrap();
|
||||
|
|
||||
|
let mut buffer = [0_u8; 64];
|
||||
|
for i in 0..8 {
|
||||
|
buffer[8 * i] = nonce[5 * i];
|
||||
|
buffer[8 * i + 1] = nonce[5 * i + 1];
|
||||
|
buffer[8 * i + 2] = nonce[5 * i + 2];
|
||||
|
buffer[8 * i + 3] = nonce[5 * i + 3];
|
||||
|
buffer[8 * i + 4] = nonce[5 * i + 4];
|
||||
|
}
|
||||
|
|
||||
|
// Initialize the RPO state
|
||||
|
let mut rng = Rpo128Context::init();
|
||||
|
|
||||
|
// Absorb the nonce and message into the RPO state
|
||||
|
rng.absorb(&buffer);
|
||||
|
rng.absorb(&msg_bytes);
|
||||
|
rng.finalize();
|
||||
|
|
||||
|
// Generate the coefficients of the hash-to-point polynomial.
|
||||
|
let mut res: [u16; N] = [0; N];
|
||||
|
|
||||
|
unsafe {
|
||||
|
PQCLEAN_FALCON512_CLEAN_hash_to_point_rpo(
|
||||
|
&mut rng as *mut Rpo128Context,
|
||||
|
res.as_mut_ptr(),
|
||||
|
9,
|
||||
|
);
|
||||
|
}
|
||||
|
|
||||
|
// Check that the coefficients are correct
|
||||
|
let nonce = decode_nonce(&nonce);
|
||||
|
assert_eq!(res, hash_to_point(msg_felts, &nonce).inner());
|
||||
|
}
|
||||
|
}
|
@ -0,0 +1,129 @@ |
|||||
|
use clap::Parser;
|
||||
|
use miden_crypto::{
|
||||
|
hash::rpo::RpoDigest,
|
||||
|
merkle::MerkleError,
|
||||
|
Felt, Word, ONE,
|
||||
|
{hash::rpo::Rpo256, merkle::TieredSmt},
|
||||
|
};
|
||||
|
use rand_utils::rand_value;
|
||||
|
use std::time::Instant;
|
||||
|
|
||||
|
#[derive(Parser, Debug)]
|
||||
|
#[clap(
|
||||
|
name = "Benchmark",
|
||||
|
about = "Tiered SMT benchmark",
|
||||
|
version,
|
||||
|
rename_all = "kebab-case"
|
||||
|
)]
|
||||
|
pub struct BenchmarkCmd {
|
||||
|
/// Size of the tree
|
||||
|
#[clap(short = 's', long = "size")]
|
||||
|
size: u64,
|
||||
|
}
|
||||
|
|
||||
|
fn main() {
|
||||
|
benchmark_tsmt();
|
||||
|
}
|
||||
|
|
||||
|
/// Run a benchmark for the Tiered SMT.
|
||||
|
pub fn benchmark_tsmt() {
|
||||
|
let args = BenchmarkCmd::parse();
|
||||
|
let tree_size = args.size;
|
||||
|
|
||||
|
// prepare the `leaves` vector for tree creation
|
||||
|
let mut entries = Vec::new();
|
||||
|
for i in 0..tree_size {
|
||||
|
let key = rand_value::<RpoDigest>();
|
||||
|
let value = [ONE, ONE, ONE, Felt::new(i)];
|
||||
|
entries.push((key, value));
|
||||
|
}
|
||||
|
|
||||
|
let mut tree = construction(entries, tree_size).unwrap();
|
||||
|
insertion(&mut tree, tree_size).unwrap();
|
||||
|
proof_generation(&mut tree, tree_size).unwrap();
|
||||
|
}
|
||||
|
|
||||
|
/// Runs the construction benchmark for the Tiered SMT, returning the constructed tree.
|
||||
|
pub fn construction(entries: Vec<(RpoDigest, Word)>, size: u64) -> Result<TieredSmt, MerkleError> {
|
||||
|
println!("Running a construction benchmark:");
|
||||
|
let now = Instant::now();
|
||||
|
let tree = TieredSmt::with_entries(entries)?;
|
||||
|
let elapsed = now.elapsed();
|
||||
|
println!(
|
||||
|
"Constructed a TSMT with {} key-value pairs in {:.3} seconds",
|
||||
|
size,
|
||||
|
elapsed.as_secs_f32(),
|
||||
|
);
|
||||
|
|
||||
|
// Count how many nodes end up at each tier
|
||||
|
let mut nodes_num_16_32_48 = (0, 0, 0);
|
||||
|
|
||||
|
tree.upper_leaf_nodes().for_each(|(index, _)| match index.depth() {
|
||||
|
16 => nodes_num_16_32_48.0 += 1,
|
||||
|
32 => nodes_num_16_32_48.1 += 1,
|
||||
|
48 => nodes_num_16_32_48.2 += 1,
|
||||
|
_ => unreachable!(),
|
||||
|
});
|
||||
|
|
||||
|
println!("Number of nodes on depth 16: {}", nodes_num_16_32_48.0);
|
||||
|
println!("Number of nodes on depth 32: {}", nodes_num_16_32_48.1);
|
||||
|
println!("Number of nodes on depth 48: {}", nodes_num_16_32_48.2);
|
||||
|
println!("Number of nodes on depth 64: {}\n", tree.bottom_leaves().count());
|
||||
|
|
||||
|
Ok(tree)
|
||||
|
}
|
||||
|
|
||||
|
/// Runs the insertion benchmark for the Tiered SMT.
|
||||
|
pub fn insertion(tree: &mut TieredSmt, size: u64) -> Result<(), MerkleError> {
|
||||
|
println!("Running an insertion benchmark:");
|
||||
|
|
||||
|
let mut insertion_times = Vec::new();
|
||||
|
|
||||
|
for i in 0..20 {
|
||||
|
let test_key = Rpo256::hash(&rand_value::<u64>().to_be_bytes());
|
||||
|
let test_value = [ONE, ONE, ONE, Felt::new(size + i)];
|
||||
|
|
||||
|
let now = Instant::now();
|
||||
|
tree.insert(test_key, test_value);
|
||||
|
let elapsed = now.elapsed();
|
||||
|
insertion_times.push(elapsed.as_secs_f32());
|
||||
|
}
|
||||
|
|
||||
|
println!(
|
||||
|
"An average insertion time measured by 20 inserts into a TSMT with {} key-value pairs is {:.3} milliseconds\n",
|
||||
|
size,
|
||||
|
// calculate the average by dividing by 20 and convert to milliseconds by multiplying by
|
||||
|
// 1000. As a result, we can only multiply by 50
|
||||
|
insertion_times.iter().sum::<f32>() * 50f32,
|
||||
|
);
|
||||
|
|
||||
|
Ok(())
|
||||
|
}
|
||||
|
|
||||
|
/// Runs the proof generation benchmark for the Tiered SMT.
|
||||
|
pub fn proof_generation(tree: &mut TieredSmt, size: u64) -> Result<(), MerkleError> {
|
||||
|
println!("Running a proof generation benchmark:");
|
||||
|
|
||||
|
let mut insertion_times = Vec::new();
|
||||
|
|
||||
|
for i in 0..20 {
|
||||
|
let test_key = Rpo256::hash(&rand_value::<u64>().to_be_bytes());
|
||||
|
let test_value = [ONE, ONE, ONE, Felt::new(size + i)];
|
||||
|
tree.insert(test_key, test_value);
|
||||
|
|
||||
|
let now = Instant::now();
|
||||
|
let _proof = tree.prove(test_key);
|
||||
|
let elapsed = now.elapsed();
|
||||
|
insertion_times.push(elapsed.as_secs_f32());
|
||||
|
}
|
||||
|
|
||||
|
println!(
|
||||
|
"An average proving time measured by 20 value proofs in a TSMT with {} key-value pairs in {:.3} microseconds",
|
||||
|
size,
|
||||
|
// calculate the average by dividing by 20 and convert to microseconds by multiplying by
|
||||
|
// 1000000. As a result, we can only multiply by 50000
|
||||
|
insertion_times.iter().sum::<f32>() * 50000f32,
|
||||
|
);
|
||||
|
|
||||
|
Ok(())
|
||||
|
}
|
@ -0,0 +1,156 @@ |
|||||
|
use super::{
|
||||
|
BTreeMap, KvMap, MerkleError, MerkleStore, NodeIndex, RpoDigest, StoreNode, Vec, Word,
|
||||
|
};
|
||||
|
use crate::utils::collections::Diff;
|
||||
|
|
||||
|
#[cfg(test)]
|
||||
|
use super::{super::ONE, Felt, SimpleSmt, EMPTY_WORD, ZERO};
|
||||
|
|
||||
|
// MERKLE STORE DELTA
|
||||
|
// ================================================================================================
|
||||
|
|
||||
|
/// [MerkleStoreDelta] stores a vector of ([RpoDigest], [MerkleTreeDelta]) tuples where the
|
||||
|
/// [RpoDigest] represents the root of the Merkle tree and [MerkleTreeDelta] represents the
|
||||
|
/// differences between the initial and final Merkle tree states.
|
||||
|
#[derive(Default, Debug, Clone, PartialEq, Eq)]
|
||||
|
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
|
||||
|
pub struct MerkleStoreDelta(pub Vec<(RpoDigest, MerkleTreeDelta)>);
|
||||
|
|
||||
|
// MERKLE TREE DELTA
|
||||
|
// ================================================================================================
|
||||
|
|
||||
|
/// [MerkleDelta] stores the differences between the initial and final Merkle tree states.
|
||||
|
///
|
||||
|
/// The differences are represented as follows:
|
||||
|
/// - depth: the depth of the merkle tree.
|
||||
|
/// - cleared_slots: indexes of slots where values were set to [ZERO; 4].
|
||||
|
/// - updated_slots: index-value pairs of slots where values were set to non [ZERO; 4] values.
|
||||
|
#[cfg(not(test))]
|
||||
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
|
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
|
||||
|
pub struct MerkleTreeDelta {
|
||||
|
depth: u8,
|
||||
|
cleared_slots: Vec<u64>,
|
||||
|
updated_slots: Vec<(u64, Word)>,
|
||||
|
}
|
||||
|
|
||||
|
impl MerkleTreeDelta {
|
||||
|
// CONSTRUCTOR
|
||||
|
// --------------------------------------------------------------------------------------------
|
||||
|
pub fn new(depth: u8) -> Self {
|
||||
|
Self {
|
||||
|
depth,
|
||||
|
cleared_slots: Vec::new(),
|
||||
|
updated_slots: Vec::new(),
|
||||
|
}
|
||||
|
}
|
||||
|
|
||||
|
// ACCESSORS
|
||||
|
// --------------------------------------------------------------------------------------------
|
||||
|
/// Returns the depth of the Merkle tree the [MerkleDelta] is associated with.
|
||||
|
pub fn depth(&self) -> u8 {
|
||||
|
self.depth
|
||||
|
}
|
||||
|
|
||||
|
/// Returns the indexes of slots where values were set to [ZERO; 4].
|
||||
|
pub fn cleared_slots(&self) -> &[u64] {
|
||||
|
&self.cleared_slots
|
||||
|
}
|
||||
|
|
||||
|
/// Returns the index-value pairs of slots where values were set to non [ZERO; 4] values.
|
||||
|
pub fn updated_slots(&self) -> &[(u64, Word)] {
|
||||
|
&self.updated_slots
|
||||
|
}
|
||||
|
|
||||
|
// MODIFIERS
|
||||
|
// --------------------------------------------------------------------------------------------
|
||||
|
/// Adds a slot index to the list of cleared slots.
|
||||
|
pub fn add_cleared_slot(&mut self, index: u64) {
|
||||
|
self.cleared_slots.push(index);
|
||||
|
}
|
||||
|
|
||||
|
/// Adds a slot index and a value to the list of updated slots.
|
||||
|
pub fn add_updated_slot(&mut self, index: u64, value: Word) {
|
||||
|
self.updated_slots.push((index, value));
|
||||
|
}
|
||||
|
}
|
||||
|
|
||||
|
/// Extracts a [MerkleDelta] object by comparing the leaves of two Merkle trees specifies by
|
||||
|
/// their roots and depth.
|
||||
|
pub fn merkle_tree_delta<T: KvMap<RpoDigest, StoreNode>>(
|
||||
|
tree_root_1: RpoDigest,
|
||||
|
tree_root_2: RpoDigest,
|
||||
|
depth: u8,
|
||||
|
merkle_store: &MerkleStore<T>,
|
||||
|
) -> Result<MerkleTreeDelta, MerkleError> {
|
||||
|
if tree_root_1 == tree_root_2 {
|
||||
|
return Ok(MerkleTreeDelta::new(depth));
|
||||
|
}
|
||||
|
|
||||
|
let tree_1_leaves: BTreeMap<NodeIndex, RpoDigest> =
|
||||
|
merkle_store.non_empty_leaves(tree_root_1, depth).collect();
|
||||
|
let tree_2_leaves: BTreeMap<NodeIndex, RpoDigest> =
|
||||
|
merkle_store.non_empty_leaves(tree_root_2, depth).collect();
|
||||
|
let diff = tree_1_leaves.diff(&tree_2_leaves);
|
||||
|
|
||||
|
// TODO: Refactor this diff implementation to prevent allocation of both BTree and Vec.
|
||||
|
Ok(MerkleTreeDelta {
|
||||
|
depth,
|
||||
|
cleared_slots: diff.removed.into_iter().map(|index| index.value()).collect(),
|
||||
|
updated_slots: diff
|
||||
|
.updated
|
||||
|
.into_iter()
|
||||
|
.map(|(index, leaf)| (index.value(), *leaf))
|
||||
|
.collect(),
|
||||
|
})
|
||||
|
}
|
||||
|
|
||||
|
// INTERNALS
|
||||
|
// --------------------------------------------------------------------------------------------
|
||||
|
#[cfg(test)]
|
||||
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
|
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
|
||||
|
pub struct MerkleTreeDelta {
|
||||
|
pub depth: u8,
|
||||
|
pub cleared_slots: Vec<u64>,
|
||||
|
pub updated_slots: Vec<(u64, Word)>,
|
||||
|
}
|
||||
|
|
||||
|
// MERKLE DELTA
|
||||
|
// ================================================================================================
|
||||
|
#[test]
|
||||
|
fn test_compute_merkle_delta() {
|
||||
|
let entries = vec![
|
||||
|
(10, [ZERO, ONE, Felt::new(2), Felt::new(3)]),
|
||||
|
(15, [Felt::new(4), Felt::new(5), Felt::new(6), Felt::new(7)]),
|
||||
|
(20, [Felt::new(8), Felt::new(9), Felt::new(10), Felt::new(11)]),
|
||||
|
(31, [Felt::new(12), Felt::new(13), Felt::new(14), Felt::new(15)]),
|
||||
|
];
|
||||
|
let simple_smt = SimpleSmt::with_leaves(30, entries.clone()).unwrap();
|
||||
|
let mut store: MerkleStore = (&simple_smt).into();
|
||||
|
let root = simple_smt.root();
|
||||
|
|
||||
|
// add a new node
|
||||
|
let new_value = [Felt::new(16), Felt::new(17), Felt::new(18), Felt::new(19)];
|
||||
|
let new_index = NodeIndex::new(simple_smt.depth(), 32).unwrap();
|
||||
|
let root = store.set_node(root, new_index, new_value.into()).unwrap().root;
|
||||
|
|
||||
|
// update an existing node
|
||||
|
let update_value = [Felt::new(20), Felt::new(21), Felt::new(22), Felt::new(23)];
|
||||
|
let update_idx = NodeIndex::new(simple_smt.depth(), entries[0].0).unwrap();
|
||||
|
let root = store.set_node(root, update_idx, update_value.into()).unwrap().root;
|
||||
|
|
||||
|
// remove a node
|
||||
|
let remove_idx = NodeIndex::new(simple_smt.depth(), entries[1].0).unwrap();
|
||||
|
let root = store.set_node(root, remove_idx, EMPTY_WORD.into()).unwrap().root;
|
||||
|
|
||||
|
let merkle_delta =
|
||||
|
merkle_tree_delta(simple_smt.root(), root, simple_smt.depth(), &store).unwrap();
|
||||
|
let expected_merkle_delta = MerkleTreeDelta {
|
||||
|
depth: simple_smt.depth(),
|
||||
|
cleared_slots: vec![remove_idx.value()],
|
||||
|
updated_slots: vec![(update_idx.value(), update_value), (new_index.value(), new_value)],
|
||||
|
};
|
||||
|
|
||||
|
assert_eq!(merkle_delta, expected_merkle_delta);
|
||||
|
}
|
@ -0,0 +1,54 @@ |
|||||
|
use crate::{
|
||||
|
merkle::{MerklePath, NodeIndex, RpoDigest},
|
||||
|
utils::collections::Vec,
|
||||
|
};
|
||||
|
use core::fmt;
|
||||
|
|
||||
|
#[derive(Clone, Debug, PartialEq, Eq)]
|
||||
|
pub enum MerkleError {
|
||||
|
ConflictingRoots(Vec<RpoDigest>),
|
||||
|
DepthTooSmall(u8),
|
||||
|
DepthTooBig(u64),
|
||||
|
DuplicateValuesForIndex(u64),
|
||||
|
DuplicateValuesForKey(RpoDigest),
|
||||
|
InvalidIndex { depth: u8, value: u64 },
|
||||
|
InvalidDepth { expected: u8, provided: u8 },
|
||||
|
InvalidPath(MerklePath),
|
||||
|
InvalidNumEntries(usize, usize),
|
||||
|
NodeNotInSet(NodeIndex),
|
||||
|
NodeNotInStore(RpoDigest, NodeIndex),
|
||||
|
NumLeavesNotPowerOfTwo(usize),
|
||||
|
RootNotInStore(RpoDigest),
|
||||
|
}
|
||||
|
|
||||
|
impl fmt::Display for MerkleError {
|
||||
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
|
use MerkleError::*;
|
||||
|
match self {
|
||||
|
ConflictingRoots(roots) => write!(f, "the merkle paths roots do not match {roots:?}"),
|
||||
|
DepthTooSmall(depth) => write!(f, "the provided depth {depth} is too small"),
|
||||
|
DepthTooBig(depth) => write!(f, "the provided depth {depth} is too big"),
|
||||
|
DuplicateValuesForIndex(key) => write!(f, "multiple values provided for key {key}"),
|
||||
|
DuplicateValuesForKey(key) => write!(f, "multiple values provided for key {key}"),
|
||||
|
InvalidIndex{ depth, value} => write!(
|
||||
|
f,
|
||||
|
"the index value {value} is not valid for the depth {depth}"
|
||||
|
),
|
||||
|
InvalidDepth { expected, provided } => write!(
|
||||
|
f,
|
||||
|
"the provided depth {provided} is not valid for {expected}"
|
||||
|
),
|
||||
|
InvalidPath(_path) => write!(f, "the provided path is not valid"),
|
||||
|
InvalidNumEntries(max, provided) => write!(f, "the provided number of entries is {provided}, but the maximum for the given depth is {max}"),
|
||||
|
NodeNotInSet(index) => write!(f, "the node with index ({index}) is not in the set"),
|
||||
|
NodeNotInStore(hash, index) => write!(f, "the node {hash:?} with index ({index}) is not in the store"),
|
||||
|
NumLeavesNotPowerOfTwo(leaves) => {
|
||||
|
write!(f, "the leaves count {leaves} is not a power of 2")
|
||||
|
}
|
||||
|
RootNotInStore(root) => write!(f, "the root {:?} is not in the store", root),
|
||||
|
}
|
||||
|
}
|
||||
|
}
|
||||
|
|
||||
|
#[cfg(feature = "std")]
|
||||
|
impl std::error::Error for MerkleError {}
|
@ -1,408 +0,0 @@ |
|||||
use super::{BTreeMap, MerkleError, MerklePath, NodeIndex, Rpo256, ValuePath, Vec};
|
|
||||
use crate::{hash::rpo::RpoDigest, Word};
|
|
||||
|
|
||||
// MERKLE PATH SET
|
|
||||
// ================================================================================================
|
|
||||
|
|
||||
/// A set of Merkle paths.
|
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
|
||||
pub struct MerklePathSet {
|
|
||||
root: RpoDigest,
|
|
||||
total_depth: u8,
|
|
||||
paths: BTreeMap<u64, MerklePath>,
|
|
||||
}
|
|
||||
|
|
||||
impl MerklePathSet {
|
|
||||
// CONSTRUCTOR
|
|
||||
// --------------------------------------------------------------------------------------------
|
|
||||
|
|
||||
/// Returns an empty MerklePathSet.
|
|
||||
pub fn new(depth: u8) -> Self {
|
|
||||
let root = RpoDigest::default();
|
|
||||
let paths = BTreeMap::new();
|
|
||||
|
|
||||
Self {
|
|
||||
root,
|
|
||||
total_depth: depth,
|
|
||||
paths,
|
|
||||
}
|
|
||||
}
|
|
||||
|
|
||||
/// Appends the provided paths iterator into the set.
|
|
||||
///
|
|
||||
/// Analogous to `[Self::add_path]`.
|
|
||||
pub fn with_paths<I>(self, paths: I) -> Result<Self, MerkleError>
|
|
||||
where
|
|
||||
I: IntoIterator<Item = (u64, RpoDigest, MerklePath)>,
|
|
||||
{
|
|
||||
paths.into_iter().try_fold(self, |mut set, (index, value, path)| {
|
|
||||
set.add_path(index, value.into(), path)?;
|
|
||||
Ok(set)
|
|
||||
})
|
|
||||
}
|
|
||||
|
|
||||
// PUBLIC ACCESSORS
|
|
||||
// --------------------------------------------------------------------------------------------
|
|
||||
|
|
||||
/// Returns the root to which all paths in this set resolve.
|
|
||||
pub const fn root(&self) -> RpoDigest {
|
|
||||
self.root
|
|
||||
}
|
|
||||
|
|
||||
/// Returns the depth of the Merkle tree implied by the paths stored in this set.
|
|
||||
///
|
|
||||
/// Merkle tree of depth 1 has two leaves, depth 2 has four leaves etc.
|
|
||||
pub const fn depth(&self) -> u8 {
|
|
||||
self.total_depth
|
|
||||
}
|
|
||||
|
|
||||
/// Returns a node at the specified index.
|
|
||||
///
|
|
||||
/// # Errors
|
|
||||
/// Returns an error if:
|
|
||||
/// * The specified index is not valid for the depth of structure.
|
|
||||
/// * Requested node does not exist in the set.
|
|
||||
pub fn get_node(&self, index: NodeIndex) -> Result<RpoDigest, MerkleError> {
|
|
||||
if index.depth() != self.total_depth {
|
|
||||
return Err(MerkleError::InvalidDepth {
|
|
||||
expected: self.total_depth,
|
|
||||
provided: index.depth(),
|
|
||||
});
|
|
||||
}
|
|
||||
|
|
||||
let parity = index.value() & 1;
|
|
||||
let path_key = index.value() - parity;
|
|
||||
self.paths
|
|
||||
.get(&path_key)
|
|
||||
.ok_or(MerkleError::NodeNotInSet(index))
|
|
||||
.map(|path| path[parity as usize])
|
|
||||
}
|
|
||||
|
|
||||
/// Returns a leaf at the specified index.
|
|
||||
///
|
|
||||
/// # Errors
|
|
||||
/// * The specified index is not valid for the depth of the structure.
|
|
||||
/// * Leaf with the requested path does not exist in the set.
|
|
||||
pub fn get_leaf(&self, index: u64) -> Result<Word, MerkleError> {
|
|
||||
let index = NodeIndex::new(self.depth(), index)?;
|
|
||||
Ok(self.get_node(index)?.into())
|
|
||||
}
|
|
||||
|
|
||||
/// Returns a Merkle path to the node at the specified index. The node itself is
|
|
||||
/// not included in the path.
|
|
||||
///
|
|
||||
/// # Errors
|
|
||||
/// Returns an error if:
|
|
||||
/// * The specified index is not valid for the depth of structure.
|
|
||||
/// * Node of the requested path does not exist in the set.
|
|
||||
pub fn get_path(&self, index: NodeIndex) -> Result<MerklePath, MerkleError> {
|
|
||||
if index.depth() != self.total_depth {
|
|
||||
return Err(MerkleError::InvalidDepth {
|
|
||||
expected: self.total_depth,
|
|
||||
provided: index.depth(),
|
|
||||
});
|
|
||||
}
|
|
||||
|
|
||||
let parity = index.value() & 1;
|
|
||||
let path_key = index.value() - parity;
|
|
||||
let mut path =
|
|
||||
self.paths.get(&path_key).cloned().ok_or(MerkleError::NodeNotInSet(index))?;
|
|
||||
path.remove(parity as usize);
|
|
||||
Ok(path)
|
|
||||
}
|
|
||||
|
|
||||
/// Returns all paths in this path set together with their indexes.
|
|
||||
pub fn to_paths(&self) -> Vec<(u64, ValuePath)> {
|
|
||||
let mut result = Vec::with_capacity(self.paths.len() * 2);
|
|
||||
|
|
||||
for (&index, path) in self.paths.iter() {
|
|
||||
// push path for the even index into the result
|
|
||||
let path1 = ValuePath {
|
|
||||
value: path[0],
|
|
||||
path: MerklePath::new(path[1..].to_vec()),
|
|
||||
};
|
|
||||
result.push((index, path1));
|
|
||||
|
|
||||
// push path for the odd index into the result
|
|
||||
let mut path2 = path.clone();
|
|
||||
let leaf2 = path2.remove(1);
|
|
||||
let path2 = ValuePath {
|
|
||||
value: leaf2,
|
|
||||
path: path2,
|
|
||||
};
|
|
||||
result.push((index + 1, path2));
|
|
||||
}
|
|
||||
|
|
||||
result
|
|
||||
}
|
|
||||
|
|
||||
// STATE MUTATORS
|
|
||||
// --------------------------------------------------------------------------------------------
|
|
||||
|
|
||||
/// Adds the specified Merkle path to this [MerklePathSet]. The `index` and `value` parameters
|
|
||||
/// specify the leaf node at which the path starts.
|
|
||||
///
|
|
||||
/// # Errors
|
|
||||
/// Returns an error if:
|
|
||||
/// - The specified index is is not valid in the context of this Merkle path set (i.e., the
|
|
||||
/// index implies a greater depth than is specified for this set).
|
|
||||
/// - The specified path is not consistent with other paths in the set (i.e., resolves to a
|
|
||||
/// different root).
|
|
||||
pub fn add_path(
|
|
||||
&mut self,
|
|
||||
index_value: u64,
|
|
||||
value: Word,
|
|
||||
mut path: MerklePath,
|
|
||||
) -> Result<(), MerkleError> {
|
|
||||
let mut index = NodeIndex::new(path.len() as u8, index_value)?;
|
|
||||
if index.depth() != self.total_depth {
|
|
||||
return Err(MerkleError::InvalidDepth {
|
|
||||
expected: self.total_depth,
|
|
||||
provided: index.depth(),
|
|
||||
});
|
|
||||
}
|
|
||||
|
|
||||
// update the current path
|
|
||||
let parity = index_value & 1;
|
|
||||
path.insert(parity as usize, value.into());
|
|
||||
|
|
||||
// traverse to the root, updating the nodes
|
|
||||
let root = Rpo256::merge(&[path[0], path[1]]);
|
|
||||
let root = path.iter().skip(2).copied().fold(root, |root, hash| {
|
|
||||
index.move_up();
|
|
||||
Rpo256::merge(&index.build_node(root, hash))
|
|
||||
});
|
|
||||
|
|
||||
// if the path set is empty (the root is all ZEROs), set the root to the root of the added
|
|
||||
// path; otherwise, the root of the added path must be identical to the current root
|
|
||||
if self.root == RpoDigest::default() {
|
|
||||
self.root = root;
|
|
||||
} else if self.root != root {
|
|
||||
return Err(MerkleError::ConflictingRoots([self.root, root].to_vec()));
|
|
||||
}
|
|
||||
|
|
||||
// finish updating the path
|
|
||||
let path_key = index_value - parity;
|
|
||||
self.paths.insert(path_key, path);
|
|
||||
Ok(())
|
|
||||
}
|
|
||||
|
|
||||
/// Replaces the leaf at the specified index with the provided value.
|
|
||||
///
|
|
||||
/// # Errors
|
|
||||
/// Returns an error if:
|
|
||||
/// * Requested node does not exist in the set.
|
|
||||
pub fn update_leaf(&mut self, base_index_value: u64, value: Word) -> Result<(), MerkleError> {
|
|
||||
let mut index = NodeIndex::new(self.depth(), base_index_value)?;
|
|
||||
let parity = index.value() & 1;
|
|
||||
let path_key = index.value() - parity;
|
|
||||
let path = match self.paths.get_mut(&path_key) {
|
|
||||
Some(path) => path,
|
|
||||
None => return Err(MerkleError::NodeNotInSet(index)),
|
|
||||
};
|
|
||||
|
|
||||
// Fill old_hashes vector -----------------------------------------------------------------
|
|
||||
let mut current_index = index;
|
|
||||
let mut old_hashes = Vec::with_capacity(path.len().saturating_sub(2));
|
|
||||
let mut root = Rpo256::merge(&[path[0], path[1]]);
|
|
||||
for hash in path.iter().skip(2).copied() {
|
|
||||
old_hashes.push(root);
|
|
||||
current_index.move_up();
|
|
||||
let input = current_index.build_node(hash, root);
|
|
||||
root = Rpo256::merge(&input);
|
|
||||
}
|
|
||||
|
|
||||
// Fill new_hashes vector -----------------------------------------------------------------
|
|
||||
path[index.is_value_odd() as usize] = value.into();
|
|
||||
|
|
||||
let mut new_hashes = Vec::with_capacity(path.len().saturating_sub(2));
|
|
||||
let mut new_root = Rpo256::merge(&[path[0], path[1]]);
|
|
||||
for path_hash in path.iter().skip(2).copied() {
|
|
||||
new_hashes.push(new_root);
|
|
||||
index.move_up();
|
|
||||
let input = current_index.build_node(path_hash, new_root);
|
|
||||
new_root = Rpo256::merge(&input);
|
|
||||
}
|
|
||||
|
|
||||
self.root = new_root;
|
|
||||
|
|
||||
// update paths ---------------------------------------------------------------------------
|
|
||||
for path in self.paths.values_mut() {
|
|
||||
for i in (0..old_hashes.len()).rev() {
|
|
||||
if path[i + 2] == old_hashes[i] {
|
|
||||
path[i + 2] = new_hashes[i];
|
|
||||
break;
|
|
||||
}
|
|
||||
}
|
|
||||
}
|
|
||||
|
|
||||
Ok(())
|
|
||||
}
|
|
||||
}
|
|
||||
|
|
||||
// TESTS
|
|
||||
// ================================================================================================
|
|
||||
|
|
||||
#[cfg(test)]
|
|
||||
mod tests {
|
|
||||
use super::*;
|
|
||||
use crate::merkle::{int_to_leaf, int_to_node};
|
|
||||
|
|
||||
#[test]
|
|
||||
fn get_root() {
|
|
||||
let leaf0 = int_to_node(0);
|
|
||||
let leaf1 = int_to_node(1);
|
|
||||
let leaf2 = int_to_node(2);
|
|
||||
let leaf3 = int_to_node(3);
|
|
||||
|
|
||||
let parent0 = calculate_parent_hash(leaf0, 0, leaf1);
|
|
||||
let parent1 = calculate_parent_hash(leaf2, 2, leaf3);
|
|
||||
|
|
||||
let root_exp = calculate_parent_hash(parent0, 0, parent1);
|
|
||||
|
|
||||
let set = super::MerklePathSet::new(2)
|
|
||||
.with_paths([(0, leaf0, vec![leaf1, parent1].into())])
|
|
||||
.unwrap();
|
|
||||
|
|
||||
assert_eq!(set.root(), root_exp);
|
|
||||
}
|
|
||||
|
|
||||
#[test]
|
|
||||
fn add_and_get_path() {
|
|
||||
let path_6 = vec![int_to_node(7), int_to_node(45), int_to_node(123)];
|
|
||||
let hash_6 = int_to_node(6);
|
|
||||
let index = 6_u64;
|
|
||||
let depth = 3_u8;
|
|
||||
let set = super::MerklePathSet::new(depth)
|
|
||||
.with_paths([(index, hash_6, path_6.clone().into())])
|
|
||||
.unwrap();
|
|
||||
let stored_path_6 = set.get_path(NodeIndex::make(depth, index)).unwrap();
|
|
||||
|
|
||||
assert_eq!(path_6, *stored_path_6);
|
|
||||
}
|
|
||||
|
|
||||
#[test]
|
|
||||
fn get_node() {
|
|
||||
let path_6 = vec![int_to_node(7), int_to_node(45), int_to_node(123)];
|
|
||||
let hash_6 = int_to_node(6);
|
|
||||
let index = 6_u64;
|
|
||||
let depth = 3_u8;
|
|
||||
let set = MerklePathSet::new(depth).with_paths([(index, hash_6, path_6.into())]).unwrap();
|
|
||||
|
|
||||
assert_eq!(int_to_node(6u64), set.get_node(NodeIndex::make(depth, index)).unwrap());
|
|
||||
}
|
|
||||
|
|
||||
#[test]
|
|
||||
fn update_leaf() {
|
|
||||
let hash_4 = int_to_node(4);
|
|
||||
let hash_5 = int_to_node(5);
|
|
||||
let hash_6 = int_to_node(6);
|
|
||||
let hash_7 = int_to_node(7);
|
|
||||
let hash_45 = calculate_parent_hash(hash_4, 12u64, hash_5);
|
|
||||
let hash_67 = calculate_parent_hash(hash_6, 14u64, hash_7);
|
|
||||
|
|
||||
let hash_0123 = int_to_node(123);
|
|
||||
|
|
||||
let path_6 = vec![hash_7, hash_45, hash_0123];
|
|
||||
let path_5 = vec![hash_4, hash_67, hash_0123];
|
|
||||
let path_4 = vec![hash_5, hash_67, hash_0123];
|
|
||||
|
|
||||
let index_6 = 6_u64;
|
|
||||
let index_5 = 5_u64;
|
|
||||
let index_4 = 4_u64;
|
|
||||
let depth = 3_u8;
|
|
||||
let mut set = MerklePathSet::new(depth)
|
|
||||
.with_paths([
|
|
||||
(index_6, hash_6, path_6.into()),
|
|
||||
(index_5, hash_5, path_5.into()),
|
|
||||
(index_4, hash_4, path_4.into()),
|
|
||||
])
|
|
||||
.unwrap();
|
|
||||
|
|
||||
let new_hash_6 = int_to_leaf(100);
|
|
||||
let new_hash_5 = int_to_leaf(55);
|
|
||||
|
|
||||
set.update_leaf(index_6, new_hash_6).unwrap();
|
|
||||
let new_path_4 = set.get_path(NodeIndex::make(depth, index_4)).unwrap();
|
|
||||
let new_hash_67 = calculate_parent_hash(new_hash_6.into(), 14_u64, hash_7);
|
|
||||
assert_eq!(new_hash_67, new_path_4[1]);
|
|
||||
|
|
||||
set.update_leaf(index_5, new_hash_5).unwrap();
|
|
||||
let new_path_4 = set.get_path(NodeIndex::make(depth, index_4)).unwrap();
|
|
||||
let new_path_6 = set.get_path(NodeIndex::make(depth, index_6)).unwrap();
|
|
||||
let new_hash_45 = calculate_parent_hash(new_hash_5.into(), 13_u64, hash_4);
|
|
||||
assert_eq!(new_hash_45, new_path_6[1]);
|
|
||||
assert_eq!(RpoDigest::from(new_hash_5), new_path_4[0]);
|
|
||||
}
|
|
||||
|
|
||||
#[test]
|
|
||||
fn depth_3_is_correct() {
|
|
||||
let a = int_to_node(1);
|
|
||||
let b = int_to_node(2);
|
|
||||
let c = int_to_node(3);
|
|
||||
let d = int_to_node(4);
|
|
||||
let e = int_to_node(5);
|
|
||||
let f = int_to_node(6);
|
|
||||
let g = int_to_node(7);
|
|
||||
let h = int_to_node(8);
|
|
||||
|
|
||||
let i = Rpo256::merge(&[a, b]);
|
|
||||
let j = Rpo256::merge(&[c, d]);
|
|
||||
let k = Rpo256::merge(&[e, f]);
|
|
||||
let l = Rpo256::merge(&[g, h]);
|
|
||||
|
|
||||
let m = Rpo256::merge(&[i, j]);
|
|
||||
let n = Rpo256::merge(&[k, l]);
|
|
||||
|
|
||||
let root = Rpo256::merge(&[m, n]);
|
|
||||
|
|
||||
let mut set = MerklePathSet::new(3);
|
|
||||
|
|
||||
let value = b;
|
|
||||
let index = 1;
|
|
||||
let path = MerklePath::new([a, j, n].to_vec());
|
|
||||
set.add_path(index, value.into(), path).unwrap();
|
|
||||
assert_eq!(*value, set.get_leaf(index).unwrap());
|
|
||||
assert_eq!(root, set.root());
|
|
||||
|
|
||||
let value = e;
|
|
||||
let index = 4;
|
|
||||
let path = MerklePath::new([f, l, m].to_vec());
|
|
||||
set.add_path(index, value.into(), path).unwrap();
|
|
||||
assert_eq!(*value, set.get_leaf(index).unwrap());
|
|
||||
assert_eq!(root, set.root());
|
|
||||
|
|
||||
let value = a;
|
|
||||
let index = 0;
|
|
||||
let path = MerklePath::new([b, j, n].to_vec());
|
|
||||
set.add_path(index, value.into(), path).unwrap();
|
|
||||
assert_eq!(*value, set.get_leaf(index).unwrap());
|
|
||||
assert_eq!(root, set.root());
|
|
||||
|
|
||||
let value = h;
|
|
||||
let index = 7;
|
|
||||
let path = MerklePath::new([g, k, m].to_vec());
|
|
||||
set.add_path(index, value.into(), path).unwrap();
|
|
||||
assert_eq!(*value, set.get_leaf(index).unwrap());
|
|
||||
assert_eq!(root, set.root());
|
|
||||
}
|
|
||||
|
|
||||
// HELPER FUNCTIONS
|
|
||||
// --------------------------------------------------------------------------------------------
|
|
||||
|
|
||||
const fn is_even(pos: u64) -> bool {
|
|
||||
pos & 1 == 0
|
|
||||
}
|
|
||||
|
|
||||
/// Calculates the hash of the parent node by two sibling ones
|
|
||||
/// - node — current node
|
|
||||
/// - node_pos — position of the current node
|
|
||||
/// - sibling — neighboring vertex in the tree
|
|
||||
fn calculate_parent_hash(node: RpoDigest, node_pos: u64, sibling: RpoDigest) -> RpoDigest {
|
|
||||
if is_even(node_pos) {
|
|
||||
Rpo256::merge(&[node, sibling])
|
|
||||
} else {
|
|
||||
Rpo256::merge(&[sibling, node])
|
|
||||
}
|
|
||||
}
|
|
||||
}
|
|
@ -0,0 +1,48 @@ |
|||||
|
use core::fmt::Display;
|
||||
|
|
||||
|
#[derive(Debug, PartialEq, Eq)]
|
||||
|
pub enum TieredSmtProofError {
|
||||
|
EntriesEmpty,
|
||||
|
EmptyValueNotAllowed,
|
||||
|
MismatchedPrefixes(u64, u64),
|
||||
|
MultipleEntriesOutsideLastTier,
|
||||
|
NotATierPath(u8),
|
||||
|
PathTooLong,
|
||||
|
}
|
||||
|
|
||||
|
impl Display for TieredSmtProofError {
|
||||
|
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
|
||||
|
match self {
|
||||
|
TieredSmtProofError::EntriesEmpty => {
|
||||
|
write!(f, "Missing entries for tiered sparse merkle tree proof")
|
||||
|
}
|
||||
|
TieredSmtProofError::EmptyValueNotAllowed => {
|
||||
|
write!(
|
||||
|
f,
|
||||
|
"The empty value [0, 0, 0, 0] is not allowed inside a tiered sparse merkle tree"
|
||||
|
)
|
||||
|
}
|
||||
|
TieredSmtProofError::MismatchedPrefixes(first, second) => {
|
||||
|
write!(f, "Not all leaves have the same prefix. First {first} second {second}")
|
||||
|
}
|
||||
|
TieredSmtProofError::MultipleEntriesOutsideLastTier => {
|
||||
|
write!(f, "Multiple entries are only allowed for the last tier (depth 64)")
|
||||
|
}
|
||||
|
TieredSmtProofError::NotATierPath(got) => {
|
||||
|
write!(
|
||||
|
f,
|
||||
|
"Path length does not correspond to a tier. Got {got} Expected one of 16, 32, 48, 64"
|
||||
|
)
|
||||
|
}
|
||||
|
TieredSmtProofError::PathTooLong => {
|
||||
|
write!(
|
||||
|
f,
|
||||
|
"Path longer than maximum depth of 64 for tiered sparse merkle tree proof"
|
||||
|
)
|
||||
|
}
|
||||
|
}
|
||||
|
}
|
||||
|
}
|
||||
|
|
||||
|
#[cfg(feature = "std")]
|
||||
|
impl std::error::Error for TieredSmtProofError {}
|
@ -0,0 +1,419 @@ |
|||||
|
use super::{
|
||||
|
BTreeMap, BTreeSet, EmptySubtreeRoots, InnerNodeInfo, LeafNodeIndex, MerkleError, MerklePath,
|
||||
|
NodeIndex, Rpo256, RpoDigest, Vec,
|
||||
|
};
|
||||
|
|
||||
|
// CONSTANTS
|
||||
|
// ================================================================================================
|
||||
|
|
||||
|
/// The number of levels between tiers.
|
||||
|
const TIER_SIZE: u8 = super::TieredSmt::TIER_SIZE;
|
||||
|
|
||||
|
/// Depths at which leaves can exist in a tiered SMT.
|
||||
|
const TIER_DEPTHS: [u8; 4] = super::TieredSmt::TIER_DEPTHS;
|
||||
|
|
||||
|
/// Maximum node depth. This is also the bottom tier of the tree.
|
||||
|
const MAX_DEPTH: u8 = super::TieredSmt::MAX_DEPTH;
|
||||
|
|
||||
|
// NODE STORE
|
||||
|
// ================================================================================================
|
||||
|
|
||||
|
/// A store of nodes for a Tiered Sparse Merkle tree.
|
||||
|
///
|
||||
|
/// The store contains information about all nodes as well as information about which of the nodes
|
||||
|
/// represent leaf nodes in a Tiered Sparse Merkle tree. In the current implementation, [BTreeSet]s
|
||||
|
/// are used to determine the position of the leaves in the tree.
|
||||
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
|
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
|
||||
|
pub struct NodeStore {
|
||||
|
nodes: BTreeMap<NodeIndex, RpoDigest>,
|
||||
|
upper_leaves: BTreeSet<NodeIndex>,
|
||||
|
bottom_leaves: BTreeSet<u64>,
|
||||
|
}
|
||||
|
|
||||
|
impl NodeStore {
|
||||
|
// CONSTRUCTOR
|
||||
|
// --------------------------------------------------------------------------------------------
|
||||
|
/// Returns a new instance of [NodeStore] instantiated with the specified root node.
|
||||
|
///
|
||||
|
/// Root node is assumed to be a root of an empty sparse Merkle tree.
|
||||
|
pub fn new(root_node: RpoDigest) -> Self {
|
||||
|
let mut nodes = BTreeMap::default();
|
||||
|
nodes.insert(NodeIndex::root(), root_node);
|
||||
|
|
||||
|
Self {
|
||||
|
nodes,
|
||||
|
upper_leaves: BTreeSet::default(),
|
||||
|
bottom_leaves: BTreeSet::default(),
|
||||
|
}
|
||||
|
}
|
||||
|
|
||||
|
// PUBLIC ACCESSORS
|
||||
|
// --------------------------------------------------------------------------------------------
|
||||
|
|
||||
|
/// Returns a node at the specified index.
|
||||
|
///
|
||||
|
/// # Errors
|
||||
|
/// Returns an error if:
|
||||
|
/// - The specified index depth is 0 or greater than 64.
|
||||
|
/// - The node with the specified index does not exists in the Merkle tree. This is possible
|
||||
|
/// when a leaf node with the same index prefix exists at a tier higher than the requested
|
||||
|
/// node.
|
||||
|
pub fn get_node(&self, index: NodeIndex) -> Result<RpoDigest, MerkleError> {
|
||||
|
self.validate_node_access(index)?;
|
||||
|
Ok(self.get_node_unchecked(&index))
|
||||
|
}
|
||||
|
|
||||
|
/// Returns a Merkle path from the node at the specified index to the root.
|
||||
|
///
|
||||
|
/// The node itself is not included in the path.
|
||||
|
///
|
||||
|
/// # Errors
|
||||
|
/// Returns an error if:
|
||||
|
/// - The specified index depth is 0 or greater than 64.
|
||||
|
/// - The node with the specified index does not exists in the Merkle tree. This is possible
|
||||
|
/// when a leaf node with the same index prefix exists at a tier higher than the node to
|
||||
|
/// which the path is requested.
|
||||
|
pub fn get_path(&self, mut index: NodeIndex) -> Result<MerklePath, MerkleError> {
|
||||
|
self.validate_node_access(index)?;
|
||||
|
|
||||
|
let mut path = Vec::with_capacity(index.depth() as usize);
|
||||
|
for _ in 0..index.depth() {
|
||||
|
let node = self.get_node_unchecked(&index.sibling());
|
||||
|
path.push(node);
|
||||
|
index.move_up();
|
||||
|
}
|
||||
|
|
||||
|
Ok(path.into())
|
||||
|
}
|
||||
|
|
||||
|
/// Returns a Merkle path to the node specified by the key together with a flag indicating,
|
||||
|
/// whether this node is a leaf at depths 16, 32, or 48.
|
||||
|
pub fn get_proof(&self, key: &RpoDigest) -> (MerklePath, NodeIndex, bool) {
|
||||
|
let (index, leaf_exists) = self.get_leaf_index(key);
|
||||
|
let index: NodeIndex = index.into();
|
||||
|
let path = self.get_path(index).expect("failed to retrieve Merkle path for a node index");
|
||||
|
(path, index, leaf_exists)
|
||||
|
}
|
||||
|
|
||||
|
/// Returns an index at which a leaf node for the specified key should be inserted.
|
||||
|
///
|
||||
|
/// The second value in the returned tuple is set to true if the node at the returned index
|
||||
|
/// is already a leaf node.
|
||||
|
pub fn get_leaf_index(&self, key: &RpoDigest) -> (LeafNodeIndex, bool) {
|
||||
|
// traverse the tree from the root down checking nodes at tiers 16, 32, and 48. Return if
|
||||
|
// a node at any of the tiers is either a leaf or a root of an empty subtree.
|
||||
|
const NUM_UPPER_TIERS: usize = TIER_DEPTHS.len() - 1;
|
||||
|
for &tier_depth in TIER_DEPTHS[..NUM_UPPER_TIERS].iter() {
|
||||
|
let index = LeafNodeIndex::from_key(key, tier_depth);
|
||||
|
if self.upper_leaves.contains(&index) {
|
||||
|
return (index, true);
|
||||
|
} else if !self.nodes.contains_key(&index) {
|
||||
|
return (index, false);
|
||||
|
}
|
||||
|
}
|
||||
|
|
||||
|
// if we got here, that means all of the nodes checked so far are internal nodes, and
|
||||
|
// the new node would need to be inserted in the bottom tier.
|
||||
|
let index = LeafNodeIndex::from_key(key, MAX_DEPTH);
|
||||
|
(index, self.bottom_leaves.contains(&index.value()))
|
||||
|
}
|
||||
|
|
||||
|
/// Traverses the tree up from the bottom tier starting at the specified leaf index and
|
||||
|
/// returns the depth of the first node which hash more than one child. The returned depth
|
||||
|
/// is rounded up to the next tier.
|
||||
|
pub fn get_last_single_child_parent_depth(&self, leaf_index: u64) -> u8 {
|
||||
|
let mut index = NodeIndex::new_unchecked(MAX_DEPTH, leaf_index);
|
||||
|
|
||||
|
for _ in (TIER_DEPTHS[0]..MAX_DEPTH).rev() {
|
||||
|
let sibling_index = index.sibling();
|
||||
|
if self.nodes.contains_key(&sibling_index) {
|
||||
|
break;
|
||||
|
}
|
||||
|
index.move_up();
|
||||
|
}
|
||||
|
|
||||
|
let tier = (index.depth() - 1) / TIER_SIZE;
|
||||
|
TIER_DEPTHS[tier as usize]
|
||||
|
}
|
||||
|
|
||||
|
// ITERATORS
|
||||
|
// --------------------------------------------------------------------------------------------
|
||||
|
|
||||
|
/// Returns an iterator over all inner nodes of the Tiered Sparse Merkle tree (i.e., nodes not
|
||||
|
/// at depths 16 32, 48, or 64).
|
||||
|
///
|
||||
|
/// The iterator order is unspecified.
|
||||
|
pub fn inner_nodes(&self) -> impl Iterator<Item = InnerNodeInfo> + '_ {
|
||||
|
self.nodes.iter().filter_map(|(index, node)| {
|
||||
|
if self.is_internal_node(index) {
|
||||
|
Some(InnerNodeInfo {
|
||||
|
value: *node,
|
||||
|
left: self.get_node_unchecked(&index.left_child()),
|
||||
|
right: self.get_node_unchecked(&index.right_child()),
|
||||
|
})
|
||||
|
} else {
|
||||
|
None
|
||||
|
}
|
||||
|
})
|
||||
|
}
|
||||
|
|
||||
|
/// Returns an iterator over the upper leaves (i.e., leaves with depths 16, 32, 48) of the
|
||||
|
/// Tiered Sparse Merkle tree.
|
||||
|
pub fn upper_leaves(&self) -> impl Iterator<Item = (&NodeIndex, &RpoDigest)> {
|
||||
|
self.upper_leaves.iter().map(|index| (index, &self.nodes[index]))
|
||||
|
}
|
||||
|
|
||||
|
/// Returns an iterator over the bottom leaves (i.e., leaves with depth 64) of the Tiered
|
||||
|
/// Sparse Merkle tree.
|
||||
|
pub fn bottom_leaves(&self) -> impl Iterator<Item = (&u64, &RpoDigest)> {
|
||||
|
self.bottom_leaves.iter().map(|value| {
|
||||
|
let index = NodeIndex::new_unchecked(MAX_DEPTH, *value);
|
||||
|
(value, &self.nodes[&index])
|
||||
|
})
|
||||
|
}
|
||||
|
|
||||
|
// STATE MUTATORS
|
||||
|
// --------------------------------------------------------------------------------------------
|
||||
|
|
||||
|
/// Replaces the leaf node at the specified index with a tree consisting of two leaves located
|
||||
|
/// at the specified indexes. Recomputes and returns the new root.
|
||||
|
pub fn replace_leaf_with_subtree(
|
||||
|
&mut self,
|
||||
|
leaf_index: LeafNodeIndex,
|
||||
|
subtree_leaves: [(LeafNodeIndex, RpoDigest); 2],
|
||||
|
) -> RpoDigest {
|
||||
|
debug_assert!(self.is_non_empty_leaf(&leaf_index));
|
||||
|
debug_assert!(!is_empty_root(&subtree_leaves[0].1));
|
||||
|
debug_assert!(!is_empty_root(&subtree_leaves[1].1));
|
||||
|
debug_assert_eq!(subtree_leaves[0].0.depth(), subtree_leaves[1].0.depth());
|
||||
|
debug_assert!(leaf_index.depth() < subtree_leaves[0].0.depth());
|
||||
|
|
||||
|
self.upper_leaves.remove(&leaf_index);
|
||||
|
|
||||
|
if subtree_leaves[0].0 == subtree_leaves[1].0 {
|
||||
|
// if the subtree is for a single node at depth 64, we only need to insert one node
|
||||
|
debug_assert_eq!(subtree_leaves[0].0.depth(), MAX_DEPTH);
|
||||
|
debug_assert_eq!(subtree_leaves[0].1, subtree_leaves[1].1);
|
||||
|
self.insert_leaf_node(subtree_leaves[0].0, subtree_leaves[0].1)
|
||||
|
} else {
|
||||
|
self.insert_leaf_node(subtree_leaves[0].0, subtree_leaves[0].1);
|
||||
|
self.insert_leaf_node(subtree_leaves[1].0, subtree_leaves[1].1)
|
||||
|
}
|
||||
|
}
|
||||
|
|
||||
|
/// Replaces a subtree containing the retained and the removed leaf nodes, with a leaf node
|
||||
|
/// containing the retained leaf.
|
||||
|
///
|
||||
|
/// This has the effect of deleting the the node at the `removed_leaf` index from the tree,
|
||||
|
/// moving the node at the `retained_leaf` index up to the tier specified by `new_depth`.
|
||||
|
pub fn replace_subtree_with_leaf(
|
||||
|
&mut self,
|
||||
|
removed_leaf: LeafNodeIndex,
|
||||
|
retained_leaf: LeafNodeIndex,
|
||||
|
new_depth: u8,
|
||||
|
node: RpoDigest,
|
||||
|
) -> RpoDigest {
|
||||
|
debug_assert!(!is_empty_root(&node));
|
||||
|
debug_assert!(self.is_non_empty_leaf(&removed_leaf));
|
||||
|
debug_assert!(self.is_non_empty_leaf(&retained_leaf));
|
||||
|
debug_assert_eq!(removed_leaf.depth(), retained_leaf.depth());
|
||||
|
debug_assert!(removed_leaf.depth() > new_depth);
|
||||
|
|
||||
|
// remove the branches leading up to the tier to which the retained leaf is to be moved
|
||||
|
self.remove_branch(removed_leaf, new_depth);
|
||||
|
self.remove_branch(retained_leaf, new_depth);
|
||||
|
|
||||
|
// compute the index of the common root for retained and removed leaves
|
||||
|
let mut new_index = retained_leaf;
|
||||
|
new_index.move_up_to(new_depth);
|
||||
|
|
||||
|
// insert the node at the root index
|
||||
|
self.insert_leaf_node(new_index, node)
|
||||
|
}
|
||||
|
|
||||
|
/// Inserts the specified node at the specified index; recomputes and returns the new root
|
||||
|
/// of the Tiered Sparse Merkle tree.
|
||||
|
///
|
||||
|
/// This method assumes that the provided node is a non-empty value, and that there is no node
|
||||
|
/// at the specified index.
|
||||
|
pub fn insert_leaf_node(&mut self, index: LeafNodeIndex, mut node: RpoDigest) -> RpoDigest {
|
||||
|
debug_assert!(!is_empty_root(&node));
|
||||
|
debug_assert_eq!(self.nodes.get(&index), None);
|
||||
|
|
||||
|
// mark the node as the leaf
|
||||
|
if index.depth() == MAX_DEPTH {
|
||||
|
self.bottom_leaves.insert(index.value());
|
||||
|
} else {
|
||||
|
self.upper_leaves.insert(index.into());
|
||||
|
};
|
||||
|
|
||||
|
// insert the node and update the path from the node to the root
|
||||
|
let mut index: NodeIndex = index.into();
|
||||
|
for _ in 0..index.depth() {
|
||||
|
self.nodes.insert(index, node);
|
||||
|
let sibling = self.get_node_unchecked(&index.sibling());
|
||||
|
node = Rpo256::merge(&index.build_node(node, sibling));
|
||||
|
index.move_up();
|
||||
|
}
|
||||
|
|
||||
|
// update the root
|
||||
|
self.nodes.insert(NodeIndex::root(), node);
|
||||
|
node
|
||||
|
}
|
||||
|
|
||||
|
/// Updates the node at the specified index with the specified node value; recomputes and
|
||||
|
/// returns the new root of the Tiered Sparse Merkle tree.
|
||||
|
///
|
||||
|
/// This method can accept `node` as either an empty or a non-empty value.
|
||||
|
pub fn update_leaf_node(&mut self, index: LeafNodeIndex, mut node: RpoDigest) -> RpoDigest {
|
||||
|
debug_assert!(self.is_non_empty_leaf(&index));
|
||||
|
|
||||
|
// if the value we are updating the node to is a root of an empty tree, clear the leaf
|
||||
|
// flag for this node
|
||||
|
if node == EmptySubtreeRoots::empty_hashes(MAX_DEPTH)[index.depth() as usize] {
|
||||
|
if index.depth() == MAX_DEPTH {
|
||||
|
self.bottom_leaves.remove(&index.value());
|
||||
|
} else {
|
||||
|
self.upper_leaves.remove(&index);
|
||||
|
}
|
||||
|
} else {
|
||||
|
debug_assert!(!is_empty_root(&node));
|
||||
|
}
|
||||
|
|
||||
|
// update the path from the node to the root
|
||||
|
let mut index: NodeIndex = index.into();
|
||||
|
for _ in 0..index.depth() {
|
||||
|
if node == EmptySubtreeRoots::empty_hashes(MAX_DEPTH)[index.depth() as usize] {
|
||||
|
self.nodes.remove(&index);
|
||||
|
} else {
|
||||
|
self.nodes.insert(index, node);
|
||||
|
}
|
||||
|
|
||||
|
let sibling = self.get_node_unchecked(&index.sibling());
|
||||
|
node = Rpo256::merge(&index.build_node(node, sibling));
|
||||
|
index.move_up();
|
||||
|
}
|
||||
|
|
||||
|
// update the root
|
||||
|
self.nodes.insert(NodeIndex::root(), node);
|
||||
|
node
|
||||
|
}
|
||||
|
|
||||
|
/// Replaces the leaf node at the specified index with a root of an empty subtree; recomputes
|
||||
|
/// and returns the new root of the Tiered Sparse Merkle tree.
|
||||
|
pub fn clear_leaf_node(&mut self, index: LeafNodeIndex) -> RpoDigest {
|
||||
|
debug_assert!(self.is_non_empty_leaf(&index));
|
||||
|
let node = EmptySubtreeRoots::empty_hashes(MAX_DEPTH)[index.depth() as usize];
|
||||
|
self.update_leaf_node(index, node)
|
||||
|
}
|
||||
|
|
||||
|
/// Truncates a branch starting with specified leaf at the bottom tier to new depth.
|
||||
|
///
|
||||
|
/// This involves removing the part of the branch below the new depth, and then inserting a new
|
||||
|
/// // node at the new depth.
|
||||
|
pub fn truncate_branch(
|
||||
|
&mut self,
|
||||
|
leaf_index: u64,
|
||||
|
new_depth: u8,
|
||||
|
node: RpoDigest,
|
||||
|
) -> RpoDigest {
|
||||
|
debug_assert!(self.bottom_leaves.contains(&leaf_index));
|
||||
|
|
||||
|
let mut leaf_index = LeafNodeIndex::new(NodeIndex::new_unchecked(MAX_DEPTH, leaf_index));
|
||||
|
self.remove_branch(leaf_index, new_depth);
|
||||
|
|
||||
|
leaf_index.move_up_to(new_depth);
|
||||
|
self.insert_leaf_node(leaf_index, node)
|
||||
|
}
|
||||
|
|
||||
|
// HELPER METHODS
|
||||
|
// --------------------------------------------------------------------------------------------
|
||||
|
|
||||
|
/// Returns true if the node at the specified index is a leaf node.
|
||||
|
fn is_non_empty_leaf(&self, index: &LeafNodeIndex) -> bool {
|
||||
|
if index.depth() == MAX_DEPTH {
|
||||
|
self.bottom_leaves.contains(&index.value())
|
||||
|
} else {
|
||||
|
self.upper_leaves.contains(index)
|
||||
|
}
|
||||
|
}
|
||||
|
|
||||
|
/// Returns true if the node at the specified index is an internal node - i.e., there is
|
||||
|
/// no leaf at that node and the node does not belong to the bottom tier.
|
||||
|
fn is_internal_node(&self, index: &NodeIndex) -> bool {
|
||||
|
if index.depth() == MAX_DEPTH {
|
||||
|
false
|
||||
|
} else {
|
||||
|
!self.upper_leaves.contains(index)
|
||||
|
}
|
||||
|
}
|
||||
|
|
||||
|
/// Checks if the specified index is valid in the context of this Merkle tree.
|
||||
|
///
|
||||
|
/// # Errors
|
||||
|
/// Returns an error if:
|
||||
|
/// - The specified index depth is 0 or greater than 64.
|
||||
|
/// - The node for the specified index does not exists in the Merkle tree. This is possible
|
||||
|
/// when an ancestors of the specified index is a leaf node.
|
||||
|
fn validate_node_access(&self, index: NodeIndex) -> Result<(), MerkleError> {
|
||||
|
if index.is_root() {
|
||||
|
return Err(MerkleError::DepthTooSmall(index.depth()));
|
||||
|
} else if index.depth() > MAX_DEPTH {
|
||||
|
return Err(MerkleError::DepthTooBig(index.depth() as u64));
|
||||
|
} else {
|
||||
|
// make sure that there are no leaf nodes in the ancestors of the index; since leaf
|
||||
|
// nodes can live at specific depth, we just need to check these depths.
|
||||
|
let tier = ((index.depth() - 1) / TIER_SIZE) as usize;
|
||||
|
let mut tier_index = index;
|
||||
|
for &depth in TIER_DEPTHS[..tier].iter().rev() {
|
||||
|
tier_index.move_up_to(depth);
|
||||
|
if self.upper_leaves.contains(&tier_index) {
|
||||
|
return Err(MerkleError::NodeNotInSet(index));
|
||||
|
}
|
||||
|
}
|
||||
|
}
|
||||
|
|
||||
|
Ok(())
|
||||
|
}
|
||||
|
|
||||
|
/// Returns a node at the specified index. If the node does not exist at this index, a root
|
||||
|
/// for an empty subtree at the index's depth is returned.
|
||||
|
///
|
||||
|
/// Unlike [NodeStore::get_node()] this does not perform any checks to verify that the
|
||||
|
/// returned node is valid in the context of this tree.
|
||||
|
fn get_node_unchecked(&self, index: &NodeIndex) -> RpoDigest {
|
||||
|
match self.nodes.get(index) {
|
||||
|
Some(node) => *node,
|
||||
|
None => EmptySubtreeRoots::empty_hashes(MAX_DEPTH)[index.depth() as usize],
|
||||
|
}
|
||||
|
}
|
||||
|
|
||||
|
/// Removes a sequence of nodes starting at the specified index and traversing the tree up to
|
||||
|
/// the specified depth. The node at the `end_depth` is also removed, and the appropriate leaf
|
||||
|
/// flag is cleared.
|
||||
|
///
|
||||
|
/// This method does not update any other nodes and does not recompute the tree root.
|
||||
|
fn remove_branch(&mut self, index: LeafNodeIndex, end_depth: u8) {
|
||||
|
if index.depth() == MAX_DEPTH {
|
||||
|
self.bottom_leaves.remove(&index.value());
|
||||
|
} else {
|
||||
|
self.upper_leaves.remove(&index);
|
||||
|
}
|
||||
|
|
||||
|
let mut index: NodeIndex = index.into();
|
||||
|
assert!(index.depth() > end_depth);
|
||||
|
for _ in 0..(index.depth() - end_depth + 1) {
|
||||
|
self.nodes.remove(&index);
|
||||
|
index.move_up()
|
||||
|
}
|
||||
|
}
|
||||
|
}
|
||||
|
|
||||
|
// HELPER FUNCTIONS
|
||||
|
// ================================================================================================
|
||||
|
|
||||
|
/// Returns true if the specified node is a root of an empty tree or an empty value ([ZERO; 4]).
|
||||
|
fn is_empty_root(node: &RpoDigest) -> bool {
|
||||
|
EmptySubtreeRoots::empty_hashes(MAX_DEPTH).contains(node)
|
||||
|
}
|
@ -0,0 +1,162 @@ |
|||||
|
use super::{
|
||||
|
get_common_prefix_tier_depth, get_key_prefix, hash_bottom_leaf, hash_upper_leaf,
|
||||
|
EmptySubtreeRoots, LeafNodeIndex, MerklePath, RpoDigest, TieredSmtProofError, Vec, Word,
|
||||
|
};
|
||||
|
|
||||
|
// CONSTANTS
|
||||
|
// ================================================================================================
|
||||
|
|
||||
|
/// Maximum node depth. This is also the bottom tier of the tree.
|
||||
|
const MAX_DEPTH: u8 = super::TieredSmt::MAX_DEPTH;
|
||||
|
|
||||
|
/// Value of an empty leaf.
|
||||
|
pub const EMPTY_VALUE: Word = super::TieredSmt::EMPTY_VALUE;
|
||||
|
|
||||
|
/// Depths at which leaves can exist in a tiered SMT.
|
||||
|
pub const TIER_DEPTHS: [u8; 4] = super::TieredSmt::TIER_DEPTHS;
|
||||
|
|
||||
|
// TIERED SPARSE MERKLE TREE PROOF
|
||||
|
// ================================================================================================
|
||||
|
|
||||
|
/// A proof which can be used to assert membership (or non-membership) of key-value pairs in a
|
||||
|
/// Tiered Sparse Merkle tree.
|
||||
|
///
|
||||
|
/// The proof consists of a Merkle path and one or more key-value entries which describe the node
|
||||
|
/// located at the base of the path. If the node at the base of the path resolves to [ZERO; 4],
|
||||
|
/// the entries will contain a single item with value set to [ZERO; 4].
|
||||
|
#[derive(PartialEq, Eq, Debug)]
|
||||
|
pub struct TieredSmtProof {
|
||||
|
path: MerklePath,
|
||||
|
entries: Vec<(RpoDigest, Word)>,
|
||||
|
}
|
||||
|
|
||||
|
impl TieredSmtProof {
|
||||
|
// CONSTRUCTOR
|
||||
|
// --------------------------------------------------------------------------------------------
|
||||
|
/// Returns a new instance of [TieredSmtProof] instantiated from the specified path and entries.
|
||||
|
///
|
||||
|
/// # Panics
|
||||
|
/// Panics if:
|
||||
|
/// - The length of the path is greater than 64.
|
||||
|
/// - Entries is an empty vector.
|
||||
|
/// - Entries contains more than 1 item, but the length of the path is not 64.
|
||||
|
/// - Entries contains more than 1 item, and one of the items has value set to [ZERO; 4].
|
||||
|
/// - Entries contains multiple items with keys which don't share the same 64-bit prefix.
|
||||
|
pub fn new<I>(path: MerklePath, entries: I) -> Result<Self, TieredSmtProofError>
|
||||
|
where
|
||||
|
I: IntoIterator<Item = (RpoDigest, Word)>,
|
||||
|
{
|
||||
|
let entries: Vec<(RpoDigest, Word)> = entries.into_iter().collect();
|
||||
|
|
||||
|
if !TIER_DEPTHS.into_iter().any(|e| e == path.depth()) {
|
||||
|
return Err(TieredSmtProofError::NotATierPath(path.depth()));
|
||||
|
}
|
||||
|
|
||||
|
if entries.is_empty() {
|
||||
|
return Err(TieredSmtProofError::EntriesEmpty);
|
||||
|
}
|
||||
|
|
||||
|
if entries.len() > 1 {
|
||||
|
if path.depth() != MAX_DEPTH {
|
||||
|
return Err(TieredSmtProofError::MultipleEntriesOutsideLastTier);
|
||||
|
}
|
||||
|
|
||||
|
let prefix = get_key_prefix(&entries[0].0);
|
||||
|
for entry in entries.iter().skip(1) {
|
||||
|
if entry.1 == EMPTY_VALUE {
|
||||
|
return Err(TieredSmtProofError::EmptyValueNotAllowed);
|
||||
|
}
|
||||
|
let current = get_key_prefix(&entry.0);
|
||||
|
if prefix != current {
|
||||
|
return Err(TieredSmtProofError::MismatchedPrefixes(prefix, current));
|
||||
|
}
|
||||
|
}
|
||||
|
}
|
||||
|
|
||||
|
Ok(Self { path, entries })
|
||||
|
}
|
||||
|
|
||||
|
// PROOF VERIFIER
|
||||
|
// --------------------------------------------------------------------------------------------
|
||||
|
|
||||
|
/// Returns true if a Tiered Sparse Merkle tree with the specified root contains the provided
|
||||
|
/// key-value pair.
|
||||
|
///
|
||||
|
/// Note: this method cannot be used to assert non-membership. That is, if false is returned,
|
||||
|
/// it does not mean that the provided key-value pair is not in the tree.
|
||||
|
pub fn verify_membership(&self, key: &RpoDigest, value: &Word, root: &RpoDigest) -> bool {
|
||||
|
if self.is_value_empty() {
|
||||
|
if value != &EMPTY_VALUE {
|
||||
|
return false;
|
||||
|
}
|
||||
|
// if the proof is for an empty value, we can verify it against any key which has a
|
||||
|
// common prefix with the key storied in entries, but the prefix must be greater than
|
||||
|
// the path length
|
||||
|
let common_prefix_tier = get_common_prefix_tier_depth(key, &self.entries[0].0);
|
||||
|
if common_prefix_tier < self.path.depth() {
|
||||
|
return false;
|
||||
|
}
|
||||
|
} else if !self.entries.contains(&(*key, *value)) {
|
||||
|
return false;
|
||||
|
}
|
||||
|
|
||||
|
// make sure the Merkle path resolves to the correct root
|
||||
|
root == &self.compute_root()
|
||||
|
}
|
||||
|
|
||||
|
// PUBLIC ACCESSORS
|
||||
|
// --------------------------------------------------------------------------------------------
|
||||
|
|
||||
|
/// Returns the value associated with the specific key according to this proof, or None if
|
||||
|
/// this proof does not contain a value for the specified key.
|
||||
|
///
|
||||
|
/// A key-value pair generated by using this method should pass the `verify_membership()` check.
|
||||
|
pub fn get(&self, key: &RpoDigest) -> Option<Word> {
|
||||
|
if self.is_value_empty() {
|
||||
|
let common_prefix_tier = get_common_prefix_tier_depth(key, &self.entries[0].0);
|
||||
|
if common_prefix_tier < self.path.depth() {
|
||||
|
None
|
||||
|
} else {
|
||||
|
Some(EMPTY_VALUE)
|
||||
|
}
|
||||
|
} else {
|
||||
|
self.entries.iter().find(|(k, _)| k == key).map(|(_, value)| *value)
|
||||
|
}
|
||||
|
}
|
||||
|
|
||||
|
/// Computes the root of a Tiered Sparse Merkle tree to which this proof resolve.
|
||||
|
pub fn compute_root(&self) -> RpoDigest {
|
||||
|
let node = self.build_node();
|
||||
|
let index = LeafNodeIndex::from_key(&self.entries[0].0, self.path.depth());
|
||||
|
self.path
|
||||
|
.compute_root(index.value(), node)
|
||||
|
.expect("failed to compute Merkle path root")
|
||||
|
}
|
||||
|
|
||||
|
/// Consume the proof and returns its parts.
|
||||
|
pub fn into_parts(self) -> (MerklePath, Vec<(RpoDigest, Word)>) {
|
||||
|
(self.path, self.entries)
|
||||
|
}
|
||||
|
|
||||
|
// HELPER METHODS
|
||||
|
// --------------------------------------------------------------------------------------------
|
||||
|
|
||||
|
/// Returns true if the proof is for an empty value.
|
||||
|
fn is_value_empty(&self) -> bool {
|
||||
|
self.entries[0].1 == EMPTY_VALUE
|
||||
|
}
|
||||
|
|
||||
|
/// Converts the entries contained in this proof into a node value for node at the base of the
|
||||
|
/// path contained in this proof.
|
||||
|
fn build_node(&self) -> RpoDigest {
|
||||
|
let depth = self.path.depth();
|
||||
|
if self.is_value_empty() {
|
||||
|
EmptySubtreeRoots::empty_hashes(MAX_DEPTH)[depth as usize]
|
||||
|
} else if depth == MAX_DEPTH {
|
||||
|
hash_bottom_leaf(&self.entries)
|
||||
|
} else {
|
||||
|
let (key, value) = self.entries[0];
|
||||
|
hash_upper_leaf(key, value, depth)
|
||||
|
}
|
||||
|
}
|
||||
|
}
|
@ -0,0 +1,584 @@ |
|||||
|
use super::{get_key_prefix, BTreeMap, LeafNodeIndex, RpoDigest, StarkField, Vec, Word};
|
||||
|
use crate::utils::vec;
|
||||
|
use core::{
|
||||
|
cmp::{Ord, Ordering},
|
||||
|
ops::RangeBounds,
|
||||
|
};
|
||||
|
use winter_utils::collections::btree_map::Entry;
|
||||
|
|
||||
|
// CONSTANTS
|
||||
|
// ================================================================================================
|
||||
|
|
||||
|
/// Depths at which leaves can exist in a tiered SMT.
|
||||
|
const TIER_DEPTHS: [u8; 4] = super::TieredSmt::TIER_DEPTHS;
|
||||
|
|
||||
|
/// Maximum node depth. This is also the bottom tier of the tree.
|
||||
|
const MAX_DEPTH: u8 = super::TieredSmt::MAX_DEPTH;
|
||||
|
|
||||
|
// VALUE STORE
|
||||
|
// ================================================================================================
|
||||
|
/// A store for key-value pairs for a Tiered Sparse Merkle tree.
|
||||
|
///
|
||||
|
/// The store is organized in a [BTreeMap] where keys are 64 most significant bits of a key, and
|
||||
|
/// the values are the corresponding key-value pairs (or a list of key-value pairs if more that
|
||||
|
/// a single key-value pair shares the same 64-bit prefix).
|
||||
|
///
|
||||
|
/// The store supports lookup by the full key (i.e. [RpoDigest]) as well as by the 64-bit key
|
||||
|
/// prefix.
|
||||
|
#[derive(Debug, Default, Clone, PartialEq, Eq)]
|
||||
|
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
|
||||
|
pub struct ValueStore {
|
||||
|
values: BTreeMap<u64, StoreEntry>,
|
||||
|
}
|
||||
|
|
||||
|
impl ValueStore {
|
||||
|
// PUBLIC ACCESSORS
|
||||
|
// --------------------------------------------------------------------------------------------
|
||||
|
|
||||
|
/// Returns a reference to the value stored under the specified key, or None if there is no
|
||||
|
/// value associated with the specified key.
|
||||
|
pub fn get(&self, key: &RpoDigest) -> Option<&Word> {
|
||||
|
let prefix = get_key_prefix(key);
|
||||
|
self.values.get(&prefix).and_then(|entry| entry.get(key))
|
||||
|
}
|
||||
|
|
||||
|
/// Returns the first key-value pair such that the key prefix is greater than or equal to the
|
||||
|
/// specified prefix.
|
||||
|
pub fn get_first(&self, prefix: u64) -> Option<&(RpoDigest, Word)> {
|
||||
|
self.range(prefix..).next()
|
||||
|
}
|
||||
|
|
||||
|
/// Returns the first key-value pair such that the key prefix is greater than or equal to the
|
||||
|
/// specified prefix and the key value is not equal to the exclude_key value.
|
||||
|
pub fn get_first_filtered(
|
||||
|
&self,
|
||||
|
prefix: u64,
|
||||
|
exclude_key: &RpoDigest,
|
||||
|
) -> Option<&(RpoDigest, Word)> {
|
||||
|
self.range(prefix..).find(|(key, _)| key != exclude_key)
|
||||
|
}
|
||||
|
|
||||
|
/// Returns a vector with key-value pairs for all keys with the specified 64-bit prefix, or
|
||||
|
/// None if no keys with the specified prefix are present in this store.
|
||||
|
pub fn get_all(&self, prefix: u64) -> Option<Vec<(RpoDigest, Word)>> {
|
||||
|
self.values.get(&prefix).map(|entry| match entry {
|
||||
|
StoreEntry::Single(kv_pair) => vec![*kv_pair],
|
||||
|
StoreEntry::List(kv_pairs) => kv_pairs.clone(),
|
||||
|
})
|
||||
|
}
|
||||
|
|
||||
|
/// Returns information about a sibling of a leaf node with the specified index, but only if
|
||||
|
/// this is the only sibling the leaf has in some subtree starting at the first tier.
|
||||
|
///
|
||||
|
/// For example, if `index` is an index at depth 32, and there is a leaf node at depth 32 with
|
||||
|
/// the same root at depth 16 as `index`, we say that this leaf is a lone sibling.
|
||||
|
///
|
||||
|
/// The returned tuple contains: they key-value pair of the sibling as well as the index of
|
||||
|
/// the node for the root of the common subtree in which both nodes are leaves.
|
||||
|
///
|
||||
|
/// This method assumes that the key-value pair for the specified index has already been
|
||||
|
/// removed from the store.
|
||||
|
pub fn get_lone_sibling(
|
||||
|
&self,
|
||||
|
index: LeafNodeIndex,
|
||||
|
) -> Option<(&RpoDigest, &Word, LeafNodeIndex)> {
|
||||
|
// iterate over tiers from top to bottom, looking at the tiers which are strictly above
|
||||
|
// the depth of the index. This implies that only tiers at depth 32 and 48 will be
|
||||
|
// considered. For each tier, check if the parent of the index at the higher tier
|
||||
|
// contains a single node. The fist tier (depth 16) is excluded because we cannot move
|
||||
|
// nodes at depth 16 to a higher tier. This implies that nodes at the first tier will
|
||||
|
// never have "lone siblings".
|
||||
|
for &tier_depth in TIER_DEPTHS.iter().filter(|&t| index.depth() > *t) {
|
||||
|
// compute the index of the root at a higher tier
|
||||
|
let mut parent_index = index;
|
||||
|
parent_index.move_up_to(tier_depth);
|
||||
|
|
||||
|
// find the lone sibling, if any; we need to handle the "last node" at a given tier
|
||||
|
// separately specify the bounds for the search correctly.
|
||||
|
let start_prefix = parent_index.value() << (MAX_DEPTH - tier_depth);
|
||||
|
let sibling = if start_prefix.leading_ones() as u8 == tier_depth {
|
||||
|
let mut iter = self.range(start_prefix..);
|
||||
|
iter.next().filter(|_| iter.next().is_none())
|
||||
|
} else {
|
||||
|
let end_prefix = (parent_index.value() + 1) << (MAX_DEPTH - tier_depth);
|
||||
|
let mut iter = self.range(start_prefix..end_prefix);
|
||||
|
iter.next().filter(|_| iter.next().is_none())
|
||||
|
};
|
||||
|
|
||||
|
if let Some((key, value)) = sibling {
|
||||
|
return Some((key, value, parent_index));
|
||||
|
}
|
||||
|
}
|
||||
|
|
||||
|
None
|
||||
|
}
|
||||
|
|
||||
|
/// Returns an iterator over all key-value pairs in this store.
|
||||
|
pub fn iter(&self) -> impl Iterator<Item = &(RpoDigest, Word)> {
|
||||
|
self.values.iter().flat_map(|(_, entry)| entry.iter())
|
||||
|
}
|
||||
|
|
||||
|
// STATE MUTATORS
|
||||
|
// --------------------------------------------------------------------------------------------
|
||||
|
|
||||
|
/// Inserts the specified key-value pair into this store and returns the value previously
|
||||
|
/// associated with the specified key.
|
||||
|
///
|
||||
|
/// If no value was previously associated with the specified key, None is returned.
|
||||
|
pub fn insert(&mut self, key: RpoDigest, value: Word) -> Option<Word> {
|
||||
|
let prefix = get_key_prefix(&key);
|
||||
|
match self.values.entry(prefix) {
|
||||
|
Entry::Occupied(mut entry) => entry.get_mut().insert(key, value),
|
||||
|
Entry::Vacant(entry) => {
|
||||
|
entry.insert(StoreEntry::new(key, value));
|
||||
|
None
|
||||
|
}
|
||||
|
}
|
||||
|
}
|
||||
|
|
||||
|
/// Removes the key-value pair for the specified key from this store and returns the value
|
||||
|
/// associated with this key.
|
||||
|
///
|
||||
|
/// If no value was associated with the specified key, None is returned.
|
||||
|
pub fn remove(&mut self, key: &RpoDigest) -> Option<Word> {
|
||||
|
let prefix = get_key_prefix(key);
|
||||
|
match self.values.entry(prefix) {
|
||||
|
Entry::Occupied(mut entry) => {
|
||||
|
let (value, remove_entry) = entry.get_mut().remove(key);
|
||||
|
if remove_entry {
|
||||
|
entry.remove_entry();
|
||||
|
}
|
||||
|
value
|
||||
|
}
|
||||
|
Entry::Vacant(_) => None,
|
||||
|
}
|
||||
|
}
|
||||
|
|
||||
|
// HELPER METHODS
|
||||
|
// --------------------------------------------------------------------------------------------
|
||||
|
|
||||
|
/// Returns an iterator over all key-value pairs contained in this store such that the most
|
||||
|
/// significant 64 bits of the key lay within the specified bounds.
|
||||
|
///
|
||||
|
/// The order of iteration is from the smallest to the largest key.
|
||||
|
fn range<R: RangeBounds<u64>>(&self, bounds: R) -> impl Iterator<Item = &(RpoDigest, Word)> {
|
||||
|
self.values.range(bounds).flat_map(|(_, entry)| entry.iter())
|
||||
|
}
|
||||
|
}
|
||||
|
|
||||
|
// VALUE NODE
|
||||
|
// ================================================================================================
|
||||
|
|
||||
|
/// An entry in the [ValueStore].
|
||||
|
///
|
||||
|
/// An entry can contain either a single key-value pair or a vector of key-value pairs sorted by
|
||||
|
/// key.
|
||||
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
|
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
|
||||
|
pub enum StoreEntry {
|
||||
|
Single((RpoDigest, Word)),
|
||||
|
List(Vec<(RpoDigest, Word)>),
|
||||
|
}
|
||||
|
|
||||
|
impl StoreEntry {
|
||||
|
// CONSTRUCTOR
|
||||
|
// --------------------------------------------------------------------------------------------
|
||||
|
/// Returns a new [StoreEntry] instantiated with a single key-value pair.
|
||||
|
pub fn new(key: RpoDigest, value: Word) -> Self {
|
||||
|
Self::Single((key, value))
|
||||
|
}
|
||||
|
|
||||
|
// PUBLIC ACCESSORS
|
||||
|
// --------------------------------------------------------------------------------------------
|
||||
|
|
||||
|
/// Returns the value associated with the specified key, or None if this entry does not contain
|
||||
|
/// a value associated with the specified key.
|
||||
|
pub fn get(&self, key: &RpoDigest) -> Option<&Word> {
|
||||
|
match self {
|
||||
|
StoreEntry::Single(kv_pair) => {
|
||||
|
if kv_pair.0 == *key {
|
||||
|
Some(&kv_pair.1)
|
||||
|
} else {
|
||||
|
None
|
||||
|
}
|
||||
|
}
|
||||
|
StoreEntry::List(kv_pairs) => {
|
||||
|
match kv_pairs.binary_search_by(|kv_pair| cmp_digests(&kv_pair.0, key)) {
|
||||
|
Ok(pos) => Some(&kv_pairs[pos].1),
|
||||
|
Err(_) => None,
|
||||
|
}
|
||||
|
}
|
||||
|
}
|
||||
|
}
|
||||
|
|
||||
|
/// Returns an iterator over all key-value pairs in this entry.
|
||||
|
pub fn iter(&self) -> impl Iterator<Item = &(RpoDigest, Word)> {
|
||||
|
EntryIterator { entry: self, pos: 0 }
|
||||
|
}
|
||||
|
|
||||
|
// STATE MUTATORS
|
||||
|
// --------------------------------------------------------------------------------------------
|
||||
|
|
||||
|
/// Inserts the specified key-value pair into this entry and returns the value previously
|
||||
|
/// associated with the specified key, or None if no value was associated with the specified
|
||||
|
/// key.
|
||||
|
///
|
||||
|
/// If a new key is inserted, this will also transform a `SingleEntry` into a `ListEntry`.
|
||||
|
pub fn insert(&mut self, key: RpoDigest, value: Word) -> Option<Word> {
|
||||
|
match self {
|
||||
|
StoreEntry::Single(kv_pair) => {
|
||||
|
// if the key is already in this entry, update the value and return
|
||||
|
if kv_pair.0 == key {
|
||||
|
let old_value = kv_pair.1;
|
||||
|
kv_pair.1 = value;
|
||||
|
return Some(old_value);
|
||||
|
}
|
||||
|
|
||||
|
// transform the entry into a list entry, and make sure the key-value pairs
|
||||
|
// are sorted by key
|
||||
|
let mut pairs = vec![*kv_pair, (key, value)];
|
||||
|
pairs.sort_by(|a, b| cmp_digests(&a.0, &b.0));
|
||||
|
|
||||
|
*self = StoreEntry::List(pairs);
|
||||
|
None
|
||||
|
}
|
||||
|
StoreEntry::List(pairs) => {
|
||||
|
match pairs.binary_search_by(|kv_pair| cmp_digests(&kv_pair.0, &key)) {
|
||||
|
Ok(pos) => {
|
||||
|
let old_value = pairs[pos].1;
|
||||
|
pairs[pos].1 = value;
|
||||
|
Some(old_value)
|
||||
|
}
|
||||
|
Err(pos) => {
|
||||
|
pairs.insert(pos, (key, value));
|
||||
|
None
|
||||
|
}
|
||||
|
}
|
||||
|
}
|
||||
|
}
|
||||
|
}
|
||||
|
|
||||
|
/// Removes the key-value pair with the specified key from this entry, and returns the value
|
||||
|
/// of the removed pair. If the entry did not contain a key-value pair for the specified key,
|
||||
|
/// None is returned.
|
||||
|
///
|
||||
|
/// If the last last key-value pair was removed from the entry, the second tuple value will
|
||||
|
/// be set to true.
|
||||
|
pub fn remove(&mut self, key: &RpoDigest) -> (Option<Word>, bool) {
|
||||
|
match self {
|
||||
|
StoreEntry::Single(kv_pair) => {
|
||||
|
if kv_pair.0 == *key {
|
||||
|
(Some(kv_pair.1), true)
|
||||
|
} else {
|
||||
|
(None, false)
|
||||
|
}
|
||||
|
}
|
||||
|
StoreEntry::List(kv_pairs) => {
|
||||
|
match kv_pairs.binary_search_by(|kv_pair| cmp_digests(&kv_pair.0, key)) {
|
||||
|
Ok(pos) => {
|
||||
|
let kv_pair = kv_pairs.remove(pos);
|
||||
|
if kv_pairs.len() == 1 {
|
||||
|
*self = StoreEntry::Single(kv_pairs[0]);
|
||||
|
}
|
||||
|
(Some(kv_pair.1), false)
|
||||
|
}
|
||||
|
Err(_) => (None, false),
|
||||
|
}
|
||||
|
}
|
||||
|
}
|
||||
|
}
|
||||
|
}
|
||||
|
|
||||
|
/// A custom iterator over key-value pairs of a [StoreEntry].
|
||||
|
///
|
||||
|
/// For a `SingleEntry` this returns only one value, but for `ListEntry`, this iterates over the
|
||||
|
/// entire list of key-value pairs.
|
||||
|
pub struct EntryIterator<'a> {
|
||||
|
entry: &'a StoreEntry,
|
||||
|
pos: usize,
|
||||
|
}
|
||||
|
|
||||
|
impl<'a> Iterator for EntryIterator<'a> {
|
||||
|
type Item = &'a (RpoDigest, Word);
|
||||
|
|
||||
|
fn next(&mut self) -> Option<Self::Item> {
|
||||
|
match self.entry {
|
||||
|
StoreEntry::Single(kv_pair) => {
|
||||
|
if self.pos == 0 {
|
||||
|
self.pos = 1;
|
||||
|
Some(kv_pair)
|
||||
|
} else {
|
||||
|
None
|
||||
|
}
|
||||
|
}
|
||||
|
StoreEntry::List(kv_pairs) => {
|
||||
|
if self.pos >= kv_pairs.len() {
|
||||
|
None
|
||||
|
} else {
|
||||
|
let kv_pair = &kv_pairs[self.pos];
|
||||
|
self.pos += 1;
|
||||
|
Some(kv_pair)
|
||||
|
}
|
||||
|
}
|
||||
|
}
|
||||
|
}
|
||||
|
}
|
||||
|
|
||||
|
// HELPER FUNCTIONS
|
||||
|
// ================================================================================================
|
||||
|
|
||||
|
/// Compares two digests element-by-element using their integer representations starting with the
|
||||
|
/// most significant element.
|
||||
|
fn cmp_digests(d1: &RpoDigest, d2: &RpoDigest) -> Ordering {
|
||||
|
let d1 = Word::from(d1);
|
||||
|
let d2 = Word::from(d2);
|
||||
|
|
||||
|
for (v1, v2) in d1.iter().zip(d2.iter()).rev() {
|
||||
|
let v1 = v1.as_int();
|
||||
|
let v2 = v2.as_int();
|
||||
|
if v1 != v2 {
|
||||
|
return v1.cmp(&v2);
|
||||
|
}
|
||||
|
}
|
||||
|
|
||||
|
Ordering::Equal
|
||||
|
}
|
||||
|
|
||||
|
// TESTS
|
||||
|
// ================================================================================================
|
||||
|
|
||||
|
#[cfg(test)]
|
||||
|
mod tests {
|
||||
|
use super::{LeafNodeIndex, RpoDigest, StoreEntry, ValueStore};
|
||||
|
use crate::{Felt, ONE, WORD_SIZE, ZERO};
|
||||
|
|
||||
|
#[test]
|
||||
|
fn test_insert() {
|
||||
|
let mut store = ValueStore::default();
|
||||
|
|
||||
|
// insert the first key-value pair into the store
|
||||
|
let raw_a = 0b_10101010_10101010_00011111_11111111_10010110_10010011_11100000_00000000_u64;
|
||||
|
let key_a = RpoDigest::from([ZERO, ONE, ONE, Felt::new(raw_a)]);
|
||||
|
let value_a = [ONE; WORD_SIZE];
|
||||
|
|
||||
|
assert!(store.insert(key_a, value_a).is_none());
|
||||
|
assert_eq!(store.values.len(), 1);
|
||||
|
|
||||
|
let entry = store.values.get(&raw_a).unwrap();
|
||||
|
let expected_entry = StoreEntry::Single((key_a, value_a));
|
||||
|
assert_eq!(entry, &expected_entry);
|
||||
|
|
||||
|
// insert a key-value pair with a different key into the store; since the keys are
|
||||
|
// different, another entry is added to the values map
|
||||
|
let raw_b = 0b_11111110_10101010_00011111_11111111_10010110_10010011_11100000_00000000_u64;
|
||||
|
let key_b = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_b)]);
|
||||
|
let value_b = [ONE, ZERO, ONE, ZERO];
|
||||
|
|
||||
|
assert!(store.insert(key_b, value_b).is_none());
|
||||
|
assert_eq!(store.values.len(), 2);
|
||||
|
|
||||
|
let entry1 = store.values.get(&raw_a).unwrap();
|
||||
|
let expected_entry1 = StoreEntry::Single((key_a, value_a));
|
||||
|
assert_eq!(entry1, &expected_entry1);
|
||||
|
|
||||
|
let entry2 = store.values.get(&raw_b).unwrap();
|
||||
|
let expected_entry2 = StoreEntry::Single((key_b, value_b));
|
||||
|
assert_eq!(entry2, &expected_entry2);
|
||||
|
|
||||
|
// insert a key-value pair with the same 64-bit key prefix as the first key; this should
|
||||
|
// transform the first entry into a List entry
|
||||
|
let key_c = RpoDigest::from([ONE, ONE, ZERO, Felt::new(raw_a)]);
|
||||
|
let value_c = [ONE, ONE, ZERO, ZERO];
|
||||
|
|
||||
|
assert!(store.insert(key_c, value_c).is_none());
|
||||
|
assert_eq!(store.values.len(), 2);
|
||||
|
|
||||
|
let entry1 = store.values.get(&raw_a).unwrap();
|
||||
|
let expected_entry1 = StoreEntry::List(vec![(key_c, value_c), (key_a, value_a)]);
|
||||
|
assert_eq!(entry1, &expected_entry1);
|
||||
|
|
||||
|
let entry2 = store.values.get(&raw_b).unwrap();
|
||||
|
let expected_entry2 = StoreEntry::Single((key_b, value_b));
|
||||
|
assert_eq!(entry2, &expected_entry2);
|
||||
|
|
||||
|
// replace values for keys a and b
|
||||
|
let value_a2 = [ONE, ONE, ONE, ZERO];
|
||||
|
let value_b2 = [ZERO, ZERO, ZERO, ONE];
|
||||
|
|
||||
|
assert_eq!(store.insert(key_a, value_a2), Some(value_a));
|
||||
|
assert_eq!(store.values.len(), 2);
|
||||
|
|
||||
|
assert_eq!(store.insert(key_b, value_b2), Some(value_b));
|
||||
|
assert_eq!(store.values.len(), 2);
|
||||
|
|
||||
|
let entry1 = store.values.get(&raw_a).unwrap();
|
||||
|
let expected_entry1 = StoreEntry::List(vec![(key_c, value_c), (key_a, value_a2)]);
|
||||
|
assert_eq!(entry1, &expected_entry1);
|
||||
|
|
||||
|
let entry2 = store.values.get(&raw_b).unwrap();
|
||||
|
let expected_entry2 = StoreEntry::Single((key_b, value_b2));
|
||||
|
assert_eq!(entry2, &expected_entry2);
|
||||
|
|
||||
|
// insert one more key-value pair with the same 64-bit key-prefix as the first key
|
||||
|
let key_d = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_a)]);
|
||||
|
let value_d = [ZERO, ONE, ZERO, ZERO];
|
||||
|
|
||||
|
assert!(store.insert(key_d, value_d).is_none());
|
||||
|
assert_eq!(store.values.len(), 2);
|
||||
|
|
||||
|
let entry1 = store.values.get(&raw_a).unwrap();
|
||||
|
let expected_entry1 =
|
||||
|
StoreEntry::List(vec![(key_c, value_c), (key_a, value_a2), (key_d, value_d)]);
|
||||
|
assert_eq!(entry1, &expected_entry1);
|
||||
|
|
||||
|
let entry2 = store.values.get(&raw_b).unwrap();
|
||||
|
let expected_entry2 = StoreEntry::Single((key_b, value_b2));
|
||||
|
assert_eq!(entry2, &expected_entry2);
|
||||
|
}
|
||||
|
|
||||
|
#[test]
|
||||
|
fn test_remove() {
|
||||
|
// populate the value store
|
||||
|
let mut store = ValueStore::default();
|
||||
|
|
||||
|
let raw_a = 0b_10101010_10101010_00011111_11111111_10010110_10010011_11100000_00000000_u64;
|
||||
|
let key_a = RpoDigest::from([ZERO, ONE, ONE, Felt::new(raw_a)]);
|
||||
|
let value_a = [ONE; WORD_SIZE];
|
||||
|
store.insert(key_a, value_a);
|
||||
|
|
||||
|
let raw_b = 0b_11111110_10101010_00011111_11111111_10010110_10010011_11100000_00000000_u64;
|
||||
|
let key_b = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_b)]);
|
||||
|
let value_b = [ONE, ZERO, ONE, ZERO];
|
||||
|
store.insert(key_b, value_b);
|
||||
|
|
||||
|
let key_c = RpoDigest::from([ONE, ONE, ZERO, Felt::new(raw_a)]);
|
||||
|
let value_c = [ONE, ONE, ZERO, ZERO];
|
||||
|
store.insert(key_c, value_c);
|
||||
|
|
||||
|
let key_d = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_a)]);
|
||||
|
let value_d = [ZERO, ONE, ZERO, ZERO];
|
||||
|
store.insert(key_d, value_d);
|
||||
|
|
||||
|
assert_eq!(store.values.len(), 2);
|
||||
|
|
||||
|
let entry1 = store.values.get(&raw_a).unwrap();
|
||||
|
let expected_entry1 =
|
||||
|
StoreEntry::List(vec![(key_c, value_c), (key_a, value_a), (key_d, value_d)]);
|
||||
|
assert_eq!(entry1, &expected_entry1);
|
||||
|
|
||||
|
let entry2 = store.values.get(&raw_b).unwrap();
|
||||
|
let expected_entry2 = StoreEntry::Single((key_b, value_b));
|
||||
|
assert_eq!(entry2, &expected_entry2);
|
||||
|
|
||||
|
// remove non-existent keys
|
||||
|
let key_e = RpoDigest::from([ZERO, ZERO, ONE, Felt::new(raw_a)]);
|
||||
|
assert!(store.remove(&key_e).is_none());
|
||||
|
|
||||
|
let raw_f = 0b_11111110_11111111_00011111_11111111_10010110_10010011_11100000_00000000_u64;
|
||||
|
let key_f = RpoDigest::from([ZERO, ZERO, ONE, Felt::new(raw_f)]);
|
||||
|
assert!(store.remove(&key_f).is_none());
|
||||
|
|
||||
|
// remove keys from the list entry
|
||||
|
assert_eq!(store.remove(&key_c).unwrap(), value_c);
|
||||
|
let entry1 = store.values.get(&raw_a).unwrap();
|
||||
|
let expected_entry1 = StoreEntry::List(vec![(key_a, value_a), (key_d, value_d)]);
|
||||
|
assert_eq!(entry1, &expected_entry1);
|
||||
|
|
||||
|
assert_eq!(store.remove(&key_a).unwrap(), value_a);
|
||||
|
let entry1 = store.values.get(&raw_a).unwrap();
|
||||
|
let expected_entry1 = StoreEntry::Single((key_d, value_d));
|
||||
|
assert_eq!(entry1, &expected_entry1);
|
||||
|
|
||||
|
assert_eq!(store.remove(&key_d).unwrap(), value_d);
|
||||
|
assert!(store.values.get(&raw_a).is_none());
|
||||
|
assert_eq!(store.values.len(), 1);
|
||||
|
|
||||
|
// remove a key from a single entry
|
||||
|
assert_eq!(store.remove(&key_b).unwrap(), value_b);
|
||||
|
assert!(store.values.get(&raw_b).is_none());
|
||||
|
assert_eq!(store.values.len(), 0);
|
||||
|
}
|
||||
|
|
||||
|
#[test]
|
||||
|
fn test_range() {
|
||||
|
// populate the value store
|
||||
|
let mut store = ValueStore::default();
|
||||
|
|
||||
|
let raw_a = 0b_10101010_10101010_00011111_11111111_10010110_10010011_11100000_00000000_u64;
|
||||
|
let key_a = RpoDigest::from([ZERO, ONE, ONE, Felt::new(raw_a)]);
|
||||
|
let value_a = [ONE; WORD_SIZE];
|
||||
|
store.insert(key_a, value_a);
|
||||
|
|
||||
|
let raw_b = 0b_11111110_10101010_00011111_11111111_10010110_10010011_11100000_00000000_u64;
|
||||
|
let key_b = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_b)]);
|
||||
|
let value_b = [ONE, ZERO, ONE, ZERO];
|
||||
|
store.insert(key_b, value_b);
|
||||
|
|
||||
|
let key_c = RpoDigest::from([ONE, ONE, ZERO, Felt::new(raw_a)]);
|
||||
|
let value_c = [ONE, ONE, ZERO, ZERO];
|
||||
|
store.insert(key_c, value_c);
|
||||
|
|
||||
|
let key_d = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_a)]);
|
||||
|
let value_d = [ZERO, ONE, ZERO, ZERO];
|
||||
|
store.insert(key_d, value_d);
|
||||
|
|
||||
|
let raw_e = 0b_10101000_10101010_00011111_11111111_10010110_10010011_11100000_00000000_u64;
|
||||
|
let key_e = RpoDigest::from([ZERO, ONE, ONE, Felt::new(raw_e)]);
|
||||
|
let value_e = [ZERO, ZERO, ZERO, ONE];
|
||||
|
store.insert(key_e, value_e);
|
||||
|
|
||||
|
// check the entire range
|
||||
|
let mut iter = store.range(..u64::MAX);
|
||||
|
assert_eq!(iter.next(), Some(&(key_e, value_e)));
|
||||
|
assert_eq!(iter.next(), Some(&(key_c, value_c)));
|
||||
|
assert_eq!(iter.next(), Some(&(key_a, value_a)));
|
||||
|
assert_eq!(iter.next(), Some(&(key_d, value_d)));
|
||||
|
assert_eq!(iter.next(), Some(&(key_b, value_b)));
|
||||
|
assert_eq!(iter.next(), None);
|
||||
|
|
||||
|
// check all but e
|
||||
|
let mut iter = store.range(raw_a..u64::MAX);
|
||||
|
assert_eq!(iter.next(), Some(&(key_c, value_c)));
|
||||
|
assert_eq!(iter.next(), Some(&(key_a, value_a)));
|
||||
|
assert_eq!(iter.next(), Some(&(key_d, value_d)));
|
||||
|
assert_eq!(iter.next(), Some(&(key_b, value_b)));
|
||||
|
assert_eq!(iter.next(), None);
|
||||
|
|
||||
|
// check all but e and b
|
||||
|
let mut iter = store.range(raw_a..raw_b);
|
||||
|
assert_eq!(iter.next(), Some(&(key_c, value_c)));
|
||||
|
assert_eq!(iter.next(), Some(&(key_a, value_a)));
|
||||
|
assert_eq!(iter.next(), Some(&(key_d, value_d)));
|
||||
|
assert_eq!(iter.next(), None);
|
||||
|
}
|
||||
|
|
||||
|
#[test]
|
||||
|
fn test_get_lone_sibling() {
|
||||
|
// populate the value store
|
||||
|
let mut store = ValueStore::default();
|
||||
|
|
||||
|
let raw_a = 0b_10101010_10101010_00011111_11111111_10010110_10010011_11100000_00000000_u64;
|
||||
|
let key_a = RpoDigest::from([ZERO, ONE, ONE, Felt::new(raw_a)]);
|
||||
|
let value_a = [ONE; WORD_SIZE];
|
||||
|
store.insert(key_a, value_a);
|
||||
|
|
||||
|
let raw_b = 0b_11111111_11111111_00011111_11111111_10010110_10010011_11100000_00000000_u64;
|
||||
|
let key_b = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_b)]);
|
||||
|
let value_b = [ONE, ZERO, ONE, ZERO];
|
||||
|
store.insert(key_b, value_b);
|
||||
|
|
||||
|
// check sibling node for `a`
|
||||
|
let index = LeafNodeIndex::make(32, 0b_10101010_10101010_00011111_11111110);
|
||||
|
let parent_index = LeafNodeIndex::make(16, 0b_10101010_10101010);
|
||||
|
assert_eq!(store.get_lone_sibling(index), Some((&key_a, &value_a, parent_index)));
|
||||
|
|
||||
|
// check sibling node for `b`
|
||||
|
let index = LeafNodeIndex::make(32, 0b_11111111_11111111_00011111_11111111);
|
||||
|
let parent_index = LeafNodeIndex::make(16, 0b_11111111_11111111);
|
||||
|
assert_eq!(store.get_lone_sibling(index), Some((&key_b, &value_b, parent_index)));
|
||||
|
|
||||
|
// check some other sibling for some other index
|
||||
|
let index = LeafNodeIndex::make(32, 0b_11101010_10101010);
|
||||
|
assert_eq!(store.get_lone_sibling(index), None);
|
||||
|
}
|
||||
|
}
|
@ -0,0 +1,31 @@ |
|||||
|
/// A trait for computing the difference between two objects.
|
||||
|
pub trait Diff<K: Ord + Clone, V: Clone> {
|
||||
|
/// The type that describes the difference between two objects.
|
||||
|
type DiffType;
|
||||
|
|
||||
|
/// Returns a [Self::DiffType] object that represents the difference between this object and
|
||||
|
/// other.
|
||||
|
fn diff(&self, other: &Self) -> Self::DiffType;
|
||||
|
}
|
||||
|
|
||||
|
/// A trait for applying the difference between two objects.
|
||||
|
pub trait ApplyDiff<K: Ord + Clone, V: Clone> {
|
||||
|
/// The type that describes the difference between two objects.
|
||||
|
type DiffType;
|
||||
|
|
||||
|
/// Applies the provided changes described by [Self::DiffType] to the object implementing this trait.
|
||||
|
fn apply(&mut self, diff: Self::DiffType);
|
||||
|
}
|
||||
|
|
||||
|
/// A trait for applying the difference between two objects with the possibility of failure.
|
||||
|
pub trait TryApplyDiff<K: Ord + Clone, V: Clone> {
|
||||
|
/// The type that describes the difference between two objects.
|
||||
|
type DiffType;
|
||||
|
|
||||
|
/// An error type that can be returned if the changes cannot be applied.
|
||||
|
type Error;
|
||||
|
|
||||
|
/// Applies the provided changes described by [Self::DiffType] to the object implementing this trait.
|
||||
|
/// Returns an error if the changes cannot be applied.
|
||||
|
fn try_apply(&mut self, diff: Self::DiffType) -> Result<(), Self::Error>;
|
||||
|
}
|