Tracking PR for v0.7.0 releaseal-gkr-basic-workflow
@ -0,0 +1,3 @@ |
|||
[submodule "PQClean"] |
|||
path = PQClean |
|||
url = https://github.com/PQClean/PQClean.git |
@ -0,0 +1,78 @@ |
|||
#include <stddef.h> |
|||
#include <arm_sve.h> |
|||
#include "library.h" |
|||
#include "rpo_hash.h" |
|||
|
|||
// The STATE_WIDTH of RPO hash is 12x u64 elements. |
|||
// The current generation of SVE-enabled processors - Neoverse V1 |
|||
// (e.g. AWS Graviton3) have 256-bit vector registers (4x u64) |
|||
// This allows us to split the state into 3 vectors of 4 elements |
|||
// and process all 3 independent of each other. |
|||
|
|||
// We see the biggest performance gains by leveraging both |
|||
// vector and scalar operations on parts of the state array. |
|||
// Due to high latency of vector operations, the processor is able |
|||
// to reorder and pipeline scalar instructions while we wait for |
|||
// vector results. This effectively gives us some 'free' scalar |
|||
// operations and masks vector latency. |
|||
// |
|||
// This also means that we can fully saturate all four arithmetic |
|||
// units of the processor (2x scalar, 2x SIMD) |
|||
// |
|||
// THIS ANALYSIS NEEDS TO BE PERFORMED AGAIN ONCE PROCESSORS |
|||
// GAIN WIDER REGISTERS. It's quite possible that with 8x u64 |
|||
// vectors processing 2 partially filled vectors might |
|||
// be easier and faster than dealing with scalar operations |
|||
// on the remainder of the array. |
|||
// |
|||
// FOR NOW THIS IS ONLY ENABLED ON 4x u64 VECTORS! It falls back |
|||
// to the regular, already highly-optimized scalar version |
|||
// if the conditions are not met. |
|||
|
|||
bool add_constants_and_apply_sbox(uint64_t state[STATE_WIDTH], uint64_t constants[STATE_WIDTH]) { |
|||
const uint64_t vl = svcntd(); // number of u64 numbers in one SVE vector |
|||
|
|||
if (vl != 4) { |
|||
return false; |
|||
} |
|||
|
|||
svbool_t ptrue = svptrue_b64(); |
|||
|
|||
svuint64_t state1 = svld1(ptrue, state + 0*vl); |
|||
svuint64_t state2 = svld1(ptrue, state + 1*vl); |
|||
|
|||
svuint64_t const1 = svld1(ptrue, constants + 0*vl); |
|||
svuint64_t const2 = svld1(ptrue, constants + 1*vl); |
|||
|
|||
add_constants(ptrue, &state1, &const1, &state2, &const2, state+8, constants+8); |
|||
apply_sbox(ptrue, &state1, &state2, state+8); |
|||
|
|||
svst1(ptrue, state + 0*vl, state1); |
|||
svst1(ptrue, state + 1*vl, state2); |
|||
|
|||
return true; |
|||
} |
|||
|
|||
bool add_constants_and_apply_inv_sbox(uint64_t state[STATE_WIDTH], uint64_t constants[STATE_WIDTH]) { |
|||
const uint64_t vl = svcntd(); // number of u64 numbers in one SVE vector |
|||
|
|||
if (vl != 4) { |
|||
return false; |
|||
} |
|||
|
|||
svbool_t ptrue = svptrue_b64(); |
|||
|
|||
svuint64_t state1 = svld1(ptrue, state + 0 * vl); |
|||
svuint64_t state2 = svld1(ptrue, state + 1 * vl); |
|||
|
|||
svuint64_t const1 = svld1(ptrue, constants + 0 * vl); |
|||
svuint64_t const2 = svld1(ptrue, constants + 1 * vl); |
|||
|
|||
add_constants(ptrue, &state1, &const1, &state2, &const2, state + 8, constants + 8); |
|||
apply_inv_sbox(ptrue, &state1, &state2, state + 8); |
|||
|
|||
svst1(ptrue, state + 0 * vl, state1); |
|||
svst1(ptrue, state + 1 * vl, state2); |
|||
|
|||
return true; |
|||
} |
@ -0,0 +1,12 @@ |
|||
#ifndef CRYPTO_LIBRARY_H |
|||
#define CRYPTO_LIBRARY_H |
|||
|
|||
#include <stdint.h> |
|||
#include <stdbool.h> |
|||
|
|||
#define STATE_WIDTH 12 |
|||
|
|||
bool add_constants_and_apply_sbox(uint64_t state[STATE_WIDTH], uint64_t constants[STATE_WIDTH]); |
|||
bool add_constants_and_apply_inv_sbox(uint64_t state[STATE_WIDTH], uint64_t constants[STATE_WIDTH]); |
|||
|
|||
#endif //CRYPTO_LIBRARY_H |
@ -0,0 +1,221 @@ |
|||
#ifndef RPO_SVE_RPO_HASH_H |
|||
#define RPO_SVE_RPO_HASH_H |
|||
|
|||
#include <arm_sve.h> |
|||
#include <stddef.h> |
|||
#include <stdint.h> |
|||
#include <string.h> |
|||
|
|||
#define COPY(NAME, VIN1, VIN2, SIN3) \ |
|||
svuint64_t NAME ## _1 = VIN1; \ |
|||
svuint64_t NAME ## _2 = VIN2; \ |
|||
uint64_t NAME ## _3[4]; \ |
|||
memcpy(NAME ## _3, SIN3, 4 * sizeof(uint64_t)) |
|||
|
|||
#define MULTIPLY(PRED, DEST, OP) \ |
|||
mul(PRED, &DEST ## _1, &OP ## _1, &DEST ## _2, &OP ## _2, DEST ## _3, OP ## _3) |
|||
|
|||
#define SQUARE(PRED, NAME) \ |
|||
sq(PRED, &NAME ## _1, &NAME ## _2, NAME ## _3) |
|||
|
|||
#define SQUARE_DEST(PRED, DEST, SRC) \ |
|||
COPY(DEST, SRC ## _1, SRC ## _2, SRC ## _3); \ |
|||
SQUARE(PRED, DEST); |
|||
|
|||
#define POW_ACC(PRED, NAME, CNT, TAIL) \ |
|||
for (size_t i = 0; i < CNT; i++) { \ |
|||
SQUARE(PRED, NAME); \ |
|||
} \ |
|||
MULTIPLY(PRED, NAME, TAIL); |
|||
|
|||
#define POW_ACC_DEST(PRED, DEST, CNT, HEAD, TAIL) \ |
|||
COPY(DEST, HEAD ## _1, HEAD ## _2, HEAD ## _3); \ |
|||
POW_ACC(PRED, DEST, CNT, TAIL) |
|||
|
|||
extern inline void add_constants( |
|||
svbool_t pg, |
|||
svuint64_t *state1, |
|||
svuint64_t *const1, |
|||
svuint64_t *state2, |
|||
svuint64_t *const2, |
|||
uint64_t *state3, |
|||
uint64_t *const3 |
|||
) { |
|||
uint64_t Ms = 0xFFFFFFFF00000001ull; |
|||
svuint64_t Mv = svindex_u64(Ms, 0); |
|||
|
|||
uint64_t p_1 = Ms - const3[0]; |
|||
uint64_t p_2 = Ms - const3[1]; |
|||
uint64_t p_3 = Ms - const3[2]; |
|||
uint64_t p_4 = Ms - const3[3]; |
|||
|
|||
uint64_t x_1, x_2, x_3, x_4; |
|||
uint32_t adj_1 = -__builtin_sub_overflow(state3[0], p_1, &x_1); |
|||
uint32_t adj_2 = -__builtin_sub_overflow(state3[1], p_2, &x_2); |
|||
uint32_t adj_3 = -__builtin_sub_overflow(state3[2], p_3, &x_3); |
|||
uint32_t adj_4 = -__builtin_sub_overflow(state3[3], p_4, &x_4); |
|||
|
|||
state3[0] = x_1 - (uint64_t)adj_1; |
|||
state3[1] = x_2 - (uint64_t)adj_2; |
|||
state3[2] = x_3 - (uint64_t)adj_3; |
|||
state3[3] = x_4 - (uint64_t)adj_4; |
|||
|
|||
svuint64_t p1 = svsub_x(pg, Mv, *const1); |
|||
svuint64_t p2 = svsub_x(pg, Mv, *const2); |
|||
|
|||
svuint64_t x1 = svsub_x(pg, *state1, p1); |
|||
svuint64_t x2 = svsub_x(pg, *state2, p2); |
|||
|
|||
svbool_t pt1 = svcmplt_u64(pg, *state1, p1); |
|||
svbool_t pt2 = svcmplt_u64(pg, *state2, p2); |
|||
|
|||
*state1 = svsub_m(pt1, x1, (uint32_t)-1); |
|||
*state2 = svsub_m(pt2, x2, (uint32_t)-1); |
|||
} |
|||
|
|||
extern inline void mul( |
|||
svbool_t pg, |
|||
svuint64_t *r1, |
|||
const svuint64_t *op1, |
|||
svuint64_t *r2, |
|||
const svuint64_t *op2, |
|||
uint64_t *r3, |
|||
const uint64_t *op3 |
|||
) { |
|||
__uint128_t x_1 = r3[0]; |
|||
__uint128_t x_2 = r3[1]; |
|||
__uint128_t x_3 = r3[2]; |
|||
__uint128_t x_4 = r3[3]; |
|||
|
|||
x_1 *= (__uint128_t) op3[0]; |
|||
x_2 *= (__uint128_t) op3[1]; |
|||
x_3 *= (__uint128_t) op3[2]; |
|||
x_4 *= (__uint128_t) op3[3]; |
|||
|
|||
uint64_t x0_1 = x_1; |
|||
uint64_t x0_2 = x_2; |
|||
uint64_t x0_3 = x_3; |
|||
uint64_t x0_4 = x_4; |
|||
|
|||
svuint64_t l1 = svmul_x(pg, *r1, *op1); |
|||
svuint64_t l2 = svmul_x(pg, *r2, *op2); |
|||
|
|||
uint64_t x1_1 = (x_1 >> 64); |
|||
uint64_t x1_2 = (x_2 >> 64); |
|||
uint64_t x1_3 = (x_3 >> 64); |
|||
uint64_t x1_4 = (x_4 >> 64); |
|||
|
|||
uint64_t a_1, a_2, a_3, a_4; |
|||
uint64_t e_1 = __builtin_add_overflow(x0_1, (x0_1 << 32), &a_1); |
|||
uint64_t e_2 = __builtin_add_overflow(x0_2, (x0_2 << 32), &a_2); |
|||
uint64_t e_3 = __builtin_add_overflow(x0_3, (x0_3 << 32), &a_3); |
|||
uint64_t e_4 = __builtin_add_overflow(x0_4, (x0_4 << 32), &a_4); |
|||
|
|||
svuint64_t ls1 = svlsl_x(pg, l1, 32); |
|||
svuint64_t ls2 = svlsl_x(pg, l2, 32); |
|||
|
|||
svuint64_t a1 = svadd_x(pg, l1, ls1); |
|||
svuint64_t a2 = svadd_x(pg, l2, ls2); |
|||
|
|||
svbool_t e1 = svcmplt(pg, a1, l1); |
|||
svbool_t e2 = svcmplt(pg, a2, l2); |
|||
|
|||
svuint64_t as1 = svlsr_x(pg, a1, 32); |
|||
svuint64_t as2 = svlsr_x(pg, a2, 32); |
|||
|
|||
svuint64_t b1 = svsub_x(pg, a1, as1); |
|||
svuint64_t b2 = svsub_x(pg, a2, as2); |
|||
|
|||
b1 = svsub_m(e1, b1, 1); |
|||
b2 = svsub_m(e2, b2, 1); |
|||
|
|||
uint64_t b_1 = a_1 - (a_1 >> 32) - e_1; |
|||
uint64_t b_2 = a_2 - (a_2 >> 32) - e_2; |
|||
uint64_t b_3 = a_3 - (a_3 >> 32) - e_3; |
|||
uint64_t b_4 = a_4 - (a_4 >> 32) - e_4; |
|||
|
|||
uint64_t r_1, r_2, r_3, r_4; |
|||
uint32_t c_1 = __builtin_sub_overflow(x1_1, b_1, &r_1); |
|||
uint32_t c_2 = __builtin_sub_overflow(x1_2, b_2, &r_2); |
|||
uint32_t c_3 = __builtin_sub_overflow(x1_3, b_3, &r_3); |
|||
uint32_t c_4 = __builtin_sub_overflow(x1_4, b_4, &r_4); |
|||
|
|||
svuint64_t h1 = svmulh_x(pg, *r1, *op1); |
|||
svuint64_t h2 = svmulh_x(pg, *r2, *op2); |
|||
|
|||
svuint64_t tr1 = svsub_x(pg, h1, b1); |
|||
svuint64_t tr2 = svsub_x(pg, h2, b2); |
|||
|
|||
svbool_t c1 = svcmplt_u64(pg, h1, b1); |
|||
svbool_t c2 = svcmplt_u64(pg, h2, b2); |
|||
|
|||
*r1 = svsub_m(c1, tr1, (uint32_t) -1); |
|||
*r2 = svsub_m(c2, tr2, (uint32_t) -1); |
|||
|
|||
uint32_t minus1_1 = 0 - c_1; |
|||
uint32_t minus1_2 = 0 - c_2; |
|||
uint32_t minus1_3 = 0 - c_3; |
|||
uint32_t minus1_4 = 0 - c_4; |
|||
|
|||
r3[0] = r_1 - (uint64_t)minus1_1; |
|||
r3[1] = r_2 - (uint64_t)minus1_2; |
|||
r3[2] = r_3 - (uint64_t)minus1_3; |
|||
r3[3] = r_4 - (uint64_t)minus1_4; |
|||
} |
|||
|
|||
extern inline void sq(svbool_t pg, svuint64_t *a, svuint64_t *b, uint64_t *c) { |
|||
mul(pg, a, a, b, b, c, c); |
|||
} |
|||
|
|||
extern inline void apply_sbox( |
|||
svbool_t pg, |
|||
svuint64_t *state1, |
|||
svuint64_t *state2, |
|||
uint64_t *state3 |
|||
) { |
|||
COPY(x, *state1, *state2, state3); // copy input to x |
|||
SQUARE(pg, x); // x contains input^2 |
|||
mul(pg, state1, &x_1, state2, &x_2, state3, x_3); // state contains input^3 |
|||
SQUARE(pg, x); // x contains input^4 |
|||
mul(pg, state1, &x_1, state2, &x_2, state3, x_3); // state contains input^7 |
|||
} |
|||
|
|||
extern inline void apply_inv_sbox( |
|||
svbool_t pg, |
|||
svuint64_t *state_1, |
|||
svuint64_t *state_2, |
|||
uint64_t *state_3 |
|||
) { |
|||
// base^10 |
|||
COPY(t1, *state_1, *state_2, state_3); |
|||
SQUARE(pg, t1); |
|||
|
|||
// base^100 |
|||
SQUARE_DEST(pg, t2, t1); |
|||
|
|||
// base^100100 |
|||
POW_ACC_DEST(pg, t3, 3, t2, t2); |
|||
|
|||
// base^100100100100 |
|||
POW_ACC_DEST(pg, t4, 6, t3, t3); |
|||
|
|||
// compute base^100100100100100100100100 |
|||
POW_ACC_DEST(pg, t5, 12, t4, t4); |
|||
|
|||
// compute base^100100100100100100100100100100 |
|||
POW_ACC_DEST(pg, t6, 6, t5, t3); |
|||
|
|||
// compute base^1001001001001001001001001001000100100100100100100100100100100 |
|||
POW_ACC_DEST(pg, t7, 31, t6, t6); |
|||
|
|||
// compute base^1001001001001001001001001001000110110110110110110110110110110111 |
|||
SQUARE(pg, t7); |
|||
MULTIPLY(pg, t7, t6); |
|||
SQUARE(pg, t7); |
|||
SQUARE(pg, t7); |
|||
MULTIPLY(pg, t7, t1); |
|||
MULTIPLY(pg, t7, t2); |
|||
mul(pg, state_1, &t7_1, state_2, &t7_2, state_3, t7_3); |
|||
} |
|||
|
|||
#endif //RPO_SVE_RPO_HASH_H |
@ -0,0 +1,50 @@ |
|||
fn main() {
|
|||
#[cfg(feature = "std")]
|
|||
compile_rpo_falcon();
|
|||
|
|||
#[cfg(all(target_feature = "sve", feature = "sve"))]
|
|||
compile_arch_arm64_sve();
|
|||
}
|
|||
|
|||
#[cfg(feature = "std")]
|
|||
fn compile_rpo_falcon() {
|
|||
use std::path::PathBuf;
|
|||
|
|||
const RPO_FALCON_PATH: &str = "src/dsa/rpo_falcon512/falcon_c";
|
|||
|
|||
println!("cargo:rerun-if-changed={RPO_FALCON_PATH}/falcon.h");
|
|||
println!("cargo:rerun-if-changed={RPO_FALCON_PATH}/falcon.c");
|
|||
println!("cargo:rerun-if-changed={RPO_FALCON_PATH}/rpo.h");
|
|||
println!("cargo:rerun-if-changed={RPO_FALCON_PATH}/rpo.c");
|
|||
|
|||
let target_dir: PathBuf = ["PQClean", "crypto_sign", "falcon-512", "clean"].iter().collect();
|
|||
let common_dir: PathBuf = ["PQClean", "common"].iter().collect();
|
|||
|
|||
let scheme_files = glob::glob(target_dir.join("*.c").to_str().unwrap()).unwrap();
|
|||
let common_files = glob::glob(common_dir.join("*.c").to_str().unwrap()).unwrap();
|
|||
|
|||
cc::Build::new()
|
|||
.include(&common_dir)
|
|||
.include(target_dir)
|
|||
.files(scheme_files.into_iter().map(|p| p.unwrap().to_string_lossy().into_owned()))
|
|||
.files(common_files.into_iter().map(|p| p.unwrap().to_string_lossy().into_owned()))
|
|||
.file(format!("{RPO_FALCON_PATH}/falcon.c"))
|
|||
.file(format!("{RPO_FALCON_PATH}/rpo.c"))
|
|||
.flag("-O3")
|
|||
.compile("rpo_falcon512");
|
|||
}
|
|||
|
|||
#[cfg(all(target_feature = "sve", feature = "sve"))]
|
|||
fn compile_arch_arm64_sve() {
|
|||
const RPO_SVE_PATH: &str = "arch/arm64-sve/rpo";
|
|||
|
|||
println!("cargo:rerun-if-changed={RPO_SVE_PATH}/library.c");
|
|||
println!("cargo:rerun-if-changed={RPO_SVE_PATH}/library.h");
|
|||
println!("cargo:rerun-if-changed={RPO_SVE_PATH}/rpo_hash.h");
|
|||
|
|||
cc::Build::new()
|
|||
.file(format!("{RPO_SVE_PATH}/library.c"))
|
|||
.flag("-march=armv8-a+sve")
|
|||
.flag("-O3")
|
|||
.compile("rpo_sve");
|
|||
}
|
@ -0,0 +1 @@ |
|||
pub mod rpo_falcon512;
|
@ -0,0 +1,55 @@ |
|||
use super::{LOG_N, MODULUS, PK_LEN};
|
|||
use core::fmt;
|
|||
|
|||
#[derive(Clone, Debug, PartialEq, Eq)]
|
|||
pub enum FalconError {
|
|||
KeyGenerationFailed,
|
|||
PubKeyDecodingExtraData,
|
|||
PubKeyDecodingInvalidCoefficient(u32),
|
|||
PubKeyDecodingInvalidLength(usize),
|
|||
PubKeyDecodingInvalidTag(u8),
|
|||
SigDecodingTooBigHighBits(u32),
|
|||
SigDecodingInvalidRemainder,
|
|||
SigDecodingNonZeroUnusedBitsLastByte,
|
|||
SigDecodingMinusZero,
|
|||
SigDecodingIncorrectEncodingAlgorithm,
|
|||
SigDecodingNotSupportedDegree(u8),
|
|||
SigGenerationFailed,
|
|||
}
|
|||
|
|||
impl fmt::Display for FalconError {
|
|||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
|||
use FalconError::*;
|
|||
match self {
|
|||
KeyGenerationFailed => write!(f, "Failed to generate a private-public key pair"),
|
|||
PubKeyDecodingExtraData => {
|
|||
write!(f, "Failed to decode public key: input not fully consumed")
|
|||
}
|
|||
PubKeyDecodingInvalidCoefficient(val) => {
|
|||
write!(f, "Failed to decode public key: coefficient {val} is greater than or equal to the field modulus {MODULUS}")
|
|||
}
|
|||
PubKeyDecodingInvalidLength(len) => {
|
|||
write!(f, "Failed to decode public key: expected {PK_LEN} bytes but received {len}")
|
|||
}
|
|||
PubKeyDecodingInvalidTag(byte) => {
|
|||
write!(f, "Failed to decode public key: expected the first byte to be {LOG_N} but was {byte}")
|
|||
}
|
|||
SigDecodingTooBigHighBits(m) => {
|
|||
write!(f, "Failed to decode signature: high bits {m} exceed 2048")
|
|||
}
|
|||
SigDecodingInvalidRemainder => {
|
|||
write!(f, "Failed to decode signature: incorrect remaining data")
|
|||
}
|
|||
SigDecodingNonZeroUnusedBitsLastByte => {
|
|||
write!(f, "Failed to decode signature: Non-zero unused bits in the last byte")
|
|||
}
|
|||
SigDecodingMinusZero => write!(f, "Failed to decode signature: -0 is forbidden"),
|
|||
SigDecodingIncorrectEncodingAlgorithm => write!(f, "Failed to decode signature: not supported encoding algorithm"),
|
|||
SigDecodingNotSupportedDegree(log_n) => write!(f, "Failed to decode signature: only supported irreducible polynomial degree is 512, 2^{log_n} was provided"),
|
|||
SigGenerationFailed => write!(f, "Failed to generate a signature"),
|
|||
}
|
|||
}
|
|||
}
|
|||
|
|||
#[cfg(feature = "std")]
|
|||
impl std::error::Error for FalconError {}
|
@ -0,0 +1,402 @@ |
|||
/* |
|||
* Wrapper for implementing the PQClean API. |
|||
*/ |
|||
|
|||
#include <string.h> |
|||
#include "randombytes.h" |
|||
#include "falcon.h" |
|||
#include "inner.h" |
|||
#include "rpo.h" |
|||
|
|||
#define NONCELEN 40 |
|||
|
|||
/* |
|||
* Encoding formats (nnnn = log of degree, 9 for Falcon-512, 10 for Falcon-1024) |
|||
* |
|||
* private key: |
|||
* header byte: 0101nnnn |
|||
* private f (6 or 5 bits by element, depending on degree) |
|||
* private g (6 or 5 bits by element, depending on degree) |
|||
* private F (8 bits by element) |
|||
* |
|||
* public key: |
|||
* header byte: 0000nnnn |
|||
* public h (14 bits by element) |
|||
* |
|||
* signature: |
|||
* header byte: 0011nnnn |
|||
* nonce 40 bytes |
|||
* value (12 bits by element) |
|||
* |
|||
* message + signature: |
|||
* signature length (2 bytes, big-endian) |
|||
* nonce 40 bytes |
|||
* message |
|||
* header byte: 0010nnnn |
|||
* value (12 bits by element) |
|||
* (signature length is 1+len(value), not counting the nonce) |
|||
*/ |
|||
|
|||
/* see falcon.h */ |
|||
int PQCLEAN_FALCON512_CLEAN_crypto_sign_keypair_from_seed_rpo( |
|||
uint8_t *pk, |
|||
uint8_t *sk, |
|||
unsigned char *seed |
|||
) { |
|||
union |
|||
{ |
|||
uint8_t b[FALCON_KEYGEN_TEMP_9]; |
|||
uint64_t dummy_u64; |
|||
fpr dummy_fpr; |
|||
} tmp; |
|||
int8_t f[512], g[512], F[512]; |
|||
uint16_t h[512]; |
|||
inner_shake256_context rng; |
|||
size_t u, v; |
|||
|
|||
/* |
|||
* Generate key pair. |
|||
*/ |
|||
inner_shake256_init(&rng); |
|||
inner_shake256_inject(&rng, seed, sizeof seed); |
|||
inner_shake256_flip(&rng); |
|||
PQCLEAN_FALCON512_CLEAN_keygen(&rng, f, g, F, NULL, h, 9, tmp.b); |
|||
inner_shake256_ctx_release(&rng); |
|||
|
|||
/* |
|||
* Encode private key. |
|||
*/ |
|||
sk[0] = 0x50 + 9; |
|||
u = 1; |
|||
v = PQCLEAN_FALCON512_CLEAN_trim_i8_encode( |
|||
sk + u, PQCLEAN_FALCON512_CLEAN_CRYPTO_SECRETKEYBYTES - u, |
|||
f, 9, PQCLEAN_FALCON512_CLEAN_max_fg_bits[9]); |
|||
if (v == 0) |
|||
{ |
|||
return -1; |
|||
} |
|||
u += v; |
|||
v = PQCLEAN_FALCON512_CLEAN_trim_i8_encode( |
|||
sk + u, PQCLEAN_FALCON512_CLEAN_CRYPTO_SECRETKEYBYTES - u, |
|||
g, 9, PQCLEAN_FALCON512_CLEAN_max_fg_bits[9]); |
|||
if (v == 0) |
|||
{ |
|||
return -1; |
|||
} |
|||
u += v; |
|||
v = PQCLEAN_FALCON512_CLEAN_trim_i8_encode( |
|||
sk + u, PQCLEAN_FALCON512_CLEAN_CRYPTO_SECRETKEYBYTES - u, |
|||
F, 9, PQCLEAN_FALCON512_CLEAN_max_FG_bits[9]); |
|||
if (v == 0) |
|||
{ |
|||
return -1; |
|||
} |
|||
u += v; |
|||
if (u != PQCLEAN_FALCON512_CLEAN_CRYPTO_SECRETKEYBYTES) |
|||
{ |
|||
return -1; |
|||
} |
|||
|
|||
/* |
|||
* Encode public key. |
|||
*/ |
|||
pk[0] = 0x00 + 9; |
|||
v = PQCLEAN_FALCON512_CLEAN_modq_encode( |
|||
pk + 1, PQCLEAN_FALCON512_CLEAN_CRYPTO_PUBLICKEYBYTES - 1, |
|||
h, 9); |
|||
if (v != PQCLEAN_FALCON512_CLEAN_CRYPTO_PUBLICKEYBYTES - 1) |
|||
{ |
|||
return -1; |
|||
} |
|||
|
|||
return 0; |
|||
} |
|||
|
|||
int PQCLEAN_FALCON512_CLEAN_crypto_sign_keypair_rpo( |
|||
uint8_t *pk, |
|||
uint8_t *sk |
|||
) { |
|||
unsigned char seed[48]; |
|||
|
|||
/* |
|||
* Generate a random seed. |
|||
*/ |
|||
randombytes(seed, sizeof seed); |
|||
|
|||
return PQCLEAN_FALCON512_CLEAN_crypto_sign_keypair_from_seed_rpo(pk, sk, seed); |
|||
} |
|||
|
|||
/* |
|||
* Compute the signature. nonce[] receives the nonce and must have length |
|||
* NONCELEN bytes. sigbuf[] receives the signature value (without nonce |
|||
* or header byte), with *sigbuflen providing the maximum value length and |
|||
* receiving the actual value length. |
|||
* |
|||
* If a signature could be computed but not encoded because it would |
|||
* exceed the output buffer size, then a new signature is computed. If |
|||
* the provided buffer size is too low, this could loop indefinitely, so |
|||
* the caller must provide a size that can accommodate signatures with a |
|||
* large enough probability. |
|||
* |
|||
* Return value: 0 on success, -1 on error. |
|||
*/ |
|||
static int do_sign( |
|||
uint8_t *nonce, |
|||
uint8_t *sigbuf, |
|||
size_t *sigbuflen, |
|||
const uint8_t *m, |
|||
size_t mlen, |
|||
const uint8_t *sk |
|||
) { |
|||
union |
|||
{ |
|||
uint8_t b[72 * 512]; |
|||
uint64_t dummy_u64; |
|||
fpr dummy_fpr; |
|||
} tmp; |
|||
int8_t f[512], g[512], F[512], G[512]; |
|||
struct |
|||
{ |
|||
int16_t sig[512]; |
|||
uint16_t hm[512]; |
|||
} r; |
|||
unsigned char seed[48]; |
|||
inner_shake256_context sc; |
|||
rpo128_context rc; |
|||
size_t u, v; |
|||
|
|||
/* |
|||
* Decode the private key. |
|||
*/ |
|||
if (sk[0] != 0x50 + 9) |
|||
{ |
|||
return -1; |
|||
} |
|||
u = 1; |
|||
v = PQCLEAN_FALCON512_CLEAN_trim_i8_decode( |
|||
f, 9, PQCLEAN_FALCON512_CLEAN_max_fg_bits[9], |
|||
sk + u, PQCLEAN_FALCON512_CLEAN_CRYPTO_SECRETKEYBYTES - u); |
|||
if (v == 0) |
|||
{ |
|||
return -1; |
|||
} |
|||
u += v; |
|||
v = PQCLEAN_FALCON512_CLEAN_trim_i8_decode( |
|||
g, 9, PQCLEAN_FALCON512_CLEAN_max_fg_bits[9], |
|||
sk + u, PQCLEAN_FALCON512_CLEAN_CRYPTO_SECRETKEYBYTES - u); |
|||
if (v == 0) |
|||
{ |
|||
return -1; |
|||
} |
|||
u += v; |
|||
v = PQCLEAN_FALCON512_CLEAN_trim_i8_decode( |
|||
F, 9, PQCLEAN_FALCON512_CLEAN_max_FG_bits[9], |
|||
sk + u, PQCLEAN_FALCON512_CLEAN_CRYPTO_SECRETKEYBYTES - u); |
|||
if (v == 0) |
|||
{ |
|||
return -1; |
|||
} |
|||
u += v; |
|||
if (u != PQCLEAN_FALCON512_CLEAN_CRYPTO_SECRETKEYBYTES) |
|||
{ |
|||
return -1; |
|||
} |
|||
if (!PQCLEAN_FALCON512_CLEAN_complete_private(G, f, g, F, 9, tmp.b)) |
|||
{ |
|||
return -1; |
|||
} |
|||
|
|||
/* |
|||
* Create a random nonce (40 bytes). |
|||
*/ |
|||
randombytes(nonce, NONCELEN); |
|||
|
|||
/* ==== Start: Deviation from the reference implementation ================================= */ |
|||
|
|||
// Transform the nonce into 8 chunks each of size 5 bytes. We do this in order to be sure that |
|||
// the conversion to field elements succeeds |
|||
uint8_t buffer[64]; |
|||
memset(buffer, 0, 64); |
|||
for (size_t i = 0; i < 8; i++) |
|||
{ |
|||
buffer[8 * i] = nonce[5 * i]; |
|||
buffer[8 * i + 1] = nonce[5 * i + 1]; |
|||
buffer[8 * i + 2] = nonce[5 * i + 2]; |
|||
buffer[8 * i + 3] = nonce[5 * i + 3]; |
|||
buffer[8 * i + 4] = nonce[5 * i + 4]; |
|||
} |
|||
|
|||
/* |
|||
* Hash message nonce + message into a vector. |
|||
*/ |
|||
rpo128_init(&rc); |
|||
rpo128_absorb(&rc, buffer, NONCELEN + 24); |
|||
rpo128_absorb(&rc, m, mlen); |
|||
rpo128_finalize(&rc); |
|||
PQCLEAN_FALCON512_CLEAN_hash_to_point_rpo(&rc, r.hm, 9); |
|||
rpo128_release(&rc); |
|||
|
|||
/* ==== End: Deviation from the reference implementation =================================== */ |
|||
|
|||
/* |
|||
* Initialize a RNG. |
|||
*/ |
|||
randombytes(seed, sizeof seed); |
|||
inner_shake256_init(&sc); |
|||
inner_shake256_inject(&sc, seed, sizeof seed); |
|||
inner_shake256_flip(&sc); |
|||
|
|||
/* |
|||
* Compute and return the signature. This loops until a signature |
|||
* value is found that fits in the provided buffer. |
|||
*/ |
|||
for (;;) |
|||
{ |
|||
PQCLEAN_FALCON512_CLEAN_sign_dyn(r.sig, &sc, f, g, F, G, r.hm, 9, tmp.b); |
|||
v = PQCLEAN_FALCON512_CLEAN_comp_encode(sigbuf, *sigbuflen, r.sig, 9); |
|||
if (v != 0) |
|||
{ |
|||
inner_shake256_ctx_release(&sc); |
|||
*sigbuflen = v; |
|||
return 0; |
|||
} |
|||
} |
|||
} |
|||
|
|||
/* |
|||
* Verify a signature. The nonce has size NONCELEN bytes. sigbuf[] |
|||
* (of size sigbuflen) contains the signature value, not including the |
|||
* header byte or nonce. Return value is 0 on success, -1 on error. |
|||
*/ |
|||
static int do_verify( |
|||
const uint8_t *nonce, |
|||
const uint8_t *sigbuf, |
|||
size_t sigbuflen, |
|||
const uint8_t *m, |
|||
size_t mlen, |
|||
const uint8_t *pk |
|||
) { |
|||
union |
|||
{ |
|||
uint8_t b[2 * 512]; |
|||
uint64_t dummy_u64; |
|||
fpr dummy_fpr; |
|||
} tmp; |
|||
uint16_t h[512], hm[512]; |
|||
int16_t sig[512]; |
|||
rpo128_context rc; |
|||
|
|||
/* |
|||
* Decode public key. |
|||
*/ |
|||
if (pk[0] != 0x00 + 9) |
|||
{ |
|||
return -1; |
|||
} |
|||
if (PQCLEAN_FALCON512_CLEAN_modq_decode(h, 9, |
|||
pk + 1, PQCLEAN_FALCON512_CLEAN_CRYPTO_PUBLICKEYBYTES - 1) |
|||
!= PQCLEAN_FALCON512_CLEAN_CRYPTO_PUBLICKEYBYTES - 1) |
|||
{ |
|||
return -1; |
|||
} |
|||
PQCLEAN_FALCON512_CLEAN_to_ntt_monty(h, 9); |
|||
|
|||
/* |
|||
* Decode signature. |
|||
*/ |
|||
if (sigbuflen == 0) |
|||
{ |
|||
return -1; |
|||
} |
|||
if (PQCLEAN_FALCON512_CLEAN_comp_decode(sig, 9, sigbuf, sigbuflen) != sigbuflen) |
|||
{ |
|||
return -1; |
|||
} |
|||
|
|||
/* ==== Start: Deviation from the reference implementation ================================= */ |
|||
|
|||
/* |
|||
* Hash nonce + message into a vector. |
|||
*/ |
|||
|
|||
// Transform the nonce into 8 chunks each of size 5 bytes. We do this in order to be sure that |
|||
// the conversion to field elements succeeds |
|||
uint8_t buffer[64]; |
|||
memset(buffer, 0, 64); |
|||
for (size_t i = 0; i < 8; i++) |
|||
{ |
|||
buffer[8 * i] = nonce[5 * i]; |
|||
buffer[8 * i + 1] = nonce[5 * i + 1]; |
|||
buffer[8 * i + 2] = nonce[5 * i + 2]; |
|||
buffer[8 * i + 3] = nonce[5 * i + 3]; |
|||
buffer[8 * i + 4] = nonce[5 * i + 4]; |
|||
} |
|||
|
|||
rpo128_init(&rc); |
|||
rpo128_absorb(&rc, buffer, NONCELEN + 24); |
|||
rpo128_absorb(&rc, m, mlen); |
|||
rpo128_finalize(&rc); |
|||
PQCLEAN_FALCON512_CLEAN_hash_to_point_rpo(&rc, hm, 9); |
|||
rpo128_release(&rc); |
|||
|
|||
/* === End: Deviation from the reference implementation ==================================== */ |
|||
|
|||
/* |
|||
* Verify signature. |
|||
*/ |
|||
if (!PQCLEAN_FALCON512_CLEAN_verify_raw(hm, sig, h, 9, tmp.b)) |
|||
{ |
|||
return -1; |
|||
} |
|||
return 0; |
|||
} |
|||
|
|||
/* see falcon.h */ |
|||
int PQCLEAN_FALCON512_CLEAN_crypto_sign_signature_rpo( |
|||
uint8_t *sig, |
|||
size_t *siglen, |
|||
const uint8_t *m, |
|||
size_t mlen, |
|||
const uint8_t *sk |
|||
) { |
|||
/* |
|||
* The PQCLEAN_FALCON512_CLEAN_CRYPTO_BYTES constant is used for |
|||
* the signed message object (as produced by crypto_sign()) |
|||
* and includes a two-byte length value, so we take care here |
|||
* to only generate signatures that are two bytes shorter than |
|||
* the maximum. This is done to ensure that crypto_sign() |
|||
* and crypto_sign_signature() produce the exact same signature |
|||
* value, if used on the same message, with the same private key, |
|||
* and using the same output from randombytes() (this is for |
|||
* reproducibility of tests). |
|||
*/ |
|||
size_t vlen; |
|||
|
|||
vlen = PQCLEAN_FALCON512_CLEAN_CRYPTO_BYTES - NONCELEN - 3; |
|||
if (do_sign(sig + 1, sig + 1 + NONCELEN, &vlen, m, mlen, sk) < 0) |
|||
{ |
|||
return -1; |
|||
} |
|||
sig[0] = 0x30 + 9; |
|||
*siglen = 1 + NONCELEN + vlen; |
|||
return 0; |
|||
} |
|||
|
|||
/* see falcon.h */ |
|||
int PQCLEAN_FALCON512_CLEAN_crypto_sign_verify_rpo( |
|||
const uint8_t *sig, |
|||
size_t siglen, |
|||
const uint8_t *m, |
|||
size_t mlen, |
|||
const uint8_t *pk |
|||
) { |
|||
if (siglen < 1 + NONCELEN) |
|||
{ |
|||
return -1; |
|||
} |
|||
if (sig[0] != 0x30 + 9) |
|||
{ |
|||
return -1; |
|||
} |
|||
return do_verify(sig + 1, sig + 1 + NONCELEN, siglen - 1 - NONCELEN, m, mlen, pk); |
|||
} |
@ -0,0 +1,66 @@ |
|||
#include <stddef.h> |
|||
#include <stdint.h> |
|||
|
|||
#define PQCLEAN_FALCON512_CLEAN_CRYPTO_SECRETKEYBYTES 1281 |
|||
#define PQCLEAN_FALCON512_CLEAN_CRYPTO_PUBLICKEYBYTES 897 |
|||
#define PQCLEAN_FALCON512_CLEAN_CRYPTO_BYTES 666 |
|||
|
|||
/* |
|||
* Generate a new key pair. Public key goes into pk[], private key in sk[]. |
|||
* Key sizes are exact (in bytes): |
|||
* public (pk): PQCLEAN_FALCON512_CLEAN_CRYPTO_PUBLICKEYBYTES |
|||
* private (sk): PQCLEAN_FALCON512_CLEAN_CRYPTO_SECRETKEYBYTES |
|||
* |
|||
* Return value: 0 on success, -1 on error. |
|||
* |
|||
* Note: This implementation follows the reference implementation in PQClean |
|||
* https://github.com/PQClean/PQClean/tree/master/crypto_sign/falcon-512 |
|||
* verbatim except for the sections that are marked otherwise. |
|||
*/ |
|||
int PQCLEAN_FALCON512_CLEAN_crypto_sign_keypair_rpo( |
|||
uint8_t *pk, uint8_t *sk); |
|||
|
|||
/* |
|||
* Generate a new key pair from seed. Public key goes into pk[], private key in sk[]. |
|||
* Key sizes are exact (in bytes): |
|||
* public (pk): PQCLEAN_FALCON512_CLEAN_CRYPTO_PUBLICKEYBYTES |
|||
* private (sk): PQCLEAN_FALCON512_CLEAN_CRYPTO_SECRETKEYBYTES |
|||
* |
|||
* Return value: 0 on success, -1 on error. |
|||
*/ |
|||
int PQCLEAN_FALCON512_CLEAN_crypto_sign_keypair_from_seed_rpo( |
|||
uint8_t *pk, uint8_t *sk, unsigned char *seed); |
|||
|
|||
/* |
|||
* Compute a signature on a provided message (m, mlen), with a given |
|||
* private key (sk). Signature is written in sig[], with length written |
|||
* into *siglen. Signature length is variable; maximum signature length |
|||
* (in bytes) is PQCLEAN_FALCON512_CLEAN_CRYPTO_BYTES. |
|||
* |
|||
* sig[], m[] and sk[] may overlap each other arbitrarily. |
|||
* |
|||
* Return value: 0 on success, -1 on error. |
|||
* |
|||
* Note: This implementation follows the reference implementation in PQClean |
|||
* https://github.com/PQClean/PQClean/tree/master/crypto_sign/falcon-512 |
|||
* verbatim except for the sections that are marked otherwise. |
|||
*/ |
|||
int PQCLEAN_FALCON512_CLEAN_crypto_sign_signature_rpo( |
|||
uint8_t *sig, size_t *siglen, |
|||
const uint8_t *m, size_t mlen, const uint8_t *sk); |
|||
|
|||
/* |
|||
* Verify a signature (sig, siglen) on a message (m, mlen) with a given |
|||
* public key (pk). |
|||
* |
|||
* sig[], m[] and pk[] may overlap each other arbitrarily. |
|||
* |
|||
* Return value: 0 on success, -1 on error. |
|||
* |
|||
* Note: This implementation follows the reference implementation in PQClean |
|||
* https://github.com/PQClean/PQClean/tree/master/crypto_sign/falcon-512 |
|||
* verbatim except for the sections that are marked otherwise. |
|||
*/ |
|||
int PQCLEAN_FALCON512_CLEAN_crypto_sign_verify_rpo( |
|||
const uint8_t *sig, size_t siglen, |
|||
const uint8_t *m, size_t mlen, const uint8_t *pk); |
@ -0,0 +1,582 @@ |
|||
/* |
|||
* RPO implementation. |
|||
*/ |
|||
|
|||
#include <stdint.h> |
|||
#include <string.h> |
|||
#include <stdlib.h> |
|||
|
|||
/* ================================================================================================ |
|||
* Modular Arithmetic |
|||
*/ |
|||
|
|||
#define P 0xFFFFFFFF00000001 |
|||
#define M 12289 |
|||
|
|||
// From https://github.com/ncw/iprime/blob/master/mod_math_noasm.go |
|||
static uint64_t add_mod_p(uint64_t a, uint64_t b) |
|||
{ |
|||
a = P - a; |
|||
uint64_t res = b - a; |
|||
if (b < a) |
|||
res += P; |
|||
return res; |
|||
} |
|||
|
|||
static uint64_t sub_mod_p(uint64_t a, uint64_t b) |
|||
{ |
|||
uint64_t r = a - b; |
|||
if (a < b) |
|||
r += P; |
|||
return r; |
|||
} |
|||
|
|||
static uint64_t reduce_mod_p(uint64_t b, uint64_t a) |
|||
{ |
|||
uint32_t d = b >> 32, |
|||
c = b; |
|||
if (a >= P) |
|||
a -= P; |
|||
a = sub_mod_p(a, c); |
|||
a = sub_mod_p(a, d); |
|||
a = add_mod_p(a, ((uint64_t)c) << 32); |
|||
return a; |
|||
} |
|||
|
|||
static uint64_t mult_mod_p(uint64_t x, uint64_t y) |
|||
{ |
|||
uint32_t a = x, |
|||
b = x >> 32, |
|||
c = y, |
|||
d = y >> 32; |
|||
|
|||
/* first synthesize the product using 32*32 -> 64 bit multiplies */ |
|||
x = b * (uint64_t)c; /* b*c */ |
|||
y = a * (uint64_t)d; /* a*d */ |
|||
uint64_t e = a * (uint64_t)c, /* a*c */ |
|||
f = b * (uint64_t)d, /* b*d */ |
|||
t; |
|||
|
|||
x += y; /* b*c + a*d */ |
|||
/* carry? */ |
|||
if (x < y) |
|||
f += 1LL << 32; /* carry into upper 32 bits - can't overflow */ |
|||
|
|||
t = x << 32; |
|||
e += t; /* a*c + LSW(b*c + a*d) */ |
|||
/* carry? */ |
|||
if (e < t) |
|||
f += 1; /* carry into upper 64 bits - can't overflow*/ |
|||
t = x >> 32; |
|||
f += t; /* b*d + MSW(b*c + a*d) */ |
|||
/* can't overflow */ |
|||
|
|||
/* now reduce: (b*d + MSW(b*c + a*d), a*c + LSW(b*c + a*d)) */ |
|||
return reduce_mod_p(f, e); |
|||
} |
|||
|
|||
/* ================================================================================================ |
|||
* RPO128 Permutation |
|||
*/ |
|||
|
|||
static const uint64_t STATE_WIDTH = 12; |
|||
static const uint64_t NUM_ROUNDS = 7; |
|||
|
|||
/* |
|||
* MDS matrix |
|||
*/ |
|||
static const uint64_t MDS[12][12] = { |
|||
{ 7, 23, 8, 26, 13, 10, 9, 7, 6, 22, 21, 8 }, |
|||
{ 8, 7, 23, 8, 26, 13, 10, 9, 7, 6, 22, 21 }, |
|||
{ 21, 8, 7, 23, 8, 26, 13, 10, 9, 7, 6, 22 }, |
|||
{ 22, 21, 8, 7, 23, 8, 26, 13, 10, 9, 7, 6 }, |
|||
{ 6, 22, 21, 8, 7, 23, 8, 26, 13, 10, 9, 7 }, |
|||
{ 7, 6, 22, 21, 8, 7, 23, 8, 26, 13, 10, 9 }, |
|||
{ 9, 7, 6, 22, 21, 8, 7, 23, 8, 26, 13, 10 }, |
|||
{ 10, 9, 7, 6, 22, 21, 8, 7, 23, 8, 26, 13 }, |
|||
{ 13, 10, 9, 7, 6, 22, 21, 8, 7, 23, 8, 26 }, |
|||
{ 26, 13, 10, 9, 7, 6, 22, 21, 8, 7, 23, 8 }, |
|||
{ 8, 26, 13, 10, 9, 7, 6, 22, 21, 8, 7, 23 }, |
|||
{ 23, 8, 26, 13, 10, 9, 7, 6, 22, 21, 8, 7 }, |
|||
}; |
|||
|
|||
/* |
|||
* Round constants. |
|||
*/ |
|||
static const uint64_t ARK1[7][12] = { |
|||
{ |
|||
5789762306288267392ULL, |
|||
6522564764413701783ULL, |
|||
17809893479458208203ULL, |
|||
107145243989736508ULL, |
|||
6388978042437517382ULL, |
|||
15844067734406016715ULL, |
|||
9975000513555218239ULL, |
|||
3344984123768313364ULL, |
|||
9959189626657347191ULL, |
|||
12960773468763563665ULL, |
|||
9602914297752488475ULL, |
|||
16657542370200465908ULL, |
|||
}, |
|||
{ |
|||
12987190162843096997ULL, |
|||
653957632802705281ULL, |
|||
4441654670647621225ULL, |
|||
4038207883745915761ULL, |
|||
5613464648874830118ULL, |
|||
13222989726778338773ULL, |
|||
3037761201230264149ULL, |
|||
16683759727265180203ULL, |
|||
8337364536491240715ULL, |
|||
3227397518293416448ULL, |
|||
8110510111539674682ULL, |
|||
2872078294163232137ULL, |
|||
}, |
|||
{ |
|||
18072785500942327487ULL, |
|||
6200974112677013481ULL, |
|||
17682092219085884187ULL, |
|||
10599526828986756440ULL, |
|||
975003873302957338ULL, |
|||
8264241093196931281ULL, |
|||
10065763900435475170ULL, |
|||
2181131744534710197ULL, |
|||
6317303992309418647ULL, |
|||
1401440938888741532ULL, |
|||
8884468225181997494ULL, |
|||
13066900325715521532ULL, |
|||
}, |
|||
{ |
|||
5674685213610121970ULL, |
|||
5759084860419474071ULL, |
|||
13943282657648897737ULL, |
|||
1352748651966375394ULL, |
|||
17110913224029905221ULL, |
|||
1003883795902368422ULL, |
|||
4141870621881018291ULL, |
|||
8121410972417424656ULL, |
|||
14300518605864919529ULL, |
|||
13712227150607670181ULL, |
|||
17021852944633065291ULL, |
|||
6252096473787587650ULL, |
|||
}, |
|||
{ |
|||
4887609836208846458ULL, |
|||
3027115137917284492ULL, |
|||
9595098600469470675ULL, |
|||
10528569829048484079ULL, |
|||
7864689113198939815ULL, |
|||
17533723827845969040ULL, |
|||
5781638039037710951ULL, |
|||
17024078752430719006ULL, |
|||
109659393484013511ULL, |
|||
7158933660534805869ULL, |
|||
2955076958026921730ULL, |
|||
7433723648458773977ULL, |
|||
}, |
|||
{ |
|||
16308865189192447297ULL, |
|||
11977192855656444890ULL, |
|||
12532242556065780287ULL, |
|||
14594890931430968898ULL, |
|||
7291784239689209784ULL, |
|||
5514718540551361949ULL, |
|||
10025733853830934803ULL, |
|||
7293794580341021693ULL, |
|||
6728552937464861756ULL, |
|||
6332385040983343262ULL, |
|||
13277683694236792804ULL, |
|||
2600778905124452676ULL, |
|||
}, |
|||
{ |
|||
7123075680859040534ULL, |
|||
1034205548717903090ULL, |
|||
7717824418247931797ULL, |
|||
3019070937878604058ULL, |
|||
11403792746066867460ULL, |
|||
10280580802233112374ULL, |
|||
337153209462421218ULL, |
|||
13333398568519923717ULL, |
|||
3596153696935337464ULL, |
|||
8104208463525993784ULL, |
|||
14345062289456085693ULL, |
|||
17036731477169661256ULL, |
|||
}}; |
|||
|
|||
const uint64_t ARK2[7][12] = { |
|||
{ |
|||
6077062762357204287ULL, |
|||
15277620170502011191ULL, |
|||
5358738125714196705ULL, |
|||
14233283787297595718ULL, |
|||
13792579614346651365ULL, |
|||
11614812331536767105ULL, |
|||
14871063686742261166ULL, |
|||
10148237148793043499ULL, |
|||
4457428952329675767ULL, |
|||
15590786458219172475ULL, |
|||
10063319113072092615ULL, |
|||
14200078843431360086ULL, |
|||
}, |
|||
{ |
|||
6202948458916099932ULL, |
|||
17690140365333231091ULL, |
|||
3595001575307484651ULL, |
|||
373995945117666487ULL, |
|||
1235734395091296013ULL, |
|||
14172757457833931602ULL, |
|||
707573103686350224ULL, |
|||
15453217512188187135ULL, |
|||
219777875004506018ULL, |
|||
17876696346199469008ULL, |
|||
17731621626449383378ULL, |
|||
2897136237748376248ULL, |
|||
}, |
|||
{ |
|||
8023374565629191455ULL, |
|||
15013690343205953430ULL, |
|||
4485500052507912973ULL, |
|||
12489737547229155153ULL, |
|||
9500452585969030576ULL, |
|||
2054001340201038870ULL, |
|||
12420704059284934186ULL, |
|||
355990932618543755ULL, |
|||
9071225051243523860ULL, |
|||
12766199826003448536ULL, |
|||
9045979173463556963ULL, |
|||
12934431667190679898ULL, |
|||
}, |
|||
{ |
|||
18389244934624494276ULL, |
|||
16731736864863925227ULL, |
|||
4440209734760478192ULL, |
|||
17208448209698888938ULL, |
|||
8739495587021565984ULL, |
|||
17000774922218161967ULL, |
|||
13533282547195532087ULL, |
|||
525402848358706231ULL, |
|||
16987541523062161972ULL, |
|||
5466806524462797102ULL, |
|||
14512769585918244983ULL, |
|||
10973956031244051118ULL, |
|||
}, |
|||
{ |
|||
6982293561042362913ULL, |
|||
14065426295947720331ULL, |
|||
16451845770444974180ULL, |
|||
7139138592091306727ULL, |
|||
9012006439959783127ULL, |
|||
14619614108529063361ULL, |
|||
1394813199588124371ULL, |
|||
4635111139507788575ULL, |
|||
16217473952264203365ULL, |
|||
10782018226466330683ULL, |
|||
6844229992533662050ULL, |
|||
7446486531695178711ULL, |
|||
}, |
|||
{ |
|||
3736792340494631448ULL, |
|||
577852220195055341ULL, |
|||
6689998335515779805ULL, |
|||
13886063479078013492ULL, |
|||
14358505101923202168ULL, |
|||
7744142531772274164ULL, |
|||
16135070735728404443ULL, |
|||
12290902521256031137ULL, |
|||
12059913662657709804ULL, |
|||
16456018495793751911ULL, |
|||
4571485474751953524ULL, |
|||
17200392109565783176ULL, |
|||
}, |
|||
{ |
|||
17130398059294018733ULL, |
|||
519782857322261988ULL, |
|||
9625384390925085478ULL, |
|||
1664893052631119222ULL, |
|||
7629576092524553570ULL, |
|||
3485239601103661425ULL, |
|||
9755891797164033838ULL, |
|||
15218148195153269027ULL, |
|||
16460604813734957368ULL, |
|||
9643968136937729763ULL, |
|||
3611348709641382851ULL, |
|||
18256379591337759196ULL, |
|||
}, |
|||
}; |
|||
|
|||
static void apply_sbox(uint64_t *const state) |
|||
{ |
|||
for (uint64_t i = 0; i < STATE_WIDTH; i++) |
|||
{ |
|||
uint64_t t2 = mult_mod_p(*(state + i), *(state + i)); |
|||
uint64_t t4 = mult_mod_p(t2, t2); |
|||
|
|||
*(state + i) = mult_mod_p(*(state + i), mult_mod_p(t2, t4)); |
|||
} |
|||
} |
|||
|
|||
static void apply_mds(uint64_t *state) |
|||
{ |
|||
uint64_t res[STATE_WIDTH]; |
|||
for (uint64_t i = 0; i < STATE_WIDTH; i++) |
|||
{ |
|||
res[i] = 0; |
|||
} |
|||
for (uint64_t i = 0; i < STATE_WIDTH; i++) |
|||
{ |
|||
for (uint64_t j = 0; j < STATE_WIDTH; j++) |
|||
{ |
|||
res[i] = add_mod_p(res[i], mult_mod_p(MDS[i][j], *(state + j))); |
|||
} |
|||
} |
|||
|
|||
for (uint64_t i = 0; i < STATE_WIDTH; i++) |
|||
{ |
|||
*(state + i) = res[i]; |
|||
} |
|||
} |
|||
|
|||
static void apply_constants(uint64_t *const state, const uint64_t *ark) |
|||
{ |
|||
for (uint64_t i = 0; i < STATE_WIDTH; i++) |
|||
{ |
|||
*(state + i) = add_mod_p(*(state + i), *(ark + i)); |
|||
} |
|||
} |
|||
|
|||
static void exp_acc(const uint64_t m, const uint64_t *base, const uint64_t *tail, uint64_t *const res) |
|||
{ |
|||
for (uint64_t i = 0; i < m; i++) |
|||
{ |
|||
for (uint64_t j = 0; j < STATE_WIDTH; j++) |
|||
{ |
|||
if (i == 0) |
|||
{ |
|||
*(res + j) = mult_mod_p(*(base + j), *(base + j)); |
|||
} |
|||
else |
|||
{ |
|||
*(res + j) = mult_mod_p(*(res + j), *(res + j)); |
|||
} |
|||
} |
|||
} |
|||
|
|||
for (uint64_t i = 0; i < STATE_WIDTH; i++) |
|||
{ |
|||
*(res + i) = mult_mod_p(*(res + i), *(tail + i)); |
|||
} |
|||
} |
|||
|
|||
static void apply_inv_sbox(uint64_t *const state) |
|||
{ |
|||
uint64_t t1[STATE_WIDTH]; |
|||
for (uint64_t i = 0; i < STATE_WIDTH; i++) |
|||
{ |
|||
t1[i] = 0; |
|||
} |
|||
for (uint64_t i = 0; i < STATE_WIDTH; i++) |
|||
{ |
|||
t1[i] = mult_mod_p(*(state + i), *(state + i)); |
|||
} |
|||
|
|||
uint64_t t2[STATE_WIDTH]; |
|||
for (uint64_t i = 0; i < STATE_WIDTH; i++) |
|||
{ |
|||
t2[i] = 0; |
|||
} |
|||
for (uint64_t i = 0; i < STATE_WIDTH; i++) |
|||
{ |
|||
t2[i] = mult_mod_p(t1[i], t1[i]); |
|||
} |
|||
|
|||
uint64_t t3[STATE_WIDTH]; |
|||
for (uint64_t i = 0; i < STATE_WIDTH; i++) |
|||
{ |
|||
t3[i] = 0; |
|||
} |
|||
exp_acc(3, t2, t2, t3); |
|||
|
|||
uint64_t t4[STATE_WIDTH]; |
|||
for (uint64_t i = 0; i < STATE_WIDTH; i++) |
|||
{ |
|||
t4[i] = 0; |
|||
} |
|||
exp_acc(6, t3, t3, t4); |
|||
|
|||
uint64_t tmp[STATE_WIDTH]; |
|||
for (uint64_t i = 0; i < STATE_WIDTH; i++) |
|||
{ |
|||
tmp[i] = 0; |
|||
} |
|||
exp_acc(12, t4, t4, tmp); |
|||
|
|||
uint64_t t5[STATE_WIDTH]; |
|||
for (uint64_t i = 0; i < STATE_WIDTH; i++) |
|||
{ |
|||
t5[i] = 0; |
|||
} |
|||
exp_acc(6, tmp, t3, t5); |
|||
|
|||
uint64_t t6[STATE_WIDTH]; |
|||
for (uint64_t i = 0; i < STATE_WIDTH; i++) |
|||
{ |
|||
t6[i] = 0; |
|||
} |
|||
exp_acc(31, t5, t5, t6); |
|||
|
|||
for (uint64_t i = 0; i < STATE_WIDTH; i++) |
|||
{ |
|||
uint64_t a = mult_mod_p(mult_mod_p(t6[i], t6[i]), t5[i]); |
|||
a = mult_mod_p(a, a); |
|||
a = mult_mod_p(a, a); |
|||
uint64_t b = mult_mod_p(mult_mod_p(t1[i], t2[i]), *(state + i)); |
|||
|
|||
*(state + i) = mult_mod_p(a, b); |
|||
} |
|||
} |
|||
|
|||
static void apply_round(uint64_t *const state, const uint64_t round) |
|||
{ |
|||
apply_mds(state); |
|||
apply_constants(state, ARK1[round]); |
|||
apply_sbox(state); |
|||
|
|||
apply_mds(state); |
|||
apply_constants(state, ARK2[round]); |
|||
apply_inv_sbox(state); |
|||
} |
|||
|
|||
static void apply_permutation(uint64_t *state) |
|||
{ |
|||
for (uint64_t i = 0; i < NUM_ROUNDS; i++) |
|||
{ |
|||
apply_round(state, i); |
|||
} |
|||
} |
|||
|
|||
/* ================================================================================================ |
|||
* RPO128 implementation. This is supposed to substitute SHAKE256 in the hash-to-point algorithm. |
|||
*/ |
|||
|
|||
#include "rpo.h" |
|||
|
|||
void rpo128_init(rpo128_context *rc) |
|||
{ |
|||
rc->dptr = 32; |
|||
|
|||
memset(rc->st.A, 0, sizeof rc->st.A); |
|||
} |
|||
|
|||
void rpo128_absorb(rpo128_context *rc, const uint8_t *in, size_t len) |
|||
{ |
|||
size_t dptr; |
|||
|
|||
dptr = (size_t)rc->dptr; |
|||
while (len > 0) |
|||
{ |
|||
size_t clen, u; |
|||
|
|||
/* 136 * 8 = 1088 bit for the rate portion in the case of SHAKE256 |
|||
* For RPO, this is 64 * 8 = 512 bits |
|||
* The capacity for SHAKE256 is at the end while for RPO128 it is at the beginning |
|||
*/ |
|||
clen = 96 - dptr; |
|||
if (clen > len) |
|||
{ |
|||
clen = len; |
|||
} |
|||
|
|||
for (u = 0; u < clen; u++) |
|||
{ |
|||
rc->st.dbuf[dptr + u] = in[u]; |
|||
} |
|||
|
|||
dptr += clen; |
|||
in += clen; |
|||
len -= clen; |
|||
if (dptr == 96) |
|||
{ |
|||
apply_permutation(rc->st.A); |
|||
dptr = 32; |
|||
} |
|||
} |
|||
rc->dptr = dptr; |
|||
} |
|||
|
|||
void rpo128_finalize(rpo128_context *rc) |
|||
{ |
|||
// Set dptr to the end of the buffer, so that first call to extract will call the permutation. |
|||
rc->dptr = 96; |
|||
} |
|||
|
|||
void rpo128_squeeze(rpo128_context *rc, uint8_t *out, size_t len) |
|||
{ |
|||
size_t dptr; |
|||
|
|||
dptr = (size_t)rc->dptr; |
|||
while (len > 0) |
|||
{ |
|||
size_t clen; |
|||
|
|||
if (dptr == 96) |
|||
{ |
|||
apply_permutation(rc->st.A); |
|||
dptr = 32; |
|||
} |
|||
clen = 96 - dptr; |
|||
if (clen > len) |
|||
{ |
|||
clen = len; |
|||
} |
|||
len -= clen; |
|||
|
|||
memcpy(out, rc->st.dbuf + dptr, clen); |
|||
dptr += clen; |
|||
out += clen; |
|||
} |
|||
rc->dptr = dptr; |
|||
} |
|||
|
|||
void rpo128_release(rpo128_context *rc) |
|||
{ |
|||
memset(rc->st.A, 0, sizeof rc->st.A); |
|||
rc->dptr = 32; |
|||
} |
|||
|
|||
/* ================================================================================================ |
|||
* Hash-to-Point algorithm implementation based on RPO128 |
|||
*/ |
|||
|
|||
void PQCLEAN_FALCON512_CLEAN_hash_to_point_rpo(rpo128_context *rc, uint16_t *x, unsigned logn) |
|||
{ |
|||
/* |
|||
* This implementation avoids the rejection sampling step needed in the |
|||
* per-the-spec implementation. It uses a remark in https://falcon-sign.info/falcon.pdf |
|||
* page 31, which argues that the current variant is secure for the parameters set by NIST. |
|||
* Avoiding the rejection-sampling step leads to an implementation that is constant-time. |
|||
* TODO: Check that the current implementation is indeed constant-time. |
|||
*/ |
|||
size_t n; |
|||
|
|||
n = (size_t)1 << logn; |
|||
while (n > 0) |
|||
{ |
|||
uint8_t buf[8]; |
|||
uint64_t w; |
|||
|
|||
rpo128_squeeze(rc, (void *)buf, sizeof buf); |
|||
w = ((uint64_t)(buf[7]) << 56) | |
|||
((uint64_t)(buf[6]) << 48) | |
|||
((uint64_t)(buf[5]) << 40) | |
|||
((uint64_t)(buf[4]) << 32) | |
|||
((uint64_t)(buf[3]) << 24) | |
|||
((uint64_t)(buf[2]) << 16) | |
|||
((uint64_t)(buf[1]) << 8) | |
|||
((uint64_t)(buf[0])); |
|||
|
|||
w %= M; |
|||
|
|||
*x++ = (uint16_t)w; |
|||
n--; |
|||
} |
|||
} |
@ -0,0 +1,83 @@ |
|||
#include <stdint.h> |
|||
#include <string.h> |
|||
|
|||
/* ================================================================================================ |
|||
* RPO hashing algorithm related structs and methods. |
|||
*/ |
|||
|
|||
/* |
|||
* RPO128 context. |
|||
* |
|||
* This structure is used by the hashing API. It is composed of an internal state that can be |
|||
* viewed as either: |
|||
* 1. 12 field elements in the Miden VM. |
|||
* 2. 96 bytes. |
|||
* |
|||
* The first view is used for the internal state in the context of the RPO hashing algorithm. The |
|||
* second view is used for the buffer used to absorb the data to be hashed. |
|||
* |
|||
* The pointer to the buffer is updated as the data is absorbed. |
|||
* |
|||
* 'rpo128_context' must be initialized with rpo128_init() before first use. |
|||
*/ |
|||
typedef struct |
|||
{ |
|||
union |
|||
{ |
|||
uint64_t A[12]; |
|||
uint8_t dbuf[96]; |
|||
} st; |
|||
uint64_t dptr; |
|||
} rpo128_context; |
|||
|
|||
/* |
|||
* Initializes an RPO state |
|||
*/ |
|||
void rpo128_init(rpo128_context *rc); |
|||
|
|||
/* |
|||
* Absorbs an array of bytes of length 'len' into the state. |
|||
*/ |
|||
void rpo128_absorb(rpo128_context *rc, const uint8_t *in, size_t len); |
|||
|
|||
/* |
|||
* Squeezes an array of bytes of length 'len' from the state. |
|||
*/ |
|||
void rpo128_squeeze(rpo128_context *rc, uint8_t *out, size_t len); |
|||
|
|||
/* |
|||
* Finalizes the state in preparation for squeezing. |
|||
* |
|||
* This function should be called after all the data has been absorbed. |
|||
* |
|||
* Note that the current implementation does not perform any sort of padding for domain separation |
|||
* purposes. The reason being that, for our purposes, we always perform the following sequence: |
|||
* 1. Absorb a Nonce (which is always 40 bytes packed as 8 field elements). |
|||
* 2. Absorb the message (which is always 4 field elements). |
|||
* 3. Call finalize. |
|||
* 4. Squeeze the output. |
|||
* 5. Call release. |
|||
*/ |
|||
void rpo128_finalize(rpo128_context *rc); |
|||
|
|||
/* |
|||
* Releases the state. |
|||
* |
|||
* This function should be called after the squeeze operation is finished. |
|||
*/ |
|||
void rpo128_release(rpo128_context *rc); |
|||
|
|||
/* ================================================================================================ |
|||
* Hash-to-Point algorithm for signature generation and signature verification. |
|||
*/ |
|||
|
|||
/* |
|||
* Hash-to-Point algorithm. |
|||
* |
|||
* This function generates a point in Z_q[x]/(phi) from a given message. |
|||
* |
|||
* It takes a finalized rpo128_context as input and it generates the coefficients of the polynomial |
|||
* representing the point. The coefficients are stored in the array 'x'. The number of coefficients |
|||
* is given by 'logn', which must in our case is 512. |
|||
*/ |
|||
void PQCLEAN_FALCON512_CLEAN_hash_to_point_rpo(rpo128_context *rc, uint16_t *x, unsigned logn); |
@ -0,0 +1,189 @@ |
|||
use libc::c_int;
|
|||
|
|||
// C IMPLEMENTATION INTERFACE
|
|||
// ================================================================================================
|
|||
|
|||
#[link(name = "rpo_falcon512", kind = "static")]
|
|||
extern "C" {
|
|||
/// Generate a new key pair. Public key goes into pk[], private key in sk[].
|
|||
/// Key sizes are exact (in bytes):
|
|||
/// - public (pk): 897
|
|||
/// - private (sk): 1281
|
|||
///
|
|||
/// Return value: 0 on success, -1 on error.
|
|||
pub fn PQCLEAN_FALCON512_CLEAN_crypto_sign_keypair_rpo(pk: *mut u8, sk: *mut u8) -> c_int;
|
|||
|
|||
/// Generate a new key pair from seed. Public key goes into pk[], private key in sk[].
|
|||
/// Key sizes are exact (in bytes):
|
|||
/// - public (pk): 897
|
|||
/// - private (sk): 1281
|
|||
///
|
|||
/// Return value: 0 on success, -1 on error.
|
|||
pub fn PQCLEAN_FALCON512_CLEAN_crypto_sign_keypair_from_seed_rpo(
|
|||
pk: *mut u8,
|
|||
sk: *mut u8,
|
|||
seed: *const u8,
|
|||
) -> c_int;
|
|||
|
|||
/// Compute a signature on a provided message (m, mlen), with a given private key (sk).
|
|||
/// Signature is written in sig[], with length written into *siglen. Signature length is
|
|||
/// variable; maximum signature length (in bytes) is 666.
|
|||
///
|
|||
/// sig[], m[] and sk[] may overlap each other arbitrarily.
|
|||
///
|
|||
/// Return value: 0 on success, -1 on error.
|
|||
pub fn PQCLEAN_FALCON512_CLEAN_crypto_sign_signature_rpo(
|
|||
sig: *mut u8,
|
|||
siglen: *mut usize,
|
|||
m: *const u8,
|
|||
mlen: usize,
|
|||
sk: *const u8,
|
|||
) -> c_int;
|
|||
|
|||
// TEST HELPERS
|
|||
// --------------------------------------------------------------------------------------------
|
|||
|
|||
/// Verify a signature (sig, siglen) on a message (m, mlen) with a given public key (pk).
|
|||
///
|
|||
/// sig[], m[] and pk[] may overlap each other arbitrarily.
|
|||
///
|
|||
/// Return value: 0 on success, -1 on error.
|
|||
#[cfg(test)]
|
|||
pub fn PQCLEAN_FALCON512_CLEAN_crypto_sign_verify_rpo(
|
|||
sig: *const u8,
|
|||
siglen: usize,
|
|||
m: *const u8,
|
|||
mlen: usize,
|
|||
pk: *const u8,
|
|||
) -> c_int;
|
|||
|
|||
/// Hash-to-Point algorithm.
|
|||
///
|
|||
/// This function generates a point in Z_q[x]/(phi) from a given message.
|
|||
///
|
|||
/// It takes a finalized rpo128_context as input and it generates the coefficients of the polynomial
|
|||
/// representing the point. The coefficients are stored in the array 'x'. The number of coefficients
|
|||
/// is given by 'logn', which must in our case is 512.
|
|||
#[cfg(test)]
|
|||
pub fn PQCLEAN_FALCON512_CLEAN_hash_to_point_rpo(
|
|||
rc: *mut Rpo128Context,
|
|||
x: *mut u16,
|
|||
logn: usize,
|
|||
);
|
|||
|
|||
#[cfg(test)]
|
|||
pub fn rpo128_init(sc: *mut Rpo128Context);
|
|||
|
|||
#[cfg(test)]
|
|||
pub fn rpo128_absorb(
|
|||
sc: *mut Rpo128Context,
|
|||
data: *const ::std::os::raw::c_void,
|
|||
len: libc::size_t,
|
|||
);
|
|||
|
|||
#[cfg(test)]
|
|||
pub fn rpo128_finalize(sc: *mut Rpo128Context);
|
|||
}
|
|||
|
|||
#[repr(C)]
|
|||
#[cfg(test)]
|
|||
pub struct Rpo128Context {
|
|||
pub content: [u64; 13usize],
|
|||
}
|
|||
|
|||
// TESTS
|
|||
// ================================================================================================
|
|||
|
|||
#[cfg(all(test, feature = "std"))]
|
|||
mod tests {
|
|||
use super::*;
|
|||
use crate::dsa::rpo_falcon512::{NONCE_LEN, PK_LEN, SIG_LEN, SK_LEN};
|
|||
use rand_utils::{rand_array, rand_value, rand_vector};
|
|||
|
|||
#[test]
|
|||
fn falcon_ffi() {
|
|||
unsafe {
|
|||
//let mut rng = rand::thread_rng();
|
|||
|
|||
// --- generate a key pair from a seed ----------------------------
|
|||
|
|||
let mut pk = [0u8; PK_LEN];
|
|||
let mut sk = [0u8; SK_LEN];
|
|||
let seed: [u8; NONCE_LEN] = rand_array();
|
|||
|
|||
assert_eq!(
|
|||
0,
|
|||
PQCLEAN_FALCON512_CLEAN_crypto_sign_keypair_from_seed_rpo(
|
|||
pk.as_mut_ptr(),
|
|||
sk.as_mut_ptr(),
|
|||
seed.as_ptr()
|
|||
)
|
|||
);
|
|||
|
|||
// --- sign a message and make sure it verifies -------------------
|
|||
|
|||
let mlen: usize = rand_value::<u16>() as usize;
|
|||
let msg: Vec<u8> = rand_vector(mlen);
|
|||
let mut detached_sig = [0u8; NONCE_LEN + SIG_LEN];
|
|||
let mut siglen = 0;
|
|||
|
|||
assert_eq!(
|
|||
0,
|
|||
PQCLEAN_FALCON512_CLEAN_crypto_sign_signature_rpo(
|
|||
detached_sig.as_mut_ptr(),
|
|||
&mut siglen as *mut usize,
|
|||
msg.as_ptr(),
|
|||
msg.len(),
|
|||
sk.as_ptr()
|
|||
)
|
|||
);
|
|||
|
|||
assert_eq!(
|
|||
0,
|
|||
PQCLEAN_FALCON512_CLEAN_crypto_sign_verify_rpo(
|
|||
detached_sig.as_ptr(),
|
|||
siglen,
|
|||
msg.as_ptr(),
|
|||
msg.len(),
|
|||
pk.as_ptr()
|
|||
)
|
|||
);
|
|||
|
|||
// --- check verification of different signature ------------------
|
|||
|
|||
assert_eq!(
|
|||
-1,
|
|||
PQCLEAN_FALCON512_CLEAN_crypto_sign_verify_rpo(
|
|||
detached_sig.as_ptr(),
|
|||
siglen,
|
|||
msg.as_ptr(),
|
|||
msg.len() - 1,
|
|||
pk.as_ptr()
|
|||
)
|
|||
);
|
|||
|
|||
// --- check verification against a different pub key -------------
|
|||
|
|||
let mut pk_alt = [0u8; PK_LEN];
|
|||
let mut sk_alt = [0u8; SK_LEN];
|
|||
assert_eq!(
|
|||
0,
|
|||
PQCLEAN_FALCON512_CLEAN_crypto_sign_keypair_rpo(
|
|||
pk_alt.as_mut_ptr(),
|
|||
sk_alt.as_mut_ptr()
|
|||
)
|
|||
);
|
|||
|
|||
assert_eq!(
|
|||
-1,
|
|||
PQCLEAN_FALCON512_CLEAN_crypto_sign_verify_rpo(
|
|||
detached_sig.as_ptr(),
|
|||
siglen,
|
|||
msg.as_ptr(),
|
|||
msg.len(),
|
|||
pk_alt.as_ptr()
|
|||
)
|
|||
);
|
|||
}
|
|||
}
|
|||
}
|
@ -0,0 +1,227 @@ |
|||
use super::{
|
|||
ByteReader, ByteWriter, Deserializable, DeserializationError, FalconError, Polynomial,
|
|||
PublicKeyBytes, Rpo256, SecretKeyBytes, Serializable, Signature, Word,
|
|||
};
|
|||
|
|||
#[cfg(feature = "std")]
|
|||
use super::{ffi, NonceBytes, StarkField, NONCE_LEN, PK_LEN, SIG_LEN, SK_LEN};
|
|||
|
|||
// PUBLIC KEY
|
|||
// ================================================================================================
|
|||
|
|||
/// A public key for verifying signatures.
|
|||
///
|
|||
/// The public key is a [Word] (i.e., 4 field elements) that is the hash of the coefficients of
|
|||
/// the polynomial representing the raw bytes of the expanded public key.
|
|||
///
|
|||
/// For Falcon-512, the first byte of the expanded public key is always equal to log2(512) i.e., 9.
|
|||
#[derive(Debug, Clone, Copy, PartialEq)]
|
|||
pub struct PublicKey(Word);
|
|||
|
|||
impl PublicKey {
|
|||
/// Returns a new [PublicKey] which is a commitment to the provided expanded public key.
|
|||
///
|
|||
/// # Errors
|
|||
/// Returns an error if the decoding of the public key fails.
|
|||
pub fn new(pk: PublicKeyBytes) -> Result<Self, FalconError> {
|
|||
let h = Polynomial::from_pub_key(&pk)?;
|
|||
let pk_felts = h.to_elements();
|
|||
let pk_digest = Rpo256::hash_elements(&pk_felts).into();
|
|||
Ok(Self(pk_digest))
|
|||
}
|
|||
|
|||
/// Verifies the provided signature against provided message and this public key.
|
|||
pub fn verify(&self, message: Word, signature: &Signature) -> bool {
|
|||
signature.verify(message, self.0)
|
|||
}
|
|||
}
|
|||
|
|||
impl From<PublicKey> for Word {
|
|||
fn from(key: PublicKey) -> Self {
|
|||
key.0
|
|||
}
|
|||
}
|
|||
|
|||
// KEY PAIR
|
|||
// ================================================================================================
|
|||
|
|||
/// A key pair (public and secret keys) for signing messages.
|
|||
///
|
|||
/// The secret key is a byte array of length [PK_LEN].
|
|||
/// The public key is a byte array of length [SK_LEN].
|
|||
#[derive(Debug, Clone, Copy, PartialEq)]
|
|||
pub struct KeyPair {
|
|||
public_key: PublicKeyBytes,
|
|||
secret_key: SecretKeyBytes,
|
|||
}
|
|||
|
|||
#[allow(clippy::new_without_default)]
|
|||
impl KeyPair {
|
|||
// CONSTRUCTORS
|
|||
// --------------------------------------------------------------------------------------------
|
|||
|
|||
/// Generates a (public_key, secret_key) key pair from OS-provided randomness.
|
|||
///
|
|||
/// # Errors
|
|||
/// Returns an error if key generation fails.
|
|||
#[cfg(feature = "std")]
|
|||
pub fn new() -> Result<Self, FalconError> {
|
|||
let mut public_key = [0u8; PK_LEN];
|
|||
let mut secret_key = [0u8; SK_LEN];
|
|||
|
|||
let res = unsafe {
|
|||
ffi::PQCLEAN_FALCON512_CLEAN_crypto_sign_keypair_rpo(
|
|||
public_key.as_mut_ptr(),
|
|||
secret_key.as_mut_ptr(),
|
|||
)
|
|||
};
|
|||
|
|||
if res == 0 {
|
|||
Ok(Self { public_key, secret_key })
|
|||
} else {
|
|||
Err(FalconError::KeyGenerationFailed)
|
|||
}
|
|||
}
|
|||
|
|||
/// Generates a (public_key, secret_key) key pair from the provided seed.
|
|||
///
|
|||
/// # Errors
|
|||
/// Returns an error if key generation fails.
|
|||
#[cfg(feature = "std")]
|
|||
pub fn from_seed(seed: &NonceBytes) -> Result<Self, FalconError> {
|
|||
let mut public_key = [0u8; PK_LEN];
|
|||
let mut secret_key = [0u8; SK_LEN];
|
|||
|
|||
let res = unsafe {
|
|||
ffi::PQCLEAN_FALCON512_CLEAN_crypto_sign_keypair_from_seed_rpo(
|
|||
public_key.as_mut_ptr(),
|
|||
secret_key.as_mut_ptr(),
|
|||
seed.as_ptr(),
|
|||
)
|
|||
};
|
|||
|
|||
if res == 0 {
|
|||
Ok(Self { public_key, secret_key })
|
|||
} else {
|
|||
Err(FalconError::KeyGenerationFailed)
|
|||
}
|
|||
}
|
|||
|
|||
// PUBLIC ACCESSORS
|
|||
// --------------------------------------------------------------------------------------------
|
|||
|
|||
/// Returns the public key corresponding to this key pair.
|
|||
pub fn public_key(&self) -> PublicKey {
|
|||
// TODO: memoize public key commitment as computing it requires quite a bit of hashing.
|
|||
// expect() is fine here because we assume that the key pair was constructed correctly.
|
|||
PublicKey::new(self.public_key).expect("invalid key pair")
|
|||
}
|
|||
|
|||
/// Returns the expanded public key corresponding to this key pair.
|
|||
pub fn expanded_public_key(&self) -> PublicKeyBytes {
|
|||
self.public_key
|
|||
}
|
|||
|
|||
// SIGNATURE GENERATION
|
|||
// --------------------------------------------------------------------------------------------
|
|||
|
|||
/// Signs a message with a secret key and a seed.
|
|||
///
|
|||
/// # Errors
|
|||
/// Returns an error of signature generation fails.
|
|||
#[cfg(feature = "std")]
|
|||
pub fn sign(&self, message: Word) -> Result<Signature, FalconError> {
|
|||
let msg = message.iter().flat_map(|e| e.as_int().to_le_bytes()).collect::<Vec<_>>();
|
|||
let msg_len = msg.len();
|
|||
let mut sig = [0_u8; SIG_LEN + NONCE_LEN];
|
|||
let mut sig_len: usize = 0;
|
|||
|
|||
let res = unsafe {
|
|||
ffi::PQCLEAN_FALCON512_CLEAN_crypto_sign_signature_rpo(
|
|||
sig.as_mut_ptr(),
|
|||
&mut sig_len as *mut usize,
|
|||
msg.as_ptr(),
|
|||
msg_len,
|
|||
self.secret_key.as_ptr(),
|
|||
)
|
|||
};
|
|||
|
|||
if res == 0 {
|
|||
Ok(Signature { sig, pk: self.public_key })
|
|||
} else {
|
|||
Err(FalconError::SigGenerationFailed)
|
|||
}
|
|||
}
|
|||
}
|
|||
|
|||
// SERIALIZATION / DESERIALIZATION
|
|||
// ================================================================================================
|
|||
|
|||
impl Serializable for KeyPair {
|
|||
fn write_into<W: ByteWriter>(&self, target: &mut W) {
|
|||
target.write_bytes(&self.public_key);
|
|||
target.write_bytes(&self.secret_key);
|
|||
}
|
|||
}
|
|||
|
|||
impl Deserializable for KeyPair {
|
|||
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
|
|||
let public_key: PublicKeyBytes = source.read_array()?;
|
|||
let secret_key: SecretKeyBytes = source.read_array()?;
|
|||
Ok(Self { public_key, secret_key })
|
|||
}
|
|||
}
|
|||
|
|||
// TESTS
|
|||
// ================================================================================================
|
|||
|
|||
#[cfg(all(test, feature = "std"))]
|
|||
mod tests {
|
|||
use super::{super::Felt, KeyPair, NonceBytes, Word};
|
|||
use rand_utils::{rand_array, rand_vector};
|
|||
|
|||
#[test]
|
|||
fn test_falcon_verification() {
|
|||
// generate random keys
|
|||
let keys = KeyPair::new().unwrap();
|
|||
let pk = keys.public_key();
|
|||
|
|||
// sign a random message
|
|||
let message: Word = rand_vector::<Felt>(4).try_into().expect("Should not fail.");
|
|||
let signature = keys.sign(message);
|
|||
|
|||
// make sure the signature verifies correctly
|
|||
assert!(pk.verify(message, signature.as_ref().unwrap()));
|
|||
|
|||
// a signature should not verify against a wrong message
|
|||
let message2: Word = rand_vector::<Felt>(4).try_into().expect("Should not fail.");
|
|||
assert!(!pk.verify(message2, signature.as_ref().unwrap()));
|
|||
|
|||
// a signature should not verify against a wrong public key
|
|||
let keys2 = KeyPair::new().unwrap();
|
|||
assert!(!keys2.public_key().verify(message, signature.as_ref().unwrap()))
|
|||
}
|
|||
|
|||
#[test]
|
|||
fn test_falcon_verification_from_seed() {
|
|||
// generate keys from a random seed
|
|||
let seed: NonceBytes = rand_array();
|
|||
let keys = KeyPair::from_seed(&seed).unwrap();
|
|||
let pk = keys.public_key();
|
|||
|
|||
// sign a random message
|
|||
let message: Word = rand_vector::<Felt>(4).try_into().expect("Should not fail.");
|
|||
let signature = keys.sign(message);
|
|||
|
|||
// make sure the signature verifies correctly
|
|||
assert!(pk.verify(message, signature.as_ref().unwrap()));
|
|||
|
|||
// a signature should not verify against a wrong message
|
|||
let message2: Word = rand_vector::<Felt>(4).try_into().expect("Should not fail.");
|
|||
assert!(!pk.verify(message2, signature.as_ref().unwrap()));
|
|||
|
|||
// a signature should not verify against a wrong public key
|
|||
let keys2 = KeyPair::new().unwrap();
|
|||
assert!(!keys2.public_key().verify(message, signature.as_ref().unwrap()))
|
|||
}
|
|||
}
|
@ -0,0 +1,60 @@ |
|||
use crate::{
|
|||
hash::rpo::Rpo256,
|
|||
utils::{
|
|||
collections::Vec, ByteReader, ByteWriter, Deserializable, DeserializationError,
|
|||
Serializable,
|
|||
},
|
|||
Felt, StarkField, Word, ZERO,
|
|||
};
|
|||
|
|||
#[cfg(feature = "std")]
|
|||
mod ffi;
|
|||
|
|||
mod error;
|
|||
mod keys;
|
|||
mod polynomial;
|
|||
mod signature;
|
|||
|
|||
pub use error::FalconError;
|
|||
pub use keys::{KeyPair, PublicKey};
|
|||
pub use polynomial::Polynomial;
|
|||
pub use signature::Signature;
|
|||
|
|||
// CONSTANTS
|
|||
// ================================================================================================
|
|||
|
|||
// The Falcon modulus.
|
|||
const MODULUS: u16 = 12289;
|
|||
const MODULUS_MINUS_1_OVER_TWO: u16 = 6144;
|
|||
|
|||
// The Falcon parameters for Falcon-512. This is the degree of the polynomial `phi := x^N + 1`
|
|||
// defining the ring Z_p[x]/(phi).
|
|||
const N: usize = 512;
|
|||
const LOG_N: usize = 9;
|
|||
|
|||
/// Length of nonce used for key-pair generation.
|
|||
const NONCE_LEN: usize = 40;
|
|||
|
|||
/// Number of filed elements used to encode a nonce.
|
|||
const NONCE_ELEMENTS: usize = 8;
|
|||
|
|||
/// Public key length as a u8 vector.
|
|||
const PK_LEN: usize = 897;
|
|||
|
|||
/// Secret key length as a u8 vector.
|
|||
const SK_LEN: usize = 1281;
|
|||
|
|||
/// Signature length as a u8 vector.
|
|||
const SIG_LEN: usize = 626;
|
|||
|
|||
/// Bound on the squared-norm of the signature.
|
|||
const SIG_L2_BOUND: u64 = 34034726;
|
|||
|
|||
// TYPE ALIASES
|
|||
// ================================================================================================
|
|||
|
|||
type SignatureBytes = [u8; NONCE_LEN + SIG_LEN];
|
|||
type PublicKeyBytes = [u8; PK_LEN];
|
|||
type SecretKeyBytes = [u8; SK_LEN];
|
|||
type NonceBytes = [u8; NONCE_LEN];
|
|||
type NonceElements = [Felt; NONCE_ELEMENTS];
|
@ -0,0 +1,277 @@ |
|||
use super::{FalconError, Felt, Vec, LOG_N, MODULUS, MODULUS_MINUS_1_OVER_TWO, N, PK_LEN};
|
|||
use core::ops::{Add, Mul, Sub};
|
|||
|
|||
// FALCON POLYNOMIAL
|
|||
// ================================================================================================
|
|||
|
|||
/// A polynomial over Z_p[x]/(phi) where phi := x^512 + 1
|
|||
#[derive(Debug, Copy, Clone, PartialEq)]
|
|||
pub struct Polynomial([u16; N]);
|
|||
|
|||
impl Polynomial {
|
|||
// CONSTRUCTORS
|
|||
// --------------------------------------------------------------------------------------------
|
|||
|
|||
/// Constructs a new polynomial from a list of coefficients.
|
|||
///
|
|||
/// # Safety
|
|||
/// This constructor validates that the coefficients are in the valid range only in debug mode.
|
|||
pub unsafe fn new(data: [u16; N]) -> Self {
|
|||
for value in data {
|
|||
debug_assert!(value < MODULUS);
|
|||
}
|
|||
|
|||
Self(data)
|
|||
}
|
|||
|
|||
/// Decodes raw bytes representing a public key into a polynomial in Z_p[x]/(phi).
|
|||
///
|
|||
/// # Errors
|
|||
/// Returns an error if:
|
|||
/// - The provided input is not exactly 897 bytes long.
|
|||
/// - The first byte of the input is not equal to log2(512) i.e., 9.
|
|||
/// - Any of the coefficients encoded in the provided input is greater than or equal to the
|
|||
/// Falcon field modulus.
|
|||
pub fn from_pub_key(input: &[u8]) -> Result<Self, FalconError> {
|
|||
if input.len() != PK_LEN {
|
|||
return Err(FalconError::PubKeyDecodingInvalidLength(input.len()));
|
|||
}
|
|||
|
|||
if input[0] != LOG_N as u8 {
|
|||
return Err(FalconError::PubKeyDecodingInvalidTag(input[0]));
|
|||
}
|
|||
|
|||
let mut acc = 0_u32;
|
|||
let mut acc_len = 0;
|
|||
|
|||
let mut output = [0_u16; N];
|
|||
let mut output_idx = 0;
|
|||
|
|||
for &byte in input.iter().skip(1) {
|
|||
acc = (acc << 8) | (byte as u32);
|
|||
acc_len += 8;
|
|||
|
|||
if acc_len >= 14 {
|
|||
acc_len -= 14;
|
|||
let w = (acc >> acc_len) & 0x3FFF;
|
|||
if w >= MODULUS as u32 {
|
|||
return Err(FalconError::PubKeyDecodingInvalidCoefficient(w));
|
|||
}
|
|||
output[output_idx] = w as u16;
|
|||
output_idx += 1;
|
|||
}
|
|||
}
|
|||
|
|||
if (acc & ((1u32 << acc_len) - 1)) == 0 {
|
|||
Ok(Self(output))
|
|||
} else {
|
|||
Err(FalconError::PubKeyDecodingExtraData)
|
|||
}
|
|||
}
|
|||
|
|||
/// Decodes the signature into the coefficients of a polynomial in Z_p[x]/(phi). It assumes
|
|||
/// that the signature has been encoded using the uncompressed format.
|
|||
///
|
|||
/// # Errors
|
|||
/// Returns an error if:
|
|||
/// - The signature has been encoded using a different algorithm than the reference compressed
|
|||
/// encoding algorithm.
|
|||
/// - The encoded signature polynomial is in Z_p[x]/(phi') where phi' = x^N' + 1 and N' != 512.
|
|||
/// - While decoding the high bits of a coefficient, the current accumulated value of its
|
|||
/// high bits is larger than 2048.
|
|||
/// - The decoded coefficient is -0.
|
|||
/// - The remaining unused bits in the last byte of `input` are non-zero.
|
|||
pub fn from_signature(input: &[u8]) -> Result<Self, FalconError> {
|
|||
let (encoding, log_n) = (input[0] >> 4, input[0] & 0b00001111);
|
|||
|
|||
if encoding != 0b0011 {
|
|||
return Err(FalconError::SigDecodingIncorrectEncodingAlgorithm);
|
|||
}
|
|||
if log_n != 0b1001 {
|
|||
return Err(FalconError::SigDecodingNotSupportedDegree(log_n));
|
|||
}
|
|||
|
|||
let input = &input[41..];
|
|||
let mut input_idx = 0;
|
|||
let mut acc = 0u32;
|
|||
let mut acc_len = 0;
|
|||
let mut output = [0_u16; N];
|
|||
|
|||
for e in output.iter_mut() {
|
|||
acc = (acc << 8) | (input[input_idx] as u32);
|
|||
input_idx += 1;
|
|||
let b = acc >> acc_len;
|
|||
let s = b & 128;
|
|||
let mut m = b & 127;
|
|||
|
|||
loop {
|
|||
if acc_len == 0 {
|
|||
acc = (acc << 8) | (input[input_idx] as u32);
|
|||
input_idx += 1;
|
|||
acc_len = 8;
|
|||
}
|
|||
acc_len -= 1;
|
|||
if ((acc >> acc_len) & 1) != 0 {
|
|||
break;
|
|||
}
|
|||
m += 128;
|
|||
if m >= 2048 {
|
|||
return Err(FalconError::SigDecodingTooBigHighBits(m));
|
|||
}
|
|||
}
|
|||
if s != 0 && m == 0 {
|
|||
return Err(FalconError::SigDecodingMinusZero);
|
|||
}
|
|||
|
|||
*e = if s != 0 { (MODULUS as u32 - m) as u16 } else { m as u16 };
|
|||
}
|
|||
|
|||
if (acc & ((1 << acc_len) - 1)) != 0 {
|
|||
return Err(FalconError::SigDecodingNonZeroUnusedBitsLastByte);
|
|||
}
|
|||
|
|||
Ok(Self(output))
|
|||
}
|
|||
|
|||
// PUBLIC ACCESSORS
|
|||
// --------------------------------------------------------------------------------------------
|
|||
|
|||
/// Returns the coefficients of this polynomial as integers.
|
|||
pub fn inner(&self) -> [u16; N] {
|
|||
self.0
|
|||
}
|
|||
|
|||
/// Returns the coefficients of this polynomial as field elements.
|
|||
pub fn to_elements(&self) -> Vec<Felt> {
|
|||
self.0.iter().map(|&a| Felt::from(a)).collect()
|
|||
}
|
|||
|
|||
// POLYNOMIAL OPERATIONS
|
|||
// --------------------------------------------------------------------------------------------
|
|||
|
|||
/// Multiplies two polynomials over Z_p[x] without reducing modulo p. Given that the degrees
|
|||
/// of the input polynomials are less than 512 and their coefficients are less than the modulus
|
|||
/// q equal to 12289, the resulting product polynomial is guaranteed to have coefficients less
|
|||
/// than the Miden prime.
|
|||
///
|
|||
/// Note that this multiplication is not over Z_p[x]/(phi).
|
|||
pub fn mul_modulo_p(a: &Self, b: &Self) -> [u64; 1024] {
|
|||
let mut c = [0; 2 * N];
|
|||
for i in 0..N {
|
|||
for j in 0..N {
|
|||
c[i + j] += a.0[i] as u64 * b.0[j] as u64;
|
|||
}
|
|||
}
|
|||
|
|||
c
|
|||
}
|
|||
|
|||
/// Reduces a polynomial, that is the product of two polynomials over Z_p[x], modulo
|
|||
/// the irreducible polynomial phi. This results in an element in Z_p[x]/(phi).
|
|||
pub fn reduce_negacyclic(a: &[u64; 1024]) -> Self {
|
|||
let mut c = [0; N];
|
|||
for i in 0..N {
|
|||
let ai = a[N + i] % MODULUS as u64;
|
|||
let neg_ai = (MODULUS - ai as u16) % MODULUS;
|
|||
|
|||
let bi = (a[i] % MODULUS as u64) as u16;
|
|||
c[i] = (neg_ai + bi) % MODULUS;
|
|||
}
|
|||
|
|||
Self(c)
|
|||
}
|
|||
|
|||
/// Computes the norm squared of a polynomial in Z_p[x]/(phi) after normalizing its
|
|||
/// coefficients to be in the interval (-p/2, p/2].
|
|||
pub fn sq_norm(&self) -> u64 {
|
|||
let mut res = 0;
|
|||
for e in self.0 {
|
|||
if e > MODULUS_MINUS_1_OVER_TWO {
|
|||
res += (MODULUS - e) as u64 * (MODULUS - e) as u64
|
|||
} else {
|
|||
res += e as u64 * e as u64
|
|||
}
|
|||
}
|
|||
res
|
|||
}
|
|||
}
|
|||
|
|||
// Returns a polynomial representing the zero polynomial i.e. default element.
|
|||
impl Default for Polynomial {
|
|||
fn default() -> Self {
|
|||
Self([0_u16; N])
|
|||
}
|
|||
}
|
|||
|
|||
/// Multiplication over Z_p[x]/(phi)
|
|||
impl Mul for Polynomial {
|
|||
type Output = Self;
|
|||
|
|||
fn mul(self, other: Self) -> <Self as Mul<Self>>::Output {
|
|||
let mut result = [0_u16; N];
|
|||
for j in 0..N {
|
|||
for k in 0..N {
|
|||
let i = (j + k) % N;
|
|||
let a = self.0[j] as usize;
|
|||
let b = other.0[k] as usize;
|
|||
let q = MODULUS as usize;
|
|||
let mut prod = a * b % q;
|
|||
if (N - 1) < (j + k) {
|
|||
prod = (q - prod) % q;
|
|||
}
|
|||
result[i] = ((result[i] as usize + prod) % q) as u16;
|
|||
}
|
|||
}
|
|||
|
|||
Polynomial(result)
|
|||
}
|
|||
}
|
|||
|
|||
/// Addition over Z_p[x]/(phi)
|
|||
impl Add for Polynomial {
|
|||
type Output = Self;
|
|||
|
|||
fn add(self, other: Self) -> <Self as Add<Self>>::Output {
|
|||
let mut res = self;
|
|||
res.0.iter_mut().zip(other.0.iter()).for_each(|(x, y)| *x = (*x + *y) % MODULUS);
|
|||
|
|||
res
|
|||
}
|
|||
}
|
|||
|
|||
/// Subtraction over Z_p[x]/(phi)
|
|||
impl Sub for Polynomial {
|
|||
type Output = Self;
|
|||
|
|||
fn sub(self, other: Self) -> <Self as Add<Self>>::Output {
|
|||
let mut res = self;
|
|||
res.0
|
|||
.iter_mut()
|
|||
.zip(other.0.iter())
|
|||
.for_each(|(x, y)| *x = (*x + MODULUS - *y) % MODULUS);
|
|||
|
|||
res
|
|||
}
|
|||
}
|
|||
|
|||
// TESTS
|
|||
// ================================================================================================
|
|||
|
|||
#[cfg(test)]
|
|||
mod tests {
|
|||
use super::{Polynomial, N};
|
|||
|
|||
#[test]
|
|||
fn test_negacyclic_reduction() {
|
|||
let coef1: [u16; N] = rand_utils::rand_array();
|
|||
let coef2: [u16; N] = rand_utils::rand_array();
|
|||
|
|||
let poly1 = Polynomial(coef1);
|
|||
let poly2 = Polynomial(coef2);
|
|||
|
|||
assert_eq!(
|
|||
poly1 * poly2,
|
|||
Polynomial::reduce_negacyclic(&Polynomial::mul_modulo_p(&poly1, &poly2))
|
|||
);
|
|||
}
|
|||
}
|
@ -0,0 +1,262 @@ |
|||
use super::{
|
|||
ByteReader, ByteWriter, Deserializable, DeserializationError, NonceBytes, NonceElements,
|
|||
Polynomial, PublicKeyBytes, Rpo256, Serializable, SignatureBytes, StarkField, Word, MODULUS, N,
|
|||
SIG_L2_BOUND, ZERO,
|
|||
};
|
|||
use crate::utils::string::ToString;
|
|||
|
|||
// FALCON SIGNATURE
|
|||
// ================================================================================================
|
|||
|
|||
/// An RPO Falcon512 signature over a message.
|
|||
///
|
|||
/// The signature is a pair of polynomials (s1, s2) in (Z_p[x]/(phi))^2, where:
|
|||
/// - p := 12289
|
|||
/// - phi := x^512 + 1
|
|||
/// - s1 = c - s2 * h
|
|||
/// - h is a polynomial representing the public key and c is a polynomial that is the hash-to-point
|
|||
/// of the message being signed.
|
|||
///
|
|||
/// The signature verifies if and only if:
|
|||
/// 1. s1 = c - s2 * h
|
|||
/// 2. |s1|^2 + |s2|^2 <= SIG_L2_BOUND
|
|||
///
|
|||
/// where |.| is the norm.
|
|||
///
|
|||
/// [Signature] also includes the extended public key which is serialized as:
|
|||
/// 1. 1 byte representing the log2(512) i.e., 9.
|
|||
/// 2. 896 bytes for the public key. This is decoded into the `h` polynomial above.
|
|||
///
|
|||
/// The actual signature is serialized as:
|
|||
/// 1. A header byte specifying the algorithm used to encode the coefficients of the `s2` polynomial
|
|||
/// together with the degree of the irreducible polynomial phi.
|
|||
/// The general format of this byte is 0b0cc1nnnn where:
|
|||
/// a. cc is either 01 when the compressed encoding algorithm is used and 10 when the
|
|||
/// uncompressed algorithm is used.
|
|||
/// b. nnnn is log2(N) where N is the degree of the irreducible polynomial phi.
|
|||
/// The current implementation works always with cc equal to 0b01 and nnnn equal to 0b1001 and
|
|||
/// thus the header byte is always equal to 0b00111001.
|
|||
/// 2. 40 bytes for the nonce.
|
|||
/// 3. 625 bytes encoding the `s2` polynomial above.
|
|||
///
|
|||
/// The total size of the signature (including the extended public key) is 1563 bytes.
|
|||
pub struct Signature {
|
|||
pub(super) pk: PublicKeyBytes,
|
|||
pub(super) sig: SignatureBytes,
|
|||
}
|
|||
|
|||
impl Signature {
|
|||
// PUBLIC ACCESSORS
|
|||
// --------------------------------------------------------------------------------------------
|
|||
|
|||
/// Returns the public key polynomial h.
|
|||
pub fn pub_key_poly(&self) -> Polynomial {
|
|||
// TODO: memoize
|
|||
// we assume that the signature was constructed with a valid public key, and thus
|
|||
// expect() is OK here.
|
|||
Polynomial::from_pub_key(&self.pk).expect("invalid public key")
|
|||
}
|
|||
|
|||
/// Returns the nonce component of the signature represented as field elements.
|
|||
///
|
|||
/// Nonce bytes are converted to field elements by taking consecutive 5 byte chunks
|
|||
/// of the nonce and interpreting them as field elements.
|
|||
pub fn nonce(&self) -> NonceElements {
|
|||
// we assume that the signature was constructed with a valid signature, and thus
|
|||
// expect() is OK here.
|
|||
let nonce = self.sig[1..41].try_into().expect("invalid signature");
|
|||
decode_nonce(nonce)
|
|||
}
|
|||
|
|||
// Returns the polynomial representation of the signature in Z_p[x]/(phi).
|
|||
pub fn sig_poly(&self) -> Polynomial {
|
|||
// TODO: memoize
|
|||
// we assume that the signature was constructed with a valid signature, and thus
|
|||
// expect() is OK here.
|
|||
Polynomial::from_signature(&self.sig).expect("invalid signature")
|
|||
}
|
|||
|
|||
// HASH-TO-POINT
|
|||
// --------------------------------------------------------------------------------------------
|
|||
|
|||
/// Returns a polynomial in Z_p[x]/(phi) representing the hash of the provided message.
|
|||
pub fn hash_to_point(&self, message: Word) -> Polynomial {
|
|||
hash_to_point(message, &self.nonce())
|
|||
}
|
|||
|
|||
// SIGNATURE VERIFICATION
|
|||
// --------------------------------------------------------------------------------------------
|
|||
/// Returns true if this signature is a valid signature for the specified message generated
|
|||
/// against key pair matching the specified public key commitment.
|
|||
pub fn verify(&self, message: Word, pubkey_com: Word) -> bool {
|
|||
// Make sure the expanded public key matches the provided public key commitment
|
|||
let h = self.pub_key_poly();
|
|||
let h_digest: Word = Rpo256::hash_elements(&h.to_elements()).into();
|
|||
if h_digest != pubkey_com {
|
|||
return false;
|
|||
}
|
|||
|
|||
// Make sure the signature is valid
|
|||
let s2 = self.sig_poly();
|
|||
let c = self.hash_to_point(message);
|
|||
|
|||
let s1 = c - s2 * h;
|
|||
|
|||
let sq_norm = s1.sq_norm() + s2.sq_norm();
|
|||
sq_norm <= SIG_L2_BOUND
|
|||
}
|
|||
}
|
|||
|
|||
// SERIALIZATION / DESERIALIZATION
|
|||
// ================================================================================================
|
|||
|
|||
impl Serializable for Signature {
|
|||
fn write_into<W: ByteWriter>(&self, target: &mut W) {
|
|||
target.write_bytes(&self.pk);
|
|||
target.write_bytes(&self.sig);
|
|||
}
|
|||
}
|
|||
|
|||
impl Deserializable for Signature {
|
|||
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
|
|||
let pk: PublicKeyBytes = source.read_array()?;
|
|||
let sig: SignatureBytes = source.read_array()?;
|
|||
|
|||
// make sure public key and signature can be decoded correctly
|
|||
Polynomial::from_pub_key(&pk)
|
|||
.map_err(|err| DeserializationError::InvalidValue(err.to_string()))?;
|
|||
Polynomial::from_signature(&sig[41..])
|
|||
.map_err(|err| DeserializationError::InvalidValue(err.to_string()))?;
|
|||
|
|||
Ok(Self { pk, sig })
|
|||
}
|
|||
}
|
|||
|
|||
// HELPER FUNCTIONS
|
|||
// ================================================================================================
|
|||
|
|||
/// Returns a polynomial in Z_p[x]/(phi) representing the hash of the provided message and
|
|||
/// nonce.
|
|||
fn hash_to_point(message: Word, nonce: &NonceElements) -> Polynomial {
|
|||
let mut state = [ZERO; Rpo256::STATE_WIDTH];
|
|||
|
|||
// absorb the nonce into the state
|
|||
for (&n, s) in nonce.iter().zip(state[Rpo256::RATE_RANGE].iter_mut()) {
|
|||
*s = n;
|
|||
}
|
|||
Rpo256::apply_permutation(&mut state);
|
|||
|
|||
// absorb message into the state
|
|||
for (&m, s) in message.iter().zip(state[Rpo256::RATE_RANGE].iter_mut()) {
|
|||
*s = m;
|
|||
}
|
|||
|
|||
// squeeze the coefficients of the polynomial
|
|||
let mut i = 0;
|
|||
let mut res = [0_u16; N];
|
|||
for _ in 0..64 {
|
|||
Rpo256::apply_permutation(&mut state);
|
|||
for a in &state[Rpo256::RATE_RANGE] {
|
|||
res[i] = (a.as_int() % MODULUS as u64) as u16;
|
|||
i += 1;
|
|||
}
|
|||
}
|
|||
|
|||
// using the raw constructor is OK here because we reduce all coefficients by the modulus above
|
|||
unsafe { Polynomial::new(res) }
|
|||
}
|
|||
|
|||
/// Converts byte representation of the nonce into field element representation.
|
|||
fn decode_nonce(nonce: &NonceBytes) -> NonceElements {
|
|||
let mut buffer = [0_u8; 8];
|
|||
let mut result = [ZERO; 8];
|
|||
for (i, bytes) in nonce.chunks(5).enumerate() {
|
|||
buffer[..5].copy_from_slice(bytes);
|
|||
result[i] = u64::from_le_bytes(buffer).into();
|
|||
}
|
|||
|
|||
result
|
|||
}
|
|||
|
|||
// TESTS
|
|||
// ================================================================================================
|
|||
|
|||
#[cfg(all(test, feature = "std"))]
|
|||
mod tests {
|
|||
use super::{
|
|||
super::{ffi::*, Felt},
|
|||
*,
|
|||
};
|
|||
use libc::c_void;
|
|||
use rand_utils::rand_vector;
|
|||
|
|||
// Wrappers for unsafe functions
|
|||
impl Rpo128Context {
|
|||
/// Initializes the RPO state.
|
|||
pub fn init() -> Self {
|
|||
let mut ctx = Rpo128Context { content: [0u64; 13] };
|
|||
unsafe {
|
|||
rpo128_init(&mut ctx as *mut Rpo128Context);
|
|||
}
|
|||
ctx
|
|||
}
|
|||
|
|||
/// Absorbs data into the RPO state.
|
|||
pub fn absorb(&mut self, data: &[u8]) {
|
|||
unsafe {
|
|||
rpo128_absorb(
|
|||
self as *mut Rpo128Context,
|
|||
data.as_ptr() as *const c_void,
|
|||
data.len(),
|
|||
)
|
|||
}
|
|||
}
|
|||
|
|||
/// Finalizes the RPO state to prepare for squeezing.
|
|||
pub fn finalize(&mut self) {
|
|||
unsafe { rpo128_finalize(self as *mut Rpo128Context) }
|
|||
}
|
|||
}
|
|||
|
|||
#[test]
|
|||
fn test_hash_to_point() {
|
|||
// Create a random message and transform it into a u8 vector
|
|||
let msg_felts: Word = rand_vector::<Felt>(4).try_into().unwrap();
|
|||
let msg_bytes = msg_felts.iter().flat_map(|e| e.as_int().to_le_bytes()).collect::<Vec<_>>();
|
|||
|
|||
// Create a nonce i.e. a [u8; 40] array and pack into a [Felt; 8] array.
|
|||
let nonce: [u8; 40] = rand_vector::<u8>(40).try_into().unwrap();
|
|||
|
|||
let mut buffer = [0_u8; 64];
|
|||
for i in 0..8 {
|
|||
buffer[8 * i] = nonce[5 * i];
|
|||
buffer[8 * i + 1] = nonce[5 * i + 1];
|
|||
buffer[8 * i + 2] = nonce[5 * i + 2];
|
|||
buffer[8 * i + 3] = nonce[5 * i + 3];
|
|||
buffer[8 * i + 4] = nonce[5 * i + 4];
|
|||
}
|
|||
|
|||
// Initialize the RPO state
|
|||
let mut rng = Rpo128Context::init();
|
|||
|
|||
// Absorb the nonce and message into the RPO state
|
|||
rng.absorb(&buffer);
|
|||
rng.absorb(&msg_bytes);
|
|||
rng.finalize();
|
|||
|
|||
// Generate the coefficients of the hash-to-point polynomial.
|
|||
let mut res: [u16; N] = [0; N];
|
|||
|
|||
unsafe {
|
|||
PQCLEAN_FALCON512_CLEAN_hash_to_point_rpo(
|
|||
&mut rng as *mut Rpo128Context,
|
|||
res.as_mut_ptr(),
|
|||
9,
|
|||
);
|
|||
}
|
|||
|
|||
// Check that the coefficients are correct
|
|||
let nonce = decode_nonce(&nonce);
|
|||
assert_eq!(res, hash_to_point(msg_felts, &nonce).inner());
|
|||
}
|
|||
}
|
@ -0,0 +1,129 @@ |
|||
use clap::Parser;
|
|||
use miden_crypto::{
|
|||
hash::rpo::RpoDigest,
|
|||
merkle::MerkleError,
|
|||
Felt, Word, ONE,
|
|||
{hash::rpo::Rpo256, merkle::TieredSmt},
|
|||
};
|
|||
use rand_utils::rand_value;
|
|||
use std::time::Instant;
|
|||
|
|||
#[derive(Parser, Debug)]
|
|||
#[clap(
|
|||
name = "Benchmark",
|
|||
about = "Tiered SMT benchmark",
|
|||
version,
|
|||
rename_all = "kebab-case"
|
|||
)]
|
|||
pub struct BenchmarkCmd {
|
|||
/// Size of the tree
|
|||
#[clap(short = 's', long = "size")]
|
|||
size: u64,
|
|||
}
|
|||
|
|||
fn main() {
|
|||
benchmark_tsmt();
|
|||
}
|
|||
|
|||
/// Run a benchmark for the Tiered SMT.
|
|||
pub fn benchmark_tsmt() {
|
|||
let args = BenchmarkCmd::parse();
|
|||
let tree_size = args.size;
|
|||
|
|||
// prepare the `leaves` vector for tree creation
|
|||
let mut entries = Vec::new();
|
|||
for i in 0..tree_size {
|
|||
let key = rand_value::<RpoDigest>();
|
|||
let value = [ONE, ONE, ONE, Felt::new(i)];
|
|||
entries.push((key, value));
|
|||
}
|
|||
|
|||
let mut tree = construction(entries, tree_size).unwrap();
|
|||
insertion(&mut tree, tree_size).unwrap();
|
|||
proof_generation(&mut tree, tree_size).unwrap();
|
|||
}
|
|||
|
|||
/// Runs the construction benchmark for the Tiered SMT, returning the constructed tree.
|
|||
pub fn construction(entries: Vec<(RpoDigest, Word)>, size: u64) -> Result<TieredSmt, MerkleError> {
|
|||
println!("Running a construction benchmark:");
|
|||
let now = Instant::now();
|
|||
let tree = TieredSmt::with_entries(entries)?;
|
|||
let elapsed = now.elapsed();
|
|||
println!(
|
|||
"Constructed a TSMT with {} key-value pairs in {:.3} seconds",
|
|||
size,
|
|||
elapsed.as_secs_f32(),
|
|||
);
|
|||
|
|||
// Count how many nodes end up at each tier
|
|||
let mut nodes_num_16_32_48 = (0, 0, 0);
|
|||
|
|||
tree.upper_leaf_nodes().for_each(|(index, _)| match index.depth() {
|
|||
16 => nodes_num_16_32_48.0 += 1,
|
|||
32 => nodes_num_16_32_48.1 += 1,
|
|||
48 => nodes_num_16_32_48.2 += 1,
|
|||
_ => unreachable!(),
|
|||
});
|
|||
|
|||
println!("Number of nodes on depth 16: {}", nodes_num_16_32_48.0);
|
|||
println!("Number of nodes on depth 32: {}", nodes_num_16_32_48.1);
|
|||
println!("Number of nodes on depth 48: {}", nodes_num_16_32_48.2);
|
|||
println!("Number of nodes on depth 64: {}\n", tree.bottom_leaves().count());
|
|||
|
|||
Ok(tree)
|
|||
}
|
|||
|
|||
/// Runs the insertion benchmark for the Tiered SMT.
|
|||
pub fn insertion(tree: &mut TieredSmt, size: u64) -> Result<(), MerkleError> {
|
|||
println!("Running an insertion benchmark:");
|
|||
|
|||
let mut insertion_times = Vec::new();
|
|||
|
|||
for i in 0..20 {
|
|||
let test_key = Rpo256::hash(&rand_value::<u64>().to_be_bytes());
|
|||
let test_value = [ONE, ONE, ONE, Felt::new(size + i)];
|
|||
|
|||
let now = Instant::now();
|
|||
tree.insert(test_key, test_value);
|
|||
let elapsed = now.elapsed();
|
|||
insertion_times.push(elapsed.as_secs_f32());
|
|||
}
|
|||
|
|||
println!(
|
|||
"An average insertion time measured by 20 inserts into a TSMT with {} key-value pairs is {:.3} milliseconds\n",
|
|||
size,
|
|||
// calculate the average by dividing by 20 and convert to milliseconds by multiplying by
|
|||
// 1000. As a result, we can only multiply by 50
|
|||
insertion_times.iter().sum::<f32>() * 50f32,
|
|||
);
|
|||
|
|||
Ok(())
|
|||
}
|
|||
|
|||
/// Runs the proof generation benchmark for the Tiered SMT.
|
|||
pub fn proof_generation(tree: &mut TieredSmt, size: u64) -> Result<(), MerkleError> {
|
|||
println!("Running a proof generation benchmark:");
|
|||
|
|||
let mut insertion_times = Vec::new();
|
|||
|
|||
for i in 0..20 {
|
|||
let test_key = Rpo256::hash(&rand_value::<u64>().to_be_bytes());
|
|||
let test_value = [ONE, ONE, ONE, Felt::new(size + i)];
|
|||
tree.insert(test_key, test_value);
|
|||
|
|||
let now = Instant::now();
|
|||
let _proof = tree.prove(test_key);
|
|||
let elapsed = now.elapsed();
|
|||
insertion_times.push(elapsed.as_secs_f32());
|
|||
}
|
|||
|
|||
println!(
|
|||
"An average proving time measured by 20 value proofs in a TSMT with {} key-value pairs in {:.3} microseconds",
|
|||
size,
|
|||
// calculate the average by dividing by 20 and convert to microseconds by multiplying by
|
|||
// 1000000. As a result, we can only multiply by 50000
|
|||
insertion_times.iter().sum::<f32>() * 50000f32,
|
|||
);
|
|||
|
|||
Ok(())
|
|||
}
|
@ -0,0 +1,156 @@ |
|||
use super::{
|
|||
BTreeMap, KvMap, MerkleError, MerkleStore, NodeIndex, RpoDigest, StoreNode, Vec, Word,
|
|||
};
|
|||
use crate::utils::collections::Diff;
|
|||
|
|||
#[cfg(test)]
|
|||
use super::{super::ONE, Felt, SimpleSmt, EMPTY_WORD, ZERO};
|
|||
|
|||
// MERKLE STORE DELTA
|
|||
// ================================================================================================
|
|||
|
|||
/// [MerkleStoreDelta] stores a vector of ([RpoDigest], [MerkleTreeDelta]) tuples where the
|
|||
/// [RpoDigest] represents the root of the Merkle tree and [MerkleTreeDelta] represents the
|
|||
/// differences between the initial and final Merkle tree states.
|
|||
#[derive(Default, Debug, Clone, PartialEq, Eq)]
|
|||
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
|
|||
pub struct MerkleStoreDelta(pub Vec<(RpoDigest, MerkleTreeDelta)>);
|
|||
|
|||
// MERKLE TREE DELTA
|
|||
// ================================================================================================
|
|||
|
|||
/// [MerkleDelta] stores the differences between the initial and final Merkle tree states.
|
|||
///
|
|||
/// The differences are represented as follows:
|
|||
/// - depth: the depth of the merkle tree.
|
|||
/// - cleared_slots: indexes of slots where values were set to [ZERO; 4].
|
|||
/// - updated_slots: index-value pairs of slots where values were set to non [ZERO; 4] values.
|
|||
#[cfg(not(test))]
|
|||
#[derive(Debug, Clone, PartialEq, Eq)]
|
|||
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
|
|||
pub struct MerkleTreeDelta {
|
|||
depth: u8,
|
|||
cleared_slots: Vec<u64>,
|
|||
updated_slots: Vec<(u64, Word)>,
|
|||
}
|
|||
|
|||
impl MerkleTreeDelta {
|
|||
// CONSTRUCTOR
|
|||
// --------------------------------------------------------------------------------------------
|
|||
pub fn new(depth: u8) -> Self {
|
|||
Self {
|
|||
depth,
|
|||
cleared_slots: Vec::new(),
|
|||
updated_slots: Vec::new(),
|
|||
}
|
|||
}
|
|||
|
|||
// ACCESSORS
|
|||
// --------------------------------------------------------------------------------------------
|
|||
/// Returns the depth of the Merkle tree the [MerkleDelta] is associated with.
|
|||
pub fn depth(&self) -> u8 {
|
|||
self.depth
|
|||
}
|
|||
|
|||
/// Returns the indexes of slots where values were set to [ZERO; 4].
|
|||
pub fn cleared_slots(&self) -> &[u64] {
|
|||
&self.cleared_slots
|
|||
}
|
|||
|
|||
/// Returns the index-value pairs of slots where values were set to non [ZERO; 4] values.
|
|||
pub fn updated_slots(&self) -> &[(u64, Word)] {
|
|||
&self.updated_slots
|
|||
}
|
|||
|
|||
// MODIFIERS
|
|||
// --------------------------------------------------------------------------------------------
|
|||
/// Adds a slot index to the list of cleared slots.
|
|||
pub fn add_cleared_slot(&mut self, index: u64) {
|
|||
self.cleared_slots.push(index);
|
|||
}
|
|||
|
|||
/// Adds a slot index and a value to the list of updated slots.
|
|||
pub fn add_updated_slot(&mut self, index: u64, value: Word) {
|
|||
self.updated_slots.push((index, value));
|
|||
}
|
|||
}
|
|||
|
|||
/// Extracts a [MerkleDelta] object by comparing the leaves of two Merkle trees specifies by
|
|||
/// their roots and depth.
|
|||
pub fn merkle_tree_delta<T: KvMap<RpoDigest, StoreNode>>(
|
|||
tree_root_1: RpoDigest,
|
|||
tree_root_2: RpoDigest,
|
|||
depth: u8,
|
|||
merkle_store: &MerkleStore<T>,
|
|||
) -> Result<MerkleTreeDelta, MerkleError> {
|
|||
if tree_root_1 == tree_root_2 {
|
|||
return Ok(MerkleTreeDelta::new(depth));
|
|||
}
|
|||
|
|||
let tree_1_leaves: BTreeMap<NodeIndex, RpoDigest> =
|
|||
merkle_store.non_empty_leaves(tree_root_1, depth).collect();
|
|||
let tree_2_leaves: BTreeMap<NodeIndex, RpoDigest> =
|
|||
merkle_store.non_empty_leaves(tree_root_2, depth).collect();
|
|||
let diff = tree_1_leaves.diff(&tree_2_leaves);
|
|||
|
|||
// TODO: Refactor this diff implementation to prevent allocation of both BTree and Vec.
|
|||
Ok(MerkleTreeDelta {
|
|||
depth,
|
|||
cleared_slots: diff.removed.into_iter().map(|index| index.value()).collect(),
|
|||
updated_slots: diff
|
|||
.updated
|
|||
.into_iter()
|
|||
.map(|(index, leaf)| (index.value(), *leaf))
|
|||
.collect(),
|
|||
})
|
|||
}
|
|||
|
|||
// INTERNALS
|
|||
// --------------------------------------------------------------------------------------------
|
|||
#[cfg(test)]
|
|||
#[derive(Debug, Clone, PartialEq, Eq)]
|
|||
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
|
|||
pub struct MerkleTreeDelta {
|
|||
pub depth: u8,
|
|||
pub cleared_slots: Vec<u64>,
|
|||
pub updated_slots: Vec<(u64, Word)>,
|
|||
}
|
|||
|
|||
// MERKLE DELTA
|
|||
// ================================================================================================
|
|||
#[test]
|
|||
fn test_compute_merkle_delta() {
|
|||
let entries = vec![
|
|||
(10, [ZERO, ONE, Felt::new(2), Felt::new(3)]),
|
|||
(15, [Felt::new(4), Felt::new(5), Felt::new(6), Felt::new(7)]),
|
|||
(20, [Felt::new(8), Felt::new(9), Felt::new(10), Felt::new(11)]),
|
|||
(31, [Felt::new(12), Felt::new(13), Felt::new(14), Felt::new(15)]),
|
|||
];
|
|||
let simple_smt = SimpleSmt::with_leaves(30, entries.clone()).unwrap();
|
|||
let mut store: MerkleStore = (&simple_smt).into();
|
|||
let root = simple_smt.root();
|
|||
|
|||
// add a new node
|
|||
let new_value = [Felt::new(16), Felt::new(17), Felt::new(18), Felt::new(19)];
|
|||
let new_index = NodeIndex::new(simple_smt.depth(), 32).unwrap();
|
|||
let root = store.set_node(root, new_index, new_value.into()).unwrap().root;
|
|||
|
|||
// update an existing node
|
|||
let update_value = [Felt::new(20), Felt::new(21), Felt::new(22), Felt::new(23)];
|
|||
let update_idx = NodeIndex::new(simple_smt.depth(), entries[0].0).unwrap();
|
|||
let root = store.set_node(root, update_idx, update_value.into()).unwrap().root;
|
|||
|
|||
// remove a node
|
|||
let remove_idx = NodeIndex::new(simple_smt.depth(), entries[1].0).unwrap();
|
|||
let root = store.set_node(root, remove_idx, EMPTY_WORD.into()).unwrap().root;
|
|||
|
|||
let merkle_delta =
|
|||
merkle_tree_delta(simple_smt.root(), root, simple_smt.depth(), &store).unwrap();
|
|||
let expected_merkle_delta = MerkleTreeDelta {
|
|||
depth: simple_smt.depth(),
|
|||
cleared_slots: vec![remove_idx.value()],
|
|||
updated_slots: vec![(update_idx.value(), update_value), (new_index.value(), new_value)],
|
|||
};
|
|||
|
|||
assert_eq!(merkle_delta, expected_merkle_delta);
|
|||
}
|
@ -0,0 +1,54 @@ |
|||
use crate::{
|
|||
merkle::{MerklePath, NodeIndex, RpoDigest},
|
|||
utils::collections::Vec,
|
|||
};
|
|||
use core::fmt;
|
|||
|
|||
#[derive(Clone, Debug, PartialEq, Eq)]
|
|||
pub enum MerkleError {
|
|||
ConflictingRoots(Vec<RpoDigest>),
|
|||
DepthTooSmall(u8),
|
|||
DepthTooBig(u64),
|
|||
DuplicateValuesForIndex(u64),
|
|||
DuplicateValuesForKey(RpoDigest),
|
|||
InvalidIndex { depth: u8, value: u64 },
|
|||
InvalidDepth { expected: u8, provided: u8 },
|
|||
InvalidPath(MerklePath),
|
|||
InvalidNumEntries(usize, usize),
|
|||
NodeNotInSet(NodeIndex),
|
|||
NodeNotInStore(RpoDigest, NodeIndex),
|
|||
NumLeavesNotPowerOfTwo(usize),
|
|||
RootNotInStore(RpoDigest),
|
|||
}
|
|||
|
|||
impl fmt::Display for MerkleError {
|
|||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
|||
use MerkleError::*;
|
|||
match self {
|
|||
ConflictingRoots(roots) => write!(f, "the merkle paths roots do not match {roots:?}"),
|
|||
DepthTooSmall(depth) => write!(f, "the provided depth {depth} is too small"),
|
|||
DepthTooBig(depth) => write!(f, "the provided depth {depth} is too big"),
|
|||
DuplicateValuesForIndex(key) => write!(f, "multiple values provided for key {key}"),
|
|||
DuplicateValuesForKey(key) => write!(f, "multiple values provided for key {key}"),
|
|||
InvalidIndex{ depth, value} => write!(
|
|||
f,
|
|||
"the index value {value} is not valid for the depth {depth}"
|
|||
),
|
|||
InvalidDepth { expected, provided } => write!(
|
|||
f,
|
|||
"the provided depth {provided} is not valid for {expected}"
|
|||
),
|
|||
InvalidPath(_path) => write!(f, "the provided path is not valid"),
|
|||
InvalidNumEntries(max, provided) => write!(f, "the provided number of entries is {provided}, but the maximum for the given depth is {max}"),
|
|||
NodeNotInSet(index) => write!(f, "the node with index ({index}) is not in the set"),
|
|||
NodeNotInStore(hash, index) => write!(f, "the node {hash:?} with index ({index}) is not in the store"),
|
|||
NumLeavesNotPowerOfTwo(leaves) => {
|
|||
write!(f, "the leaves count {leaves} is not a power of 2")
|
|||
}
|
|||
RootNotInStore(root) => write!(f, "the root {:?} is not in the store", root),
|
|||
}
|
|||
}
|
|||
}
|
|||
|
|||
#[cfg(feature = "std")]
|
|||
impl std::error::Error for MerkleError {}
|
@ -1,408 +0,0 @@ |
|||
use super::{BTreeMap, MerkleError, MerklePath, NodeIndex, Rpo256, ValuePath, Vec};
|
|||
use crate::{hash::rpo::RpoDigest, Word};
|
|||
|
|||
// MERKLE PATH SET
|
|||
// ================================================================================================
|
|||
|
|||
/// A set of Merkle paths.
|
|||
#[derive(Debug, Clone, PartialEq, Eq)]
|
|||
pub struct MerklePathSet {
|
|||
root: RpoDigest,
|
|||
total_depth: u8,
|
|||
paths: BTreeMap<u64, MerklePath>,
|
|||
}
|
|||
|
|||
impl MerklePathSet {
|
|||
// CONSTRUCTOR
|
|||
// --------------------------------------------------------------------------------------------
|
|||
|
|||
/// Returns an empty MerklePathSet.
|
|||
pub fn new(depth: u8) -> Self {
|
|||
let root = RpoDigest::default();
|
|||
let paths = BTreeMap::new();
|
|||
|
|||
Self {
|
|||
root,
|
|||
total_depth: depth,
|
|||
paths,
|
|||
}
|
|||
}
|
|||
|
|||
/// Appends the provided paths iterator into the set.
|
|||
///
|
|||
/// Analogous to `[Self::add_path]`.
|
|||
pub fn with_paths<I>(self, paths: I) -> Result<Self, MerkleError>
|
|||
where
|
|||
I: IntoIterator<Item = (u64, RpoDigest, MerklePath)>,
|
|||
{
|
|||
paths.into_iter().try_fold(self, |mut set, (index, value, path)| {
|
|||
set.add_path(index, value.into(), path)?;
|
|||
Ok(set)
|
|||
})
|
|||
}
|
|||
|
|||
// PUBLIC ACCESSORS
|
|||
// --------------------------------------------------------------------------------------------
|
|||
|
|||
/// Returns the root to which all paths in this set resolve.
|
|||
pub const fn root(&self) -> RpoDigest {
|
|||
self.root
|
|||
}
|
|||
|
|||
/// Returns the depth of the Merkle tree implied by the paths stored in this set.
|
|||
///
|
|||
/// Merkle tree of depth 1 has two leaves, depth 2 has four leaves etc.
|
|||
pub const fn depth(&self) -> u8 {
|
|||
self.total_depth
|
|||
}
|
|||
|
|||
/// Returns a node at the specified index.
|
|||
///
|
|||
/// # Errors
|
|||
/// Returns an error if:
|
|||
/// * The specified index is not valid for the depth of structure.
|
|||
/// * Requested node does not exist in the set.
|
|||
pub fn get_node(&self, index: NodeIndex) -> Result<RpoDigest, MerkleError> {
|
|||
if index.depth() != self.total_depth {
|
|||
return Err(MerkleError::InvalidDepth {
|
|||
expected: self.total_depth,
|
|||
provided: index.depth(),
|
|||
});
|
|||
}
|
|||
|
|||
let parity = index.value() & 1;
|
|||
let path_key = index.value() - parity;
|
|||
self.paths
|
|||
.get(&path_key)
|
|||
.ok_or(MerkleError::NodeNotInSet(index))
|
|||
.map(|path| path[parity as usize])
|
|||
}
|
|||
|
|||
/// Returns a leaf at the specified index.
|
|||
///
|
|||
/// # Errors
|
|||
/// * The specified index is not valid for the depth of the structure.
|
|||
/// * Leaf with the requested path does not exist in the set.
|
|||
pub fn get_leaf(&self, index: u64) -> Result<Word, MerkleError> {
|
|||
let index = NodeIndex::new(self.depth(), index)?;
|
|||
Ok(self.get_node(index)?.into())
|
|||
}
|
|||
|
|||
/// Returns a Merkle path to the node at the specified index. The node itself is
|
|||
/// not included in the path.
|
|||
///
|
|||
/// # Errors
|
|||
/// Returns an error if:
|
|||
/// * The specified index is not valid for the depth of structure.
|
|||
/// * Node of the requested path does not exist in the set.
|
|||
pub fn get_path(&self, index: NodeIndex) -> Result<MerklePath, MerkleError> {
|
|||
if index.depth() != self.total_depth {
|
|||
return Err(MerkleError::InvalidDepth {
|
|||
expected: self.total_depth,
|
|||
provided: index.depth(),
|
|||
});
|
|||
}
|
|||
|
|||
let parity = index.value() & 1;
|
|||
let path_key = index.value() - parity;
|
|||
let mut path =
|
|||
self.paths.get(&path_key).cloned().ok_or(MerkleError::NodeNotInSet(index))?;
|
|||
path.remove(parity as usize);
|
|||
Ok(path)
|
|||
}
|
|||
|
|||
/// Returns all paths in this path set together with their indexes.
|
|||
pub fn to_paths(&self) -> Vec<(u64, ValuePath)> {
|
|||
let mut result = Vec::with_capacity(self.paths.len() * 2);
|
|||
|
|||
for (&index, path) in self.paths.iter() {
|
|||
// push path for the even index into the result
|
|||
let path1 = ValuePath {
|
|||
value: path[0],
|
|||
path: MerklePath::new(path[1..].to_vec()),
|
|||
};
|
|||
result.push((index, path1));
|
|||
|
|||
// push path for the odd index into the result
|
|||
let mut path2 = path.clone();
|
|||
let leaf2 = path2.remove(1);
|
|||
let path2 = ValuePath {
|
|||
value: leaf2,
|
|||
path: path2,
|
|||
};
|
|||
result.push((index + 1, path2));
|
|||
}
|
|||
|
|||
result
|
|||
}
|
|||
|
|||
// STATE MUTATORS
|
|||
// --------------------------------------------------------------------------------------------
|
|||
|
|||
/// Adds the specified Merkle path to this [MerklePathSet]. The `index` and `value` parameters
|
|||
/// specify the leaf node at which the path starts.
|
|||
///
|
|||
/// # Errors
|
|||
/// Returns an error if:
|
|||
/// - The specified index is is not valid in the context of this Merkle path set (i.e., the
|
|||
/// index implies a greater depth than is specified for this set).
|
|||
/// - The specified path is not consistent with other paths in the set (i.e., resolves to a
|
|||
/// different root).
|
|||
pub fn add_path(
|
|||
&mut self,
|
|||
index_value: u64,
|
|||
value: Word,
|
|||
mut path: MerklePath,
|
|||
) -> Result<(), MerkleError> {
|
|||
let mut index = NodeIndex::new(path.len() as u8, index_value)?;
|
|||
if index.depth() != self.total_depth {
|
|||
return Err(MerkleError::InvalidDepth {
|
|||
expected: self.total_depth,
|
|||
provided: index.depth(),
|
|||
});
|
|||
}
|
|||
|
|||
// update the current path
|
|||
let parity = index_value & 1;
|
|||
path.insert(parity as usize, value.into());
|
|||
|
|||
// traverse to the root, updating the nodes
|
|||
let root = Rpo256::merge(&[path[0], path[1]]);
|
|||
let root = path.iter().skip(2).copied().fold(root, |root, hash| {
|
|||
index.move_up();
|
|||
Rpo256::merge(&index.build_node(root, hash))
|
|||
});
|
|||
|
|||
// if the path set is empty (the root is all ZEROs), set the root to the root of the added
|
|||
// path; otherwise, the root of the added path must be identical to the current root
|
|||
if self.root == RpoDigest::default() {
|
|||
self.root = root;
|
|||
} else if self.root != root {
|
|||
return Err(MerkleError::ConflictingRoots([self.root, root].to_vec()));
|
|||
}
|
|||
|
|||
// finish updating the path
|
|||
let path_key = index_value - parity;
|
|||
self.paths.insert(path_key, path);
|
|||
Ok(())
|
|||
}
|
|||
|
|||
/// Replaces the leaf at the specified index with the provided value.
|
|||
///
|
|||
/// # Errors
|
|||
/// Returns an error if:
|
|||
/// * Requested node does not exist in the set.
|
|||
pub fn update_leaf(&mut self, base_index_value: u64, value: Word) -> Result<(), MerkleError> {
|
|||
let mut index = NodeIndex::new(self.depth(), base_index_value)?;
|
|||
let parity = index.value() & 1;
|
|||
let path_key = index.value() - parity;
|
|||
let path = match self.paths.get_mut(&path_key) {
|
|||
Some(path) => path,
|
|||
None => return Err(MerkleError::NodeNotInSet(index)),
|
|||
};
|
|||
|
|||
// Fill old_hashes vector -----------------------------------------------------------------
|
|||
let mut current_index = index;
|
|||
let mut old_hashes = Vec::with_capacity(path.len().saturating_sub(2));
|
|||
let mut root = Rpo256::merge(&[path[0], path[1]]);
|
|||
for hash in path.iter().skip(2).copied() {
|
|||
old_hashes.push(root);
|
|||
current_index.move_up();
|
|||
let input = current_index.build_node(hash, root);
|
|||
root = Rpo256::merge(&input);
|
|||
}
|
|||
|
|||
// Fill new_hashes vector -----------------------------------------------------------------
|
|||
path[index.is_value_odd() as usize] = value.into();
|
|||
|
|||
let mut new_hashes = Vec::with_capacity(path.len().saturating_sub(2));
|
|||
let mut new_root = Rpo256::merge(&[path[0], path[1]]);
|
|||
for path_hash in path.iter().skip(2).copied() {
|
|||
new_hashes.push(new_root);
|
|||
index.move_up();
|
|||
let input = current_index.build_node(path_hash, new_root);
|
|||
new_root = Rpo256::merge(&input);
|
|||
}
|
|||
|
|||
self.root = new_root;
|
|||
|
|||
// update paths ---------------------------------------------------------------------------
|
|||
for path in self.paths.values_mut() {
|
|||
for i in (0..old_hashes.len()).rev() {
|
|||
if path[i + 2] == old_hashes[i] {
|
|||
path[i + 2] = new_hashes[i];
|
|||
break;
|
|||
}
|
|||
}
|
|||
}
|
|||
|
|||
Ok(())
|
|||
}
|
|||
}
|
|||
|
|||
// TESTS
|
|||
// ================================================================================================
|
|||
|
|||
#[cfg(test)]
|
|||
mod tests {
|
|||
use super::*;
|
|||
use crate::merkle::{int_to_leaf, int_to_node};
|
|||
|
|||
#[test]
|
|||
fn get_root() {
|
|||
let leaf0 = int_to_node(0);
|
|||
let leaf1 = int_to_node(1);
|
|||
let leaf2 = int_to_node(2);
|
|||
let leaf3 = int_to_node(3);
|
|||
|
|||
let parent0 = calculate_parent_hash(leaf0, 0, leaf1);
|
|||
let parent1 = calculate_parent_hash(leaf2, 2, leaf3);
|
|||
|
|||
let root_exp = calculate_parent_hash(parent0, 0, parent1);
|
|||
|
|||
let set = super::MerklePathSet::new(2)
|
|||
.with_paths([(0, leaf0, vec![leaf1, parent1].into())])
|
|||
.unwrap();
|
|||
|
|||
assert_eq!(set.root(), root_exp);
|
|||
}
|
|||
|
|||
#[test]
|
|||
fn add_and_get_path() {
|
|||
let path_6 = vec![int_to_node(7), int_to_node(45), int_to_node(123)];
|
|||
let hash_6 = int_to_node(6);
|
|||
let index = 6_u64;
|
|||
let depth = 3_u8;
|
|||
let set = super::MerklePathSet::new(depth)
|
|||
.with_paths([(index, hash_6, path_6.clone().into())])
|
|||
.unwrap();
|
|||
let stored_path_6 = set.get_path(NodeIndex::make(depth, index)).unwrap();
|
|||
|
|||
assert_eq!(path_6, *stored_path_6);
|
|||
}
|
|||
|
|||
#[test]
|
|||
fn get_node() {
|
|||
let path_6 = vec![int_to_node(7), int_to_node(45), int_to_node(123)];
|
|||
let hash_6 = int_to_node(6);
|
|||
let index = 6_u64;
|
|||
let depth = 3_u8;
|
|||
let set = MerklePathSet::new(depth).with_paths([(index, hash_6, path_6.into())]).unwrap();
|
|||
|
|||
assert_eq!(int_to_node(6u64), set.get_node(NodeIndex::make(depth, index)).unwrap());
|
|||
}
|
|||
|
|||
#[test]
|
|||
fn update_leaf() {
|
|||
let hash_4 = int_to_node(4);
|
|||
let hash_5 = int_to_node(5);
|
|||
let hash_6 = int_to_node(6);
|
|||
let hash_7 = int_to_node(7);
|
|||
let hash_45 = calculate_parent_hash(hash_4, 12u64, hash_5);
|
|||
let hash_67 = calculate_parent_hash(hash_6, 14u64, hash_7);
|
|||
|
|||
let hash_0123 = int_to_node(123);
|
|||
|
|||
let path_6 = vec![hash_7, hash_45, hash_0123];
|
|||
let path_5 = vec![hash_4, hash_67, hash_0123];
|
|||
let path_4 = vec![hash_5, hash_67, hash_0123];
|
|||
|
|||
let index_6 = 6_u64;
|
|||
let index_5 = 5_u64;
|
|||
let index_4 = 4_u64;
|
|||
let depth = 3_u8;
|
|||
let mut set = MerklePathSet::new(depth)
|
|||
.with_paths([
|
|||
(index_6, hash_6, path_6.into()),
|
|||
(index_5, hash_5, path_5.into()),
|
|||
(index_4, hash_4, path_4.into()),
|
|||
])
|
|||
.unwrap();
|
|||
|
|||
let new_hash_6 = int_to_leaf(100);
|
|||
let new_hash_5 = int_to_leaf(55);
|
|||
|
|||
set.update_leaf(index_6, new_hash_6).unwrap();
|
|||
let new_path_4 = set.get_path(NodeIndex::make(depth, index_4)).unwrap();
|
|||
let new_hash_67 = calculate_parent_hash(new_hash_6.into(), 14_u64, hash_7);
|
|||
assert_eq!(new_hash_67, new_path_4[1]);
|
|||
|
|||
set.update_leaf(index_5, new_hash_5).unwrap();
|
|||
let new_path_4 = set.get_path(NodeIndex::make(depth, index_4)).unwrap();
|
|||
let new_path_6 = set.get_path(NodeIndex::make(depth, index_6)).unwrap();
|
|||
let new_hash_45 = calculate_parent_hash(new_hash_5.into(), 13_u64, hash_4);
|
|||
assert_eq!(new_hash_45, new_path_6[1]);
|
|||
assert_eq!(RpoDigest::from(new_hash_5), new_path_4[0]);
|
|||
}
|
|||
|
|||
#[test]
|
|||
fn depth_3_is_correct() {
|
|||
let a = int_to_node(1);
|
|||
let b = int_to_node(2);
|
|||
let c = int_to_node(3);
|
|||
let d = int_to_node(4);
|
|||
let e = int_to_node(5);
|
|||
let f = int_to_node(6);
|
|||
let g = int_to_node(7);
|
|||
let h = int_to_node(8);
|
|||
|
|||
let i = Rpo256::merge(&[a, b]);
|
|||
let j = Rpo256::merge(&[c, d]);
|
|||
let k = Rpo256::merge(&[e, f]);
|
|||
let l = Rpo256::merge(&[g, h]);
|
|||
|
|||
let m = Rpo256::merge(&[i, j]);
|
|||
let n = Rpo256::merge(&[k, l]);
|
|||
|
|||
let root = Rpo256::merge(&[m, n]);
|
|||
|
|||
let mut set = MerklePathSet::new(3);
|
|||
|
|||
let value = b;
|
|||
let index = 1;
|
|||
let path = MerklePath::new([a, j, n].to_vec());
|
|||
set.add_path(index, value.into(), path).unwrap();
|
|||
assert_eq!(*value, set.get_leaf(index).unwrap());
|
|||
assert_eq!(root, set.root());
|
|||
|
|||
let value = e;
|
|||
let index = 4;
|
|||
let path = MerklePath::new([f, l, m].to_vec());
|
|||
set.add_path(index, value.into(), path).unwrap();
|
|||
assert_eq!(*value, set.get_leaf(index).unwrap());
|
|||
assert_eq!(root, set.root());
|
|||
|
|||
let value = a;
|
|||
let index = 0;
|
|||
let path = MerklePath::new([b, j, n].to_vec());
|
|||
set.add_path(index, value.into(), path).unwrap();
|
|||
assert_eq!(*value, set.get_leaf(index).unwrap());
|
|||
assert_eq!(root, set.root());
|
|||
|
|||
let value = h;
|
|||
let index = 7;
|
|||
let path = MerklePath::new([g, k, m].to_vec());
|
|||
set.add_path(index, value.into(), path).unwrap();
|
|||
assert_eq!(*value, set.get_leaf(index).unwrap());
|
|||
assert_eq!(root, set.root());
|
|||
}
|
|||
|
|||
// HELPER FUNCTIONS
|
|||
// --------------------------------------------------------------------------------------------
|
|||
|
|||
const fn is_even(pos: u64) -> bool {
|
|||
pos & 1 == 0
|
|||
}
|
|||
|
|||
/// Calculates the hash of the parent node by two sibling ones
|
|||
/// - node — current node
|
|||
/// - node_pos — position of the current node
|
|||
/// - sibling — neighboring vertex in the tree
|
|||
fn calculate_parent_hash(node: RpoDigest, node_pos: u64, sibling: RpoDigest) -> RpoDigest {
|
|||
if is_even(node_pos) {
|
|||
Rpo256::merge(&[node, sibling])
|
|||
} else {
|
|||
Rpo256::merge(&[sibling, node])
|
|||
}
|
|||
}
|
|||
}
|
@ -0,0 +1,48 @@ |
|||
use core::fmt::Display;
|
|||
|
|||
#[derive(Debug, PartialEq, Eq)]
|
|||
pub enum TieredSmtProofError {
|
|||
EntriesEmpty,
|
|||
EmptyValueNotAllowed,
|
|||
MismatchedPrefixes(u64, u64),
|
|||
MultipleEntriesOutsideLastTier,
|
|||
NotATierPath(u8),
|
|||
PathTooLong,
|
|||
}
|
|||
|
|||
impl Display for TieredSmtProofError {
|
|||
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
|
|||
match self {
|
|||
TieredSmtProofError::EntriesEmpty => {
|
|||
write!(f, "Missing entries for tiered sparse merkle tree proof")
|
|||
}
|
|||
TieredSmtProofError::EmptyValueNotAllowed => {
|
|||
write!(
|
|||
f,
|
|||
"The empty value [0, 0, 0, 0] is not allowed inside a tiered sparse merkle tree"
|
|||
)
|
|||
}
|
|||
TieredSmtProofError::MismatchedPrefixes(first, second) => {
|
|||
write!(f, "Not all leaves have the same prefix. First {first} second {second}")
|
|||
}
|
|||
TieredSmtProofError::MultipleEntriesOutsideLastTier => {
|
|||
write!(f, "Multiple entries are only allowed for the last tier (depth 64)")
|
|||
}
|
|||
TieredSmtProofError::NotATierPath(got) => {
|
|||
write!(
|
|||
f,
|
|||
"Path length does not correspond to a tier. Got {got} Expected one of 16, 32, 48, 64"
|
|||
)
|
|||
}
|
|||
TieredSmtProofError::PathTooLong => {
|
|||
write!(
|
|||
f,
|
|||
"Path longer than maximum depth of 64 for tiered sparse merkle tree proof"
|
|||
)
|
|||
}
|
|||
}
|
|||
}
|
|||
}
|
|||
|
|||
#[cfg(feature = "std")]
|
|||
impl std::error::Error for TieredSmtProofError {}
|
@ -0,0 +1,419 @@ |
|||
use super::{
|
|||
BTreeMap, BTreeSet, EmptySubtreeRoots, InnerNodeInfo, LeafNodeIndex, MerkleError, MerklePath,
|
|||
NodeIndex, Rpo256, RpoDigest, Vec,
|
|||
};
|
|||
|
|||
// CONSTANTS
|
|||
// ================================================================================================
|
|||
|
|||
/// The number of levels between tiers.
|
|||
const TIER_SIZE: u8 = super::TieredSmt::TIER_SIZE;
|
|||
|
|||
/// Depths at which leaves can exist in a tiered SMT.
|
|||
const TIER_DEPTHS: [u8; 4] = super::TieredSmt::TIER_DEPTHS;
|
|||
|
|||
/// Maximum node depth. This is also the bottom tier of the tree.
|
|||
const MAX_DEPTH: u8 = super::TieredSmt::MAX_DEPTH;
|
|||
|
|||
// NODE STORE
|
|||
// ================================================================================================
|
|||
|
|||
/// A store of nodes for a Tiered Sparse Merkle tree.
|
|||
///
|
|||
/// The store contains information about all nodes as well as information about which of the nodes
|
|||
/// represent leaf nodes in a Tiered Sparse Merkle tree. In the current implementation, [BTreeSet]s
|
|||
/// are used to determine the position of the leaves in the tree.
|
|||
#[derive(Debug, Clone, PartialEq, Eq)]
|
|||
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
|
|||
pub struct NodeStore {
|
|||
nodes: BTreeMap<NodeIndex, RpoDigest>,
|
|||
upper_leaves: BTreeSet<NodeIndex>,
|
|||
bottom_leaves: BTreeSet<u64>,
|
|||
}
|
|||
|
|||
impl NodeStore {
|
|||
// CONSTRUCTOR
|
|||
// --------------------------------------------------------------------------------------------
|
|||
/// Returns a new instance of [NodeStore] instantiated with the specified root node.
|
|||
///
|
|||
/// Root node is assumed to be a root of an empty sparse Merkle tree.
|
|||
pub fn new(root_node: RpoDigest) -> Self {
|
|||
let mut nodes = BTreeMap::default();
|
|||
nodes.insert(NodeIndex::root(), root_node);
|
|||
|
|||
Self {
|
|||
nodes,
|
|||
upper_leaves: BTreeSet::default(),
|
|||
bottom_leaves: BTreeSet::default(),
|
|||
}
|
|||
}
|
|||
|
|||
// PUBLIC ACCESSORS
|
|||
// --------------------------------------------------------------------------------------------
|
|||
|
|||
/// Returns a node at the specified index.
|
|||
///
|
|||
/// # Errors
|
|||
/// Returns an error if:
|
|||
/// - The specified index depth is 0 or greater than 64.
|
|||
/// - The node with the specified index does not exists in the Merkle tree. This is possible
|
|||
/// when a leaf node with the same index prefix exists at a tier higher than the requested
|
|||
/// node.
|
|||
pub fn get_node(&self, index: NodeIndex) -> Result<RpoDigest, MerkleError> {
|
|||
self.validate_node_access(index)?;
|
|||
Ok(self.get_node_unchecked(&index))
|
|||
}
|
|||
|
|||
/// Returns a Merkle path from the node at the specified index to the root.
|
|||
///
|
|||
/// The node itself is not included in the path.
|
|||
///
|
|||
/// # Errors
|
|||
/// Returns an error if:
|
|||
/// - The specified index depth is 0 or greater than 64.
|
|||
/// - The node with the specified index does not exists in the Merkle tree. This is possible
|
|||
/// when a leaf node with the same index prefix exists at a tier higher than the node to
|
|||
/// which the path is requested.
|
|||
pub fn get_path(&self, mut index: NodeIndex) -> Result<MerklePath, MerkleError> {
|
|||
self.validate_node_access(index)?;
|
|||
|
|||
let mut path = Vec::with_capacity(index.depth() as usize);
|
|||
for _ in 0..index.depth() {
|
|||
let node = self.get_node_unchecked(&index.sibling());
|
|||
path.push(node);
|
|||
index.move_up();
|
|||
}
|
|||
|
|||
Ok(path.into())
|
|||
}
|
|||
|
|||
/// Returns a Merkle path to the node specified by the key together with a flag indicating,
|
|||
/// whether this node is a leaf at depths 16, 32, or 48.
|
|||
pub fn get_proof(&self, key: &RpoDigest) -> (MerklePath, NodeIndex, bool) {
|
|||
let (index, leaf_exists) = self.get_leaf_index(key);
|
|||
let index: NodeIndex = index.into();
|
|||
let path = self.get_path(index).expect("failed to retrieve Merkle path for a node index");
|
|||
(path, index, leaf_exists)
|
|||
}
|
|||
|
|||
/// Returns an index at which a leaf node for the specified key should be inserted.
|
|||
///
|
|||
/// The second value in the returned tuple is set to true if the node at the returned index
|
|||
/// is already a leaf node.
|
|||
pub fn get_leaf_index(&self, key: &RpoDigest) -> (LeafNodeIndex, bool) {
|
|||
// traverse the tree from the root down checking nodes at tiers 16, 32, and 48. Return if
|
|||
// a node at any of the tiers is either a leaf or a root of an empty subtree.
|
|||
const NUM_UPPER_TIERS: usize = TIER_DEPTHS.len() - 1;
|
|||
for &tier_depth in TIER_DEPTHS[..NUM_UPPER_TIERS].iter() {
|
|||
let index = LeafNodeIndex::from_key(key, tier_depth);
|
|||
if self.upper_leaves.contains(&index) {
|
|||
return (index, true);
|
|||
} else if !self.nodes.contains_key(&index) {
|
|||
return (index, false);
|
|||
}
|
|||
}
|
|||
|
|||
// if we got here, that means all of the nodes checked so far are internal nodes, and
|
|||
// the new node would need to be inserted in the bottom tier.
|
|||
let index = LeafNodeIndex::from_key(key, MAX_DEPTH);
|
|||
(index, self.bottom_leaves.contains(&index.value()))
|
|||
}
|
|||
|
|||
/// Traverses the tree up from the bottom tier starting at the specified leaf index and
|
|||
/// returns the depth of the first node which hash more than one child. The returned depth
|
|||
/// is rounded up to the next tier.
|
|||
pub fn get_last_single_child_parent_depth(&self, leaf_index: u64) -> u8 {
|
|||
let mut index = NodeIndex::new_unchecked(MAX_DEPTH, leaf_index);
|
|||
|
|||
for _ in (TIER_DEPTHS[0]..MAX_DEPTH).rev() {
|
|||
let sibling_index = index.sibling();
|
|||
if self.nodes.contains_key(&sibling_index) {
|
|||
break;
|
|||
}
|
|||
index.move_up();
|
|||
}
|
|||
|
|||
let tier = (index.depth() - 1) / TIER_SIZE;
|
|||
TIER_DEPTHS[tier as usize]
|
|||
}
|
|||
|
|||
// ITERATORS
|
|||
// --------------------------------------------------------------------------------------------
|
|||
|
|||
/// Returns an iterator over all inner nodes of the Tiered Sparse Merkle tree (i.e., nodes not
|
|||
/// at depths 16 32, 48, or 64).
|
|||
///
|
|||
/// The iterator order is unspecified.
|
|||
pub fn inner_nodes(&self) -> impl Iterator<Item = InnerNodeInfo> + '_ {
|
|||
self.nodes.iter().filter_map(|(index, node)| {
|
|||
if self.is_internal_node(index) {
|
|||
Some(InnerNodeInfo {
|
|||
value: *node,
|
|||
left: self.get_node_unchecked(&index.left_child()),
|
|||
right: self.get_node_unchecked(&index.right_child()),
|
|||
})
|
|||
} else {
|
|||
None
|
|||
}
|
|||
})
|
|||
}
|
|||
|
|||
/// Returns an iterator over the upper leaves (i.e., leaves with depths 16, 32, 48) of the
|
|||
/// Tiered Sparse Merkle tree.
|
|||
pub fn upper_leaves(&self) -> impl Iterator<Item = (&NodeIndex, &RpoDigest)> {
|
|||
self.upper_leaves.iter().map(|index| (index, &self.nodes[index]))
|
|||
}
|
|||
|
|||
/// Returns an iterator over the bottom leaves (i.e., leaves with depth 64) of the Tiered
|
|||
/// Sparse Merkle tree.
|
|||
pub fn bottom_leaves(&self) -> impl Iterator<Item = (&u64, &RpoDigest)> {
|
|||
self.bottom_leaves.iter().map(|value| {
|
|||
let index = NodeIndex::new_unchecked(MAX_DEPTH, *value);
|
|||
(value, &self.nodes[&index])
|
|||
})
|
|||
}
|
|||
|
|||
// STATE MUTATORS
|
|||
// --------------------------------------------------------------------------------------------
|
|||
|
|||
/// Replaces the leaf node at the specified index with a tree consisting of two leaves located
|
|||
/// at the specified indexes. Recomputes and returns the new root.
|
|||
pub fn replace_leaf_with_subtree(
|
|||
&mut self,
|
|||
leaf_index: LeafNodeIndex,
|
|||
subtree_leaves: [(LeafNodeIndex, RpoDigest); 2],
|
|||
) -> RpoDigest {
|
|||
debug_assert!(self.is_non_empty_leaf(&leaf_index));
|
|||
debug_assert!(!is_empty_root(&subtree_leaves[0].1));
|
|||
debug_assert!(!is_empty_root(&subtree_leaves[1].1));
|
|||
debug_assert_eq!(subtree_leaves[0].0.depth(), subtree_leaves[1].0.depth());
|
|||
debug_assert!(leaf_index.depth() < subtree_leaves[0].0.depth());
|
|||
|
|||
self.upper_leaves.remove(&leaf_index);
|
|||
|
|||
if subtree_leaves[0].0 == subtree_leaves[1].0 {
|
|||
// if the subtree is for a single node at depth 64, we only need to insert one node
|
|||
debug_assert_eq!(subtree_leaves[0].0.depth(), MAX_DEPTH);
|
|||
debug_assert_eq!(subtree_leaves[0].1, subtree_leaves[1].1);
|
|||
self.insert_leaf_node(subtree_leaves[0].0, subtree_leaves[0].1)
|
|||
} else {
|
|||
self.insert_leaf_node(subtree_leaves[0].0, subtree_leaves[0].1);
|
|||
self.insert_leaf_node(subtree_leaves[1].0, subtree_leaves[1].1)
|
|||
}
|
|||
}
|
|||
|
|||
/// Replaces a subtree containing the retained and the removed leaf nodes, with a leaf node
|
|||
/// containing the retained leaf.
|
|||
///
|
|||
/// This has the effect of deleting the the node at the `removed_leaf` index from the tree,
|
|||
/// moving the node at the `retained_leaf` index up to the tier specified by `new_depth`.
|
|||
pub fn replace_subtree_with_leaf(
|
|||
&mut self,
|
|||
removed_leaf: LeafNodeIndex,
|
|||
retained_leaf: LeafNodeIndex,
|
|||
new_depth: u8,
|
|||
node: RpoDigest,
|
|||
) -> RpoDigest {
|
|||
debug_assert!(!is_empty_root(&node));
|
|||
debug_assert!(self.is_non_empty_leaf(&removed_leaf));
|
|||
debug_assert!(self.is_non_empty_leaf(&retained_leaf));
|
|||
debug_assert_eq!(removed_leaf.depth(), retained_leaf.depth());
|
|||
debug_assert!(removed_leaf.depth() > new_depth);
|
|||
|
|||
// remove the branches leading up to the tier to which the retained leaf is to be moved
|
|||
self.remove_branch(removed_leaf, new_depth);
|
|||
self.remove_branch(retained_leaf, new_depth);
|
|||
|
|||
// compute the index of the common root for retained and removed leaves
|
|||
let mut new_index = retained_leaf;
|
|||
new_index.move_up_to(new_depth);
|
|||
|
|||
// insert the node at the root index
|
|||
self.insert_leaf_node(new_index, node)
|
|||
}
|
|||
|
|||
/// Inserts the specified node at the specified index; recomputes and returns the new root
|
|||
/// of the Tiered Sparse Merkle tree.
|
|||
///
|
|||
/// This method assumes that the provided node is a non-empty value, and that there is no node
|
|||
/// at the specified index.
|
|||
pub fn insert_leaf_node(&mut self, index: LeafNodeIndex, mut node: RpoDigest) -> RpoDigest {
|
|||
debug_assert!(!is_empty_root(&node));
|
|||
debug_assert_eq!(self.nodes.get(&index), None);
|
|||
|
|||
// mark the node as the leaf
|
|||
if index.depth() == MAX_DEPTH {
|
|||
self.bottom_leaves.insert(index.value());
|
|||
} else {
|
|||
self.upper_leaves.insert(index.into());
|
|||
};
|
|||
|
|||
// insert the node and update the path from the node to the root
|
|||
let mut index: NodeIndex = index.into();
|
|||
for _ in 0..index.depth() {
|
|||
self.nodes.insert(index, node);
|
|||
let sibling = self.get_node_unchecked(&index.sibling());
|
|||
node = Rpo256::merge(&index.build_node(node, sibling));
|
|||
index.move_up();
|
|||
}
|
|||
|
|||
// update the root
|
|||
self.nodes.insert(NodeIndex::root(), node);
|
|||
node
|
|||
}
|
|||
|
|||
/// Updates the node at the specified index with the specified node value; recomputes and
|
|||
/// returns the new root of the Tiered Sparse Merkle tree.
|
|||
///
|
|||
/// This method can accept `node` as either an empty or a non-empty value.
|
|||
pub fn update_leaf_node(&mut self, index: LeafNodeIndex, mut node: RpoDigest) -> RpoDigest {
|
|||
debug_assert!(self.is_non_empty_leaf(&index));
|
|||
|
|||
// if the value we are updating the node to is a root of an empty tree, clear the leaf
|
|||
// flag for this node
|
|||
if node == EmptySubtreeRoots::empty_hashes(MAX_DEPTH)[index.depth() as usize] {
|
|||
if index.depth() == MAX_DEPTH {
|
|||
self.bottom_leaves.remove(&index.value());
|
|||
} else {
|
|||
self.upper_leaves.remove(&index);
|
|||
}
|
|||
} else {
|
|||
debug_assert!(!is_empty_root(&node));
|
|||
}
|
|||
|
|||
// update the path from the node to the root
|
|||
let mut index: NodeIndex = index.into();
|
|||
for _ in 0..index.depth() {
|
|||
if node == EmptySubtreeRoots::empty_hashes(MAX_DEPTH)[index.depth() as usize] {
|
|||
self.nodes.remove(&index);
|
|||
} else {
|
|||
self.nodes.insert(index, node);
|
|||
}
|
|||
|
|||
let sibling = self.get_node_unchecked(&index.sibling());
|
|||
node = Rpo256::merge(&index.build_node(node, sibling));
|
|||
index.move_up();
|
|||
}
|
|||
|
|||
// update the root
|
|||
self.nodes.insert(NodeIndex::root(), node);
|
|||
node
|
|||
}
|
|||
|
|||
/// Replaces the leaf node at the specified index with a root of an empty subtree; recomputes
|
|||
/// and returns the new root of the Tiered Sparse Merkle tree.
|
|||
pub fn clear_leaf_node(&mut self, index: LeafNodeIndex) -> RpoDigest {
|
|||
debug_assert!(self.is_non_empty_leaf(&index));
|
|||
let node = EmptySubtreeRoots::empty_hashes(MAX_DEPTH)[index.depth() as usize];
|
|||
self.update_leaf_node(index, node)
|
|||
}
|
|||
|
|||
/// Truncates a branch starting with specified leaf at the bottom tier to new depth.
|
|||
///
|
|||
/// This involves removing the part of the branch below the new depth, and then inserting a new
|
|||
/// // node at the new depth.
|
|||
pub fn truncate_branch(
|
|||
&mut self,
|
|||
leaf_index: u64,
|
|||
new_depth: u8,
|
|||
node: RpoDigest,
|
|||
) -> RpoDigest {
|
|||
debug_assert!(self.bottom_leaves.contains(&leaf_index));
|
|||
|
|||
let mut leaf_index = LeafNodeIndex::new(NodeIndex::new_unchecked(MAX_DEPTH, leaf_index));
|
|||
self.remove_branch(leaf_index, new_depth);
|
|||
|
|||
leaf_index.move_up_to(new_depth);
|
|||
self.insert_leaf_node(leaf_index, node)
|
|||
}
|
|||
|
|||
// HELPER METHODS
|
|||
// --------------------------------------------------------------------------------------------
|
|||
|
|||
/// Returns true if the node at the specified index is a leaf node.
|
|||
fn is_non_empty_leaf(&self, index: &LeafNodeIndex) -> bool {
|
|||
if index.depth() == MAX_DEPTH {
|
|||
self.bottom_leaves.contains(&index.value())
|
|||
} else {
|
|||
self.upper_leaves.contains(index)
|
|||
}
|
|||
}
|
|||
|
|||
/// Returns true if the node at the specified index is an internal node - i.e., there is
|
|||
/// no leaf at that node and the node does not belong to the bottom tier.
|
|||
fn is_internal_node(&self, index: &NodeIndex) -> bool {
|
|||
if index.depth() == MAX_DEPTH {
|
|||
false
|
|||
} else {
|
|||
!self.upper_leaves.contains(index)
|
|||
}
|
|||
}
|
|||
|
|||
/// Checks if the specified index is valid in the context of this Merkle tree.
|
|||
///
|
|||
/// # Errors
|
|||
/// Returns an error if:
|
|||
/// - The specified index depth is 0 or greater than 64.
|
|||
/// - The node for the specified index does not exists in the Merkle tree. This is possible
|
|||
/// when an ancestors of the specified index is a leaf node.
|
|||
fn validate_node_access(&self, index: NodeIndex) -> Result<(), MerkleError> {
|
|||
if index.is_root() {
|
|||
return Err(MerkleError::DepthTooSmall(index.depth()));
|
|||
} else if index.depth() > MAX_DEPTH {
|
|||
return Err(MerkleError::DepthTooBig(index.depth() as u64));
|
|||
} else {
|
|||
// make sure that there are no leaf nodes in the ancestors of the index; since leaf
|
|||
// nodes can live at specific depth, we just need to check these depths.
|
|||
let tier = ((index.depth() - 1) / TIER_SIZE) as usize;
|
|||
let mut tier_index = index;
|
|||
for &depth in TIER_DEPTHS[..tier].iter().rev() {
|
|||
tier_index.move_up_to(depth);
|
|||
if self.upper_leaves.contains(&tier_index) {
|
|||
return Err(MerkleError::NodeNotInSet(index));
|
|||
}
|
|||
}
|
|||
}
|
|||
|
|||
Ok(())
|
|||
}
|
|||
|
|||
/// Returns a node at the specified index. If the node does not exist at this index, a root
|
|||
/// for an empty subtree at the index's depth is returned.
|
|||
///
|
|||
/// Unlike [NodeStore::get_node()] this does not perform any checks to verify that the
|
|||
/// returned node is valid in the context of this tree.
|
|||
fn get_node_unchecked(&self, index: &NodeIndex) -> RpoDigest {
|
|||
match self.nodes.get(index) {
|
|||
Some(node) => *node,
|
|||
None => EmptySubtreeRoots::empty_hashes(MAX_DEPTH)[index.depth() as usize],
|
|||
}
|
|||
}
|
|||
|
|||
/// Removes a sequence of nodes starting at the specified index and traversing the tree up to
|
|||
/// the specified depth. The node at the `end_depth` is also removed, and the appropriate leaf
|
|||
/// flag is cleared.
|
|||
///
|
|||
/// This method does not update any other nodes and does not recompute the tree root.
|
|||
fn remove_branch(&mut self, index: LeafNodeIndex, end_depth: u8) {
|
|||
if index.depth() == MAX_DEPTH {
|
|||
self.bottom_leaves.remove(&index.value());
|
|||
} else {
|
|||
self.upper_leaves.remove(&index);
|
|||
}
|
|||
|
|||
let mut index: NodeIndex = index.into();
|
|||
assert!(index.depth() > end_depth);
|
|||
for _ in 0..(index.depth() - end_depth + 1) {
|
|||
self.nodes.remove(&index);
|
|||
index.move_up()
|
|||
}
|
|||
}
|
|||
}
|
|||
|
|||
// HELPER FUNCTIONS
|
|||
// ================================================================================================
|
|||
|
|||
/// Returns true if the specified node is a root of an empty tree or an empty value ([ZERO; 4]).
|
|||
fn is_empty_root(node: &RpoDigest) -> bool {
|
|||
EmptySubtreeRoots::empty_hashes(MAX_DEPTH).contains(node)
|
|||
}
|
@ -0,0 +1,162 @@ |
|||
use super::{
|
|||
get_common_prefix_tier_depth, get_key_prefix, hash_bottom_leaf, hash_upper_leaf,
|
|||
EmptySubtreeRoots, LeafNodeIndex, MerklePath, RpoDigest, TieredSmtProofError, Vec, Word,
|
|||
};
|
|||
|
|||
// CONSTANTS
|
|||
// ================================================================================================
|
|||
|
|||
/// Maximum node depth. This is also the bottom tier of the tree.
|
|||
const MAX_DEPTH: u8 = super::TieredSmt::MAX_DEPTH;
|
|||
|
|||
/// Value of an empty leaf.
|
|||
pub const EMPTY_VALUE: Word = super::TieredSmt::EMPTY_VALUE;
|
|||
|
|||
/// Depths at which leaves can exist in a tiered SMT.
|
|||
pub const TIER_DEPTHS: [u8; 4] = super::TieredSmt::TIER_DEPTHS;
|
|||
|
|||
// TIERED SPARSE MERKLE TREE PROOF
|
|||
// ================================================================================================
|
|||
|
|||
/// A proof which can be used to assert membership (or non-membership) of key-value pairs in a
|
|||
/// Tiered Sparse Merkle tree.
|
|||
///
|
|||
/// The proof consists of a Merkle path and one or more key-value entries which describe the node
|
|||
/// located at the base of the path. If the node at the base of the path resolves to [ZERO; 4],
|
|||
/// the entries will contain a single item with value set to [ZERO; 4].
|
|||
#[derive(PartialEq, Eq, Debug)]
|
|||
pub struct TieredSmtProof {
|
|||
path: MerklePath,
|
|||
entries: Vec<(RpoDigest, Word)>,
|
|||
}
|
|||
|
|||
impl TieredSmtProof {
|
|||
// CONSTRUCTOR
|
|||
// --------------------------------------------------------------------------------------------
|
|||
/// Returns a new instance of [TieredSmtProof] instantiated from the specified path and entries.
|
|||
///
|
|||
/// # Panics
|
|||
/// Panics if:
|
|||
/// - The length of the path is greater than 64.
|
|||
/// - Entries is an empty vector.
|
|||
/// - Entries contains more than 1 item, but the length of the path is not 64.
|
|||
/// - Entries contains more than 1 item, and one of the items has value set to [ZERO; 4].
|
|||
/// - Entries contains multiple items with keys which don't share the same 64-bit prefix.
|
|||
pub fn new<I>(path: MerklePath, entries: I) -> Result<Self, TieredSmtProofError>
|
|||
where
|
|||
I: IntoIterator<Item = (RpoDigest, Word)>,
|
|||
{
|
|||
let entries: Vec<(RpoDigest, Word)> = entries.into_iter().collect();
|
|||
|
|||
if !TIER_DEPTHS.into_iter().any(|e| e == path.depth()) {
|
|||
return Err(TieredSmtProofError::NotATierPath(path.depth()));
|
|||
}
|
|||
|
|||
if entries.is_empty() {
|
|||
return Err(TieredSmtProofError::EntriesEmpty);
|
|||
}
|
|||
|
|||
if entries.len() > 1 {
|
|||
if path.depth() != MAX_DEPTH {
|
|||
return Err(TieredSmtProofError::MultipleEntriesOutsideLastTier);
|
|||
}
|
|||
|
|||
let prefix = get_key_prefix(&entries[0].0);
|
|||
for entry in entries.iter().skip(1) {
|
|||
if entry.1 == EMPTY_VALUE {
|
|||
return Err(TieredSmtProofError::EmptyValueNotAllowed);
|
|||
}
|
|||
let current = get_key_prefix(&entry.0);
|
|||
if prefix != current {
|
|||
return Err(TieredSmtProofError::MismatchedPrefixes(prefix, current));
|
|||
}
|
|||
}
|
|||
}
|
|||
|
|||
Ok(Self { path, entries })
|
|||
}
|
|||
|
|||
// PROOF VERIFIER
|
|||
// --------------------------------------------------------------------------------------------
|
|||
|
|||
/// Returns true if a Tiered Sparse Merkle tree with the specified root contains the provided
|
|||
/// key-value pair.
|
|||
///
|
|||
/// Note: this method cannot be used to assert non-membership. That is, if false is returned,
|
|||
/// it does not mean that the provided key-value pair is not in the tree.
|
|||
pub fn verify_membership(&self, key: &RpoDigest, value: &Word, root: &RpoDigest) -> bool {
|
|||
if self.is_value_empty() {
|
|||
if value != &EMPTY_VALUE {
|
|||
return false;
|
|||
}
|
|||
// if the proof is for an empty value, we can verify it against any key which has a
|
|||
// common prefix with the key storied in entries, but the prefix must be greater than
|
|||
// the path length
|
|||
let common_prefix_tier = get_common_prefix_tier_depth(key, &self.entries[0].0);
|
|||
if common_prefix_tier < self.path.depth() {
|
|||
return false;
|
|||
}
|
|||
} else if !self.entries.contains(&(*key, *value)) {
|
|||
return false;
|
|||
}
|
|||
|
|||
// make sure the Merkle path resolves to the correct root
|
|||
root == &self.compute_root()
|
|||
}
|
|||
|
|||
// PUBLIC ACCESSORS
|
|||
// --------------------------------------------------------------------------------------------
|
|||
|
|||
/// Returns the value associated with the specific key according to this proof, or None if
|
|||
/// this proof does not contain a value for the specified key.
|
|||
///
|
|||
/// A key-value pair generated by using this method should pass the `verify_membership()` check.
|
|||
pub fn get(&self, key: &RpoDigest) -> Option<Word> {
|
|||
if self.is_value_empty() {
|
|||
let common_prefix_tier = get_common_prefix_tier_depth(key, &self.entries[0].0);
|
|||
if common_prefix_tier < self.path.depth() {
|
|||
None
|
|||
} else {
|
|||
Some(EMPTY_VALUE)
|
|||
}
|
|||
} else {
|
|||
self.entries.iter().find(|(k, _)| k == key).map(|(_, value)| *value)
|
|||
}
|
|||
}
|
|||
|
|||
/// Computes the root of a Tiered Sparse Merkle tree to which this proof resolve.
|
|||
pub fn compute_root(&self) -> RpoDigest {
|
|||
let node = self.build_node();
|
|||
let index = LeafNodeIndex::from_key(&self.entries[0].0, self.path.depth());
|
|||
self.path
|
|||
.compute_root(index.value(), node)
|
|||
.expect("failed to compute Merkle path root")
|
|||
}
|
|||
|
|||
/// Consume the proof and returns its parts.
|
|||
pub fn into_parts(self) -> (MerklePath, Vec<(RpoDigest, Word)>) {
|
|||
(self.path, self.entries)
|
|||
}
|
|||
|
|||
// HELPER METHODS
|
|||
// --------------------------------------------------------------------------------------------
|
|||
|
|||
/// Returns true if the proof is for an empty value.
|
|||
fn is_value_empty(&self) -> bool {
|
|||
self.entries[0].1 == EMPTY_VALUE
|
|||
}
|
|||
|
|||
/// Converts the entries contained in this proof into a node value for node at the base of the
|
|||
/// path contained in this proof.
|
|||
fn build_node(&self) -> RpoDigest {
|
|||
let depth = self.path.depth();
|
|||
if self.is_value_empty() {
|
|||
EmptySubtreeRoots::empty_hashes(MAX_DEPTH)[depth as usize]
|
|||
} else if depth == MAX_DEPTH {
|
|||
hash_bottom_leaf(&self.entries)
|
|||
} else {
|
|||
let (key, value) = self.entries[0];
|
|||
hash_upper_leaf(key, value, depth)
|
|||
}
|
|||
}
|
|||
}
|
@ -0,0 +1,584 @@ |
|||
use super::{get_key_prefix, BTreeMap, LeafNodeIndex, RpoDigest, StarkField, Vec, Word};
|
|||
use crate::utils::vec;
|
|||
use core::{
|
|||
cmp::{Ord, Ordering},
|
|||
ops::RangeBounds,
|
|||
};
|
|||
use winter_utils::collections::btree_map::Entry;
|
|||
|
|||
// CONSTANTS
|
|||
// ================================================================================================
|
|||
|
|||
/// Depths at which leaves can exist in a tiered SMT.
|
|||
const TIER_DEPTHS: [u8; 4] = super::TieredSmt::TIER_DEPTHS;
|
|||
|
|||
/// Maximum node depth. This is also the bottom tier of the tree.
|
|||
const MAX_DEPTH: u8 = super::TieredSmt::MAX_DEPTH;
|
|||
|
|||
// VALUE STORE
|
|||
// ================================================================================================
|
|||
/// A store for key-value pairs for a Tiered Sparse Merkle tree.
|
|||
///
|
|||
/// The store is organized in a [BTreeMap] where keys are 64 most significant bits of a key, and
|
|||
/// the values are the corresponding key-value pairs (or a list of key-value pairs if more that
|
|||
/// a single key-value pair shares the same 64-bit prefix).
|
|||
///
|
|||
/// The store supports lookup by the full key (i.e. [RpoDigest]) as well as by the 64-bit key
|
|||
/// prefix.
|
|||
#[derive(Debug, Default, Clone, PartialEq, Eq)]
|
|||
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
|
|||
pub struct ValueStore {
|
|||
values: BTreeMap<u64, StoreEntry>,
|
|||
}
|
|||
|
|||
impl ValueStore {
|
|||
// PUBLIC ACCESSORS
|
|||
// --------------------------------------------------------------------------------------------
|
|||
|
|||
/// Returns a reference to the value stored under the specified key, or None if there is no
|
|||
/// value associated with the specified key.
|
|||
pub fn get(&self, key: &RpoDigest) -> Option<&Word> {
|
|||
let prefix = get_key_prefix(key);
|
|||
self.values.get(&prefix).and_then(|entry| entry.get(key))
|
|||
}
|
|||
|
|||
/// Returns the first key-value pair such that the key prefix is greater than or equal to the
|
|||
/// specified prefix.
|
|||
pub fn get_first(&self, prefix: u64) -> Option<&(RpoDigest, Word)> {
|
|||
self.range(prefix..).next()
|
|||
}
|
|||
|
|||
/// Returns the first key-value pair such that the key prefix is greater than or equal to the
|
|||
/// specified prefix and the key value is not equal to the exclude_key value.
|
|||
pub fn get_first_filtered(
|
|||
&self,
|
|||
prefix: u64,
|
|||
exclude_key: &RpoDigest,
|
|||
) -> Option<&(RpoDigest, Word)> {
|
|||
self.range(prefix..).find(|(key, _)| key != exclude_key)
|
|||
}
|
|||
|
|||
/// Returns a vector with key-value pairs for all keys with the specified 64-bit prefix, or
|
|||
/// None if no keys with the specified prefix are present in this store.
|
|||
pub fn get_all(&self, prefix: u64) -> Option<Vec<(RpoDigest, Word)>> {
|
|||
self.values.get(&prefix).map(|entry| match entry {
|
|||
StoreEntry::Single(kv_pair) => vec![*kv_pair],
|
|||
StoreEntry::List(kv_pairs) => kv_pairs.clone(),
|
|||
})
|
|||
}
|
|||
|
|||
/// Returns information about a sibling of a leaf node with the specified index, but only if
|
|||
/// this is the only sibling the leaf has in some subtree starting at the first tier.
|
|||
///
|
|||
/// For example, if `index` is an index at depth 32, and there is a leaf node at depth 32 with
|
|||
/// the same root at depth 16 as `index`, we say that this leaf is a lone sibling.
|
|||
///
|
|||
/// The returned tuple contains: they key-value pair of the sibling as well as the index of
|
|||
/// the node for the root of the common subtree in which both nodes are leaves.
|
|||
///
|
|||
/// This method assumes that the key-value pair for the specified index has already been
|
|||
/// removed from the store.
|
|||
pub fn get_lone_sibling(
|
|||
&self,
|
|||
index: LeafNodeIndex,
|
|||
) -> Option<(&RpoDigest, &Word, LeafNodeIndex)> {
|
|||
// iterate over tiers from top to bottom, looking at the tiers which are strictly above
|
|||
// the depth of the index. This implies that only tiers at depth 32 and 48 will be
|
|||
// considered. For each tier, check if the parent of the index at the higher tier
|
|||
// contains a single node. The fist tier (depth 16) is excluded because we cannot move
|
|||
// nodes at depth 16 to a higher tier. This implies that nodes at the first tier will
|
|||
// never have "lone siblings".
|
|||
for &tier_depth in TIER_DEPTHS.iter().filter(|&t| index.depth() > *t) {
|
|||
// compute the index of the root at a higher tier
|
|||
let mut parent_index = index;
|
|||
parent_index.move_up_to(tier_depth);
|
|||
|
|||
// find the lone sibling, if any; we need to handle the "last node" at a given tier
|
|||
// separately specify the bounds for the search correctly.
|
|||
let start_prefix = parent_index.value() << (MAX_DEPTH - tier_depth);
|
|||
let sibling = if start_prefix.leading_ones() as u8 == tier_depth {
|
|||
let mut iter = self.range(start_prefix..);
|
|||
iter.next().filter(|_| iter.next().is_none())
|
|||
} else {
|
|||
let end_prefix = (parent_index.value() + 1) << (MAX_DEPTH - tier_depth);
|
|||
let mut iter = self.range(start_prefix..end_prefix);
|
|||
iter.next().filter(|_| iter.next().is_none())
|
|||
};
|
|||
|
|||
if let Some((key, value)) = sibling {
|
|||
return Some((key, value, parent_index));
|
|||
}
|
|||
}
|
|||
|
|||
None
|
|||
}
|
|||
|
|||
/// Returns an iterator over all key-value pairs in this store.
|
|||
pub fn iter(&self) -> impl Iterator<Item = &(RpoDigest, Word)> {
|
|||
self.values.iter().flat_map(|(_, entry)| entry.iter())
|
|||
}
|
|||
|
|||
// STATE MUTATORS
|
|||
// --------------------------------------------------------------------------------------------
|
|||
|
|||
/// Inserts the specified key-value pair into this store and returns the value previously
|
|||
/// associated with the specified key.
|
|||
///
|
|||
/// If no value was previously associated with the specified key, None is returned.
|
|||
pub fn insert(&mut self, key: RpoDigest, value: Word) -> Option<Word> {
|
|||
let prefix = get_key_prefix(&key);
|
|||
match self.values.entry(prefix) {
|
|||
Entry::Occupied(mut entry) => entry.get_mut().insert(key, value),
|
|||
Entry::Vacant(entry) => {
|
|||
entry.insert(StoreEntry::new(key, value));
|
|||
None
|
|||
}
|
|||
}
|
|||
}
|
|||
|
|||
/// Removes the key-value pair for the specified key from this store and returns the value
|
|||
/// associated with this key.
|
|||
///
|
|||
/// If no value was associated with the specified key, None is returned.
|
|||
pub fn remove(&mut self, key: &RpoDigest) -> Option<Word> {
|
|||
let prefix = get_key_prefix(key);
|
|||
match self.values.entry(prefix) {
|
|||
Entry::Occupied(mut entry) => {
|
|||
let (value, remove_entry) = entry.get_mut().remove(key);
|
|||
if remove_entry {
|
|||
entry.remove_entry();
|
|||
}
|
|||
value
|
|||
}
|
|||
Entry::Vacant(_) => None,
|
|||
}
|
|||
}
|
|||
|
|||
// HELPER METHODS
|
|||
// --------------------------------------------------------------------------------------------
|
|||
|
|||
/// Returns an iterator over all key-value pairs contained in this store such that the most
|
|||
/// significant 64 bits of the key lay within the specified bounds.
|
|||
///
|
|||
/// The order of iteration is from the smallest to the largest key.
|
|||
fn range<R: RangeBounds<u64>>(&self, bounds: R) -> impl Iterator<Item = &(RpoDigest, Word)> {
|
|||
self.values.range(bounds).flat_map(|(_, entry)| entry.iter())
|
|||
}
|
|||
}
|
|||
|
|||
// VALUE NODE
|
|||
// ================================================================================================
|
|||
|
|||
/// An entry in the [ValueStore].
|
|||
///
|
|||
/// An entry can contain either a single key-value pair or a vector of key-value pairs sorted by
|
|||
/// key.
|
|||
#[derive(Debug, Clone, PartialEq, Eq)]
|
|||
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
|
|||
pub enum StoreEntry {
|
|||
Single((RpoDigest, Word)),
|
|||
List(Vec<(RpoDigest, Word)>),
|
|||
}
|
|||
|
|||
impl StoreEntry {
|
|||
// CONSTRUCTOR
|
|||
// --------------------------------------------------------------------------------------------
|
|||
/// Returns a new [StoreEntry] instantiated with a single key-value pair.
|
|||
pub fn new(key: RpoDigest, value: Word) -> Self {
|
|||
Self::Single((key, value))
|
|||
}
|
|||
|
|||
// PUBLIC ACCESSORS
|
|||
// --------------------------------------------------------------------------------------------
|
|||
|
|||
/// Returns the value associated with the specified key, or None if this entry does not contain
|
|||
/// a value associated with the specified key.
|
|||
pub fn get(&self, key: &RpoDigest) -> Option<&Word> {
|
|||
match self {
|
|||
StoreEntry::Single(kv_pair) => {
|
|||
if kv_pair.0 == *key {
|
|||
Some(&kv_pair.1)
|
|||
} else {
|
|||
None
|
|||
}
|
|||
}
|
|||
StoreEntry::List(kv_pairs) => {
|
|||
match kv_pairs.binary_search_by(|kv_pair| cmp_digests(&kv_pair.0, key)) {
|
|||
Ok(pos) => Some(&kv_pairs[pos].1),
|
|||
Err(_) => None,
|
|||
}
|
|||
}
|
|||
}
|
|||
}
|
|||
|
|||
/// Returns an iterator over all key-value pairs in this entry.
|
|||
pub fn iter(&self) -> impl Iterator<Item = &(RpoDigest, Word)> {
|
|||
EntryIterator { entry: self, pos: 0 }
|
|||
}
|
|||
|
|||
// STATE MUTATORS
|
|||
// --------------------------------------------------------------------------------------------
|
|||
|
|||
/// Inserts the specified key-value pair into this entry and returns the value previously
|
|||
/// associated with the specified key, or None if no value was associated with the specified
|
|||
/// key.
|
|||
///
|
|||
/// If a new key is inserted, this will also transform a `SingleEntry` into a `ListEntry`.
|
|||
pub fn insert(&mut self, key: RpoDigest, value: Word) -> Option<Word> {
|
|||
match self {
|
|||
StoreEntry::Single(kv_pair) => {
|
|||
// if the key is already in this entry, update the value and return
|
|||
if kv_pair.0 == key {
|
|||
let old_value = kv_pair.1;
|
|||
kv_pair.1 = value;
|
|||
return Some(old_value);
|
|||
}
|
|||
|
|||
// transform the entry into a list entry, and make sure the key-value pairs
|
|||
// are sorted by key
|
|||
let mut pairs = vec![*kv_pair, (key, value)];
|
|||
pairs.sort_by(|a, b| cmp_digests(&a.0, &b.0));
|
|||
|
|||
*self = StoreEntry::List(pairs);
|
|||
None
|
|||
}
|
|||
StoreEntry::List(pairs) => {
|
|||
match pairs.binary_search_by(|kv_pair| cmp_digests(&kv_pair.0, &key)) {
|
|||
Ok(pos) => {
|
|||
let old_value = pairs[pos].1;
|
|||
pairs[pos].1 = value;
|
|||
Some(old_value)
|
|||
}
|
|||
Err(pos) => {
|
|||
pairs.insert(pos, (key, value));
|
|||
None
|
|||
}
|
|||
}
|
|||
}
|
|||
}
|
|||
}
|
|||
|
|||
/// Removes the key-value pair with the specified key from this entry, and returns the value
|
|||
/// of the removed pair. If the entry did not contain a key-value pair for the specified key,
|
|||
/// None is returned.
|
|||
///
|
|||
/// If the last last key-value pair was removed from the entry, the second tuple value will
|
|||
/// be set to true.
|
|||
pub fn remove(&mut self, key: &RpoDigest) -> (Option<Word>, bool) {
|
|||
match self {
|
|||
StoreEntry::Single(kv_pair) => {
|
|||
if kv_pair.0 == *key {
|
|||
(Some(kv_pair.1), true)
|
|||
} else {
|
|||
(None, false)
|
|||
}
|
|||
}
|
|||
StoreEntry::List(kv_pairs) => {
|
|||
match kv_pairs.binary_search_by(|kv_pair| cmp_digests(&kv_pair.0, key)) {
|
|||
Ok(pos) => {
|
|||
let kv_pair = kv_pairs.remove(pos);
|
|||
if kv_pairs.len() == 1 {
|
|||
*self = StoreEntry::Single(kv_pairs[0]);
|
|||
}
|
|||
(Some(kv_pair.1), false)
|
|||
}
|
|||
Err(_) => (None, false),
|
|||
}
|
|||
}
|
|||
}
|
|||
}
|
|||
}
|
|||
|
|||
/// A custom iterator over key-value pairs of a [StoreEntry].
|
|||
///
|
|||
/// For a `SingleEntry` this returns only one value, but for `ListEntry`, this iterates over the
|
|||
/// entire list of key-value pairs.
|
|||
pub struct EntryIterator<'a> {
|
|||
entry: &'a StoreEntry,
|
|||
pos: usize,
|
|||
}
|
|||
|
|||
impl<'a> Iterator for EntryIterator<'a> {
|
|||
type Item = &'a (RpoDigest, Word);
|
|||
|
|||
fn next(&mut self) -> Option<Self::Item> {
|
|||
match self.entry {
|
|||
StoreEntry::Single(kv_pair) => {
|
|||
if self.pos == 0 {
|
|||
self.pos = 1;
|
|||
Some(kv_pair)
|
|||
} else {
|
|||
None
|
|||
}
|
|||
}
|
|||
StoreEntry::List(kv_pairs) => {
|
|||
if self.pos >= kv_pairs.len() {
|
|||
None
|
|||
} else {
|
|||
let kv_pair = &kv_pairs[self.pos];
|
|||
self.pos += 1;
|
|||
Some(kv_pair)
|
|||
}
|
|||
}
|
|||
}
|
|||
}
|
|||
}
|
|||
|
|||
// HELPER FUNCTIONS
|
|||
// ================================================================================================
|
|||
|
|||
/// Compares two digests element-by-element using their integer representations starting with the
|
|||
/// most significant element.
|
|||
fn cmp_digests(d1: &RpoDigest, d2: &RpoDigest) -> Ordering {
|
|||
let d1 = Word::from(d1);
|
|||
let d2 = Word::from(d2);
|
|||
|
|||
for (v1, v2) in d1.iter().zip(d2.iter()).rev() {
|
|||
let v1 = v1.as_int();
|
|||
let v2 = v2.as_int();
|
|||
if v1 != v2 {
|
|||
return v1.cmp(&v2);
|
|||
}
|
|||
}
|
|||
|
|||
Ordering::Equal
|
|||
}
|
|||
|
|||
// TESTS
|
|||
// ================================================================================================
|
|||
|
|||
#[cfg(test)]
|
|||
mod tests {
|
|||
use super::{LeafNodeIndex, RpoDigest, StoreEntry, ValueStore};
|
|||
use crate::{Felt, ONE, WORD_SIZE, ZERO};
|
|||
|
|||
#[test]
|
|||
fn test_insert() {
|
|||
let mut store = ValueStore::default();
|
|||
|
|||
// insert the first key-value pair into the store
|
|||
let raw_a = 0b_10101010_10101010_00011111_11111111_10010110_10010011_11100000_00000000_u64;
|
|||
let key_a = RpoDigest::from([ZERO, ONE, ONE, Felt::new(raw_a)]);
|
|||
let value_a = [ONE; WORD_SIZE];
|
|||
|
|||
assert!(store.insert(key_a, value_a).is_none());
|
|||
assert_eq!(store.values.len(), 1);
|
|||
|
|||
let entry = store.values.get(&raw_a).unwrap();
|
|||
let expected_entry = StoreEntry::Single((key_a, value_a));
|
|||
assert_eq!(entry, &expected_entry);
|
|||
|
|||
// insert a key-value pair with a different key into the store; since the keys are
|
|||
// different, another entry is added to the values map
|
|||
let raw_b = 0b_11111110_10101010_00011111_11111111_10010110_10010011_11100000_00000000_u64;
|
|||
let key_b = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_b)]);
|
|||
let value_b = [ONE, ZERO, ONE, ZERO];
|
|||
|
|||
assert!(store.insert(key_b, value_b).is_none());
|
|||
assert_eq!(store.values.len(), 2);
|
|||
|
|||
let entry1 = store.values.get(&raw_a).unwrap();
|
|||
let expected_entry1 = StoreEntry::Single((key_a, value_a));
|
|||
assert_eq!(entry1, &expected_entry1);
|
|||
|
|||
let entry2 = store.values.get(&raw_b).unwrap();
|
|||
let expected_entry2 = StoreEntry::Single((key_b, value_b));
|
|||
assert_eq!(entry2, &expected_entry2);
|
|||
|
|||
// insert a key-value pair with the same 64-bit key prefix as the first key; this should
|
|||
// transform the first entry into a List entry
|
|||
let key_c = RpoDigest::from([ONE, ONE, ZERO, Felt::new(raw_a)]);
|
|||
let value_c = [ONE, ONE, ZERO, ZERO];
|
|||
|
|||
assert!(store.insert(key_c, value_c).is_none());
|
|||
assert_eq!(store.values.len(), 2);
|
|||
|
|||
let entry1 = store.values.get(&raw_a).unwrap();
|
|||
let expected_entry1 = StoreEntry::List(vec![(key_c, value_c), (key_a, value_a)]);
|
|||
assert_eq!(entry1, &expected_entry1);
|
|||
|
|||
let entry2 = store.values.get(&raw_b).unwrap();
|
|||
let expected_entry2 = StoreEntry::Single((key_b, value_b));
|
|||
assert_eq!(entry2, &expected_entry2);
|
|||
|
|||
// replace values for keys a and b
|
|||
let value_a2 = [ONE, ONE, ONE, ZERO];
|
|||
let value_b2 = [ZERO, ZERO, ZERO, ONE];
|
|||
|
|||
assert_eq!(store.insert(key_a, value_a2), Some(value_a));
|
|||
assert_eq!(store.values.len(), 2);
|
|||
|
|||
assert_eq!(store.insert(key_b, value_b2), Some(value_b));
|
|||
assert_eq!(store.values.len(), 2);
|
|||
|
|||
let entry1 = store.values.get(&raw_a).unwrap();
|
|||
let expected_entry1 = StoreEntry::List(vec![(key_c, value_c), (key_a, value_a2)]);
|
|||
assert_eq!(entry1, &expected_entry1);
|
|||
|
|||
let entry2 = store.values.get(&raw_b).unwrap();
|
|||
let expected_entry2 = StoreEntry::Single((key_b, value_b2));
|
|||
assert_eq!(entry2, &expected_entry2);
|
|||
|
|||
// insert one more key-value pair with the same 64-bit key-prefix as the first key
|
|||
let key_d = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_a)]);
|
|||
let value_d = [ZERO, ONE, ZERO, ZERO];
|
|||
|
|||
assert!(store.insert(key_d, value_d).is_none());
|
|||
assert_eq!(store.values.len(), 2);
|
|||
|
|||
let entry1 = store.values.get(&raw_a).unwrap();
|
|||
let expected_entry1 =
|
|||
StoreEntry::List(vec![(key_c, value_c), (key_a, value_a2), (key_d, value_d)]);
|
|||
assert_eq!(entry1, &expected_entry1);
|
|||
|
|||
let entry2 = store.values.get(&raw_b).unwrap();
|
|||
let expected_entry2 = StoreEntry::Single((key_b, value_b2));
|
|||
assert_eq!(entry2, &expected_entry2);
|
|||
}
|
|||
|
|||
#[test]
|
|||
fn test_remove() {
|
|||
// populate the value store
|
|||
let mut store = ValueStore::default();
|
|||
|
|||
let raw_a = 0b_10101010_10101010_00011111_11111111_10010110_10010011_11100000_00000000_u64;
|
|||
let key_a = RpoDigest::from([ZERO, ONE, ONE, Felt::new(raw_a)]);
|
|||
let value_a = [ONE; WORD_SIZE];
|
|||
store.insert(key_a, value_a);
|
|||
|
|||
let raw_b = 0b_11111110_10101010_00011111_11111111_10010110_10010011_11100000_00000000_u64;
|
|||
let key_b = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_b)]);
|
|||
let value_b = [ONE, ZERO, ONE, ZERO];
|
|||
store.insert(key_b, value_b);
|
|||
|
|||
let key_c = RpoDigest::from([ONE, ONE, ZERO, Felt::new(raw_a)]);
|
|||
let value_c = [ONE, ONE, ZERO, ZERO];
|
|||
store.insert(key_c, value_c);
|
|||
|
|||
let key_d = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_a)]);
|
|||
let value_d = [ZERO, ONE, ZERO, ZERO];
|
|||
store.insert(key_d, value_d);
|
|||
|
|||
assert_eq!(store.values.len(), 2);
|
|||
|
|||
let entry1 = store.values.get(&raw_a).unwrap();
|
|||
let expected_entry1 =
|
|||
StoreEntry::List(vec![(key_c, value_c), (key_a, value_a), (key_d, value_d)]);
|
|||
assert_eq!(entry1, &expected_entry1);
|
|||
|
|||
let entry2 = store.values.get(&raw_b).unwrap();
|
|||
let expected_entry2 = StoreEntry::Single((key_b, value_b));
|
|||
assert_eq!(entry2, &expected_entry2);
|
|||
|
|||
// remove non-existent keys
|
|||
let key_e = RpoDigest::from([ZERO, ZERO, ONE, Felt::new(raw_a)]);
|
|||
assert!(store.remove(&key_e).is_none());
|
|||
|
|||
let raw_f = 0b_11111110_11111111_00011111_11111111_10010110_10010011_11100000_00000000_u64;
|
|||
let key_f = RpoDigest::from([ZERO, ZERO, ONE, Felt::new(raw_f)]);
|
|||
assert!(store.remove(&key_f).is_none());
|
|||
|
|||
// remove keys from the list entry
|
|||
assert_eq!(store.remove(&key_c).unwrap(), value_c);
|
|||
let entry1 = store.values.get(&raw_a).unwrap();
|
|||
let expected_entry1 = StoreEntry::List(vec![(key_a, value_a), (key_d, value_d)]);
|
|||
assert_eq!(entry1, &expected_entry1);
|
|||
|
|||
assert_eq!(store.remove(&key_a).unwrap(), value_a);
|
|||
let entry1 = store.values.get(&raw_a).unwrap();
|
|||
let expected_entry1 = StoreEntry::Single((key_d, value_d));
|
|||
assert_eq!(entry1, &expected_entry1);
|
|||
|
|||
assert_eq!(store.remove(&key_d).unwrap(), value_d);
|
|||
assert!(store.values.get(&raw_a).is_none());
|
|||
assert_eq!(store.values.len(), 1);
|
|||
|
|||
// remove a key from a single entry
|
|||
assert_eq!(store.remove(&key_b).unwrap(), value_b);
|
|||
assert!(store.values.get(&raw_b).is_none());
|
|||
assert_eq!(store.values.len(), 0);
|
|||
}
|
|||
|
|||
#[test]
|
|||
fn test_range() {
|
|||
// populate the value store
|
|||
let mut store = ValueStore::default();
|
|||
|
|||
let raw_a = 0b_10101010_10101010_00011111_11111111_10010110_10010011_11100000_00000000_u64;
|
|||
let key_a = RpoDigest::from([ZERO, ONE, ONE, Felt::new(raw_a)]);
|
|||
let value_a = [ONE; WORD_SIZE];
|
|||
store.insert(key_a, value_a);
|
|||
|
|||
let raw_b = 0b_11111110_10101010_00011111_11111111_10010110_10010011_11100000_00000000_u64;
|
|||
let key_b = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_b)]);
|
|||
let value_b = [ONE, ZERO, ONE, ZERO];
|
|||
store.insert(key_b, value_b);
|
|||
|
|||
let key_c = RpoDigest::from([ONE, ONE, ZERO, Felt::new(raw_a)]);
|
|||
let value_c = [ONE, ONE, ZERO, ZERO];
|
|||
store.insert(key_c, value_c);
|
|||
|
|||
let key_d = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_a)]);
|
|||
let value_d = [ZERO, ONE, ZERO, ZERO];
|
|||
store.insert(key_d, value_d);
|
|||
|
|||
let raw_e = 0b_10101000_10101010_00011111_11111111_10010110_10010011_11100000_00000000_u64;
|
|||
let key_e = RpoDigest::from([ZERO, ONE, ONE, Felt::new(raw_e)]);
|
|||
let value_e = [ZERO, ZERO, ZERO, ONE];
|
|||
store.insert(key_e, value_e);
|
|||
|
|||
// check the entire range
|
|||
let mut iter = store.range(..u64::MAX);
|
|||
assert_eq!(iter.next(), Some(&(key_e, value_e)));
|
|||
assert_eq!(iter.next(), Some(&(key_c, value_c)));
|
|||
assert_eq!(iter.next(), Some(&(key_a, value_a)));
|
|||
assert_eq!(iter.next(), Some(&(key_d, value_d)));
|
|||
assert_eq!(iter.next(), Some(&(key_b, value_b)));
|
|||
assert_eq!(iter.next(), None);
|
|||
|
|||
// check all but e
|
|||
let mut iter = store.range(raw_a..u64::MAX);
|
|||
assert_eq!(iter.next(), Some(&(key_c, value_c)));
|
|||
assert_eq!(iter.next(), Some(&(key_a, value_a)));
|
|||
assert_eq!(iter.next(), Some(&(key_d, value_d)));
|
|||
assert_eq!(iter.next(), Some(&(key_b, value_b)));
|
|||
assert_eq!(iter.next(), None);
|
|||
|
|||
// check all but e and b
|
|||
let mut iter = store.range(raw_a..raw_b);
|
|||
assert_eq!(iter.next(), Some(&(key_c, value_c)));
|
|||
assert_eq!(iter.next(), Some(&(key_a, value_a)));
|
|||
assert_eq!(iter.next(), Some(&(key_d, value_d)));
|
|||
assert_eq!(iter.next(), None);
|
|||
}
|
|||
|
|||
#[test]
|
|||
fn test_get_lone_sibling() {
|
|||
// populate the value store
|
|||
let mut store = ValueStore::default();
|
|||
|
|||
let raw_a = 0b_10101010_10101010_00011111_11111111_10010110_10010011_11100000_00000000_u64;
|
|||
let key_a = RpoDigest::from([ZERO, ONE, ONE, Felt::new(raw_a)]);
|
|||
let value_a = [ONE; WORD_SIZE];
|
|||
store.insert(key_a, value_a);
|
|||
|
|||
let raw_b = 0b_11111111_11111111_00011111_11111111_10010110_10010011_11100000_00000000_u64;
|
|||
let key_b = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_b)]);
|
|||
let value_b = [ONE, ZERO, ONE, ZERO];
|
|||
store.insert(key_b, value_b);
|
|||
|
|||
// check sibling node for `a`
|
|||
let index = LeafNodeIndex::make(32, 0b_10101010_10101010_00011111_11111110);
|
|||
let parent_index = LeafNodeIndex::make(16, 0b_10101010_10101010);
|
|||
assert_eq!(store.get_lone_sibling(index), Some((&key_a, &value_a, parent_index)));
|
|||
|
|||
// check sibling node for `b`
|
|||
let index = LeafNodeIndex::make(32, 0b_11111111_11111111_00011111_11111111);
|
|||
let parent_index = LeafNodeIndex::make(16, 0b_11111111_11111111);
|
|||
assert_eq!(store.get_lone_sibling(index), Some((&key_b, &value_b, parent_index)));
|
|||
|
|||
// check some other sibling for some other index
|
|||
let index = LeafNodeIndex::make(32, 0b_11101010_10101010);
|
|||
assert_eq!(store.get_lone_sibling(index), None);
|
|||
}
|
|||
}
|
@ -0,0 +1,31 @@ |
|||
/// A trait for computing the difference between two objects.
|
|||
pub trait Diff<K: Ord + Clone, V: Clone> {
|
|||
/// The type that describes the difference between two objects.
|
|||
type DiffType;
|
|||
|
|||
/// Returns a [Self::DiffType] object that represents the difference between this object and
|
|||
/// other.
|
|||
fn diff(&self, other: &Self) -> Self::DiffType;
|
|||
}
|
|||
|
|||
/// A trait for applying the difference between two objects.
|
|||
pub trait ApplyDiff<K: Ord + Clone, V: Clone> {
|
|||
/// The type that describes the difference between two objects.
|
|||
type DiffType;
|
|||
|
|||
/// Applies the provided changes described by [Self::DiffType] to the object implementing this trait.
|
|||
fn apply(&mut self, diff: Self::DiffType);
|
|||
}
|
|||
|
|||
/// A trait for applying the difference between two objects with the possibility of failure.
|
|||
pub trait TryApplyDiff<K: Ord + Clone, V: Clone> {
|
|||
/// The type that describes the difference between two objects.
|
|||
type DiffType;
|
|||
|
|||
/// An error type that can be returned if the changes cannot be applied.
|
|||
type Error;
|
|||
|
|||
/// Applies the provided changes described by [Self::DiffType] to the object implementing this trait.
|
|||
/// Returns an error if the changes cannot be applied.
|
|||
fn try_apply(&mut self, diff: Self::DiffType) -> Result<(), Self::Error>;
|
|||
}
|