From 701a187e7f7db3a62d3ef82af34504e0b50d20ca Mon Sep 17 00:00:00 2001 From: Grzegorz Swirski Date: Fri, 15 Sep 2023 11:09:03 +0200 Subject: [PATCH] feat: implement RPO hash using SVE instructionss --- .gitignore | 3 + Cargo.toml | 4 + arch/arm64-sve/CMakeLists.txt | 10 ++ arch/arm64-sve/library.c | 78 ++++++++++++ arch/arm64-sve/library.h | 12 ++ arch/arm64-sve/rpo_hash.h | 221 ++++++++++++++++++++++++++++++++++ arch/arm64-sve/test.c | 27 +++++ build.rs | 17 +++ src/hash/rpo/mod.rs | 68 ++++++++++- 9 files changed, 436 insertions(+), 4 deletions(-) create mode 100644 arch/arm64-sve/CMakeLists.txt create mode 100644 arch/arm64-sve/library.c create mode 100644 arch/arm64-sve/library.h create mode 100644 arch/arm64-sve/rpo_hash.h create mode 100644 arch/arm64-sve/test.c create mode 100644 build.rs diff --git a/.gitignore b/.gitignore index 088ba6b..1f17879 100644 --- a/.gitignore +++ b/.gitignore @@ -8,3 +8,6 @@ Cargo.lock # These are backup files generated by rustfmt **/*.rs.bk + +# Generated by cmake +cmake-build-* diff --git a/Cargo.toml b/Cargo.toml index 04dc990..4c92e57 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -32,6 +32,7 @@ name = "store" harness = false [features] +arch-arm64-sve = ["dep:cc"] default = ["blake3/default", "std", "winter_crypto/default", "winter_math/default", "winter_utils/default"] std = ["blake3/std", "winter_crypto/std", "winter_math/std", "winter_utils/std", "rand_utils"] serde = ["winter_math/serde", "dep:serde", "serde/alloc"] @@ -49,3 +50,6 @@ rand_utils = { version = "0.6", package = "winter-rand-utils", optional = true } criterion = { version = "0.5", features = ["html_reports"] } proptest = "1.1.0" rand_utils = { version = "0.6", package = "winter-rand-utils" } + +[build-dependencies] +cc = { version = "1.0.79", optional = true } diff --git a/arch/arm64-sve/CMakeLists.txt b/arch/arm64-sve/CMakeLists.txt new file mode 100644 index 0000000..40710b1 --- /dev/null +++ b/arch/arm64-sve/CMakeLists.txt @@ -0,0 +1,10 @@ +cmake_minimum_required(VERSION 3.0) +project(rpo_sve C) + +set(CMAKE_C_STANDARD 23) +set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -march=armv8-a+sve -Wall -Wextra -pedantic -g -O3") + +add_library(rpo_sve library.c rpo_hash.h) + +add_executable(rpo_test test.c) +target_link_libraries(rpo_test rpo_sve) diff --git a/arch/arm64-sve/library.c b/arch/arm64-sve/library.c new file mode 100644 index 0000000..a1791f7 --- /dev/null +++ b/arch/arm64-sve/library.c @@ -0,0 +1,78 @@ +#include +#include +#include "library.h" +#include "rpo_hash.h" + +// The STATE_WIDTH of RPO hash is 12x u64 elements. +// The current generation of SVE-enabled processors - Neoverse V1 +// (e.g. AWS Graviton3) have 256-bit vector registers (4x u64) +// This allows us to split the state into 3 vectors of 4 elements +// and process all 3 independent of each other. + +// We see the biggest performance gains by leveraging both +// vector and scalar operations on parts of the state array. +// Due to high latency of vector operations, the processor is able +// to reorder and pipeline scalar instructions while we wait for +// vector results. This effectively gives us some 'free' scalar +// operations and masks vector latency. +// +// This also means that we can fully saturate all four arithmetic +// units of the processor (2x scalar, 2x SIMD) +// +// THIS ANALYSIS NEEDS TO BE PERFORMED AGAIN ONCE PROCESSORS +// GAIN WIDER REGISTERS. It's quite possible that with 8x u64 +// vectors processing 2 partially filled vectors might +// be easier and faster than dealing with scalar operations +// on the remainder of the array. +// +// FOR NOW THIS IS ONLY ENABLED ON 4x u64 VECTORS! It falls back +// to the regular, already highly-optimized scalar version +// if the conditions are not met. + +bool add_constants_and_apply_sbox(uint64_t state[STATE_WIDTH], uint64_t constants[STATE_WIDTH]) { + const uint64_t vl = svcntd(); // number of u64 numbers in one SVE vector + + if (vl != 4) { + return false; + } + + svbool_t ptrue = svptrue_b64(); + + svuint64_t state1 = svld1(ptrue, state + 0*vl); + svuint64_t state2 = svld1(ptrue, state + 1*vl); + + svuint64_t const1 = svld1(ptrue, constants + 0*vl); + svuint64_t const2 = svld1(ptrue, constants + 1*vl); + + add_constants(ptrue, &state1, &const1, &state2, &const2, state+8, constants+8); + apply_sbox(ptrue, &state1, &state2, state+8); + + svst1(ptrue, state + 0*vl, state1); + svst1(ptrue, state + 1*vl, state2); + + return true; +} + +bool add_constants_and_apply_inv_sbox(uint64_t state[STATE_WIDTH], uint64_t constants[STATE_WIDTH]) { + const uint64_t vl = svcntd(); // number of u64 numbers in one SVE vector + + if (vl != 4) { + return false; + } + + svbool_t ptrue = svptrue_b64(); + + svuint64_t state1 = svld1(ptrue, state + 0 * vl); + svuint64_t state2 = svld1(ptrue, state + 1 * vl); + + svuint64_t const1 = svld1(ptrue, constants + 0 * vl); + svuint64_t const2 = svld1(ptrue, constants + 1 * vl); + + add_constants(ptrue, &state1, &const1, &state2, &const2, state + 8, constants + 8); + apply_inv_sbox(ptrue, &state1, &state2, state + 8); + + svst1(ptrue, state + 0 * vl, state1); + svst1(ptrue, state + 1 * vl, state2); + + return true; +} diff --git a/arch/arm64-sve/library.h b/arch/arm64-sve/library.h new file mode 100644 index 0000000..c8f1cdd --- /dev/null +++ b/arch/arm64-sve/library.h @@ -0,0 +1,12 @@ +#ifndef CRYPTO_LIBRARY_H +#define CRYPTO_LIBRARY_H + +#include +#include + +#define STATE_WIDTH 12 + +bool add_constants_and_apply_sbox(uint64_t state[STATE_WIDTH], uint64_t constants[STATE_WIDTH]); +bool add_constants_and_apply_inv_sbox(uint64_t state[STATE_WIDTH], uint64_t constants[STATE_WIDTH]); + +#endif //CRYPTO_LIBRARY_H diff --git a/arch/arm64-sve/rpo_hash.h b/arch/arm64-sve/rpo_hash.h new file mode 100644 index 0000000..567298f --- /dev/null +++ b/arch/arm64-sve/rpo_hash.h @@ -0,0 +1,221 @@ +#ifndef RPO_SVE_RPO_HASH_H +#define RPO_SVE_RPO_HASH_H + +#include +#include +#include +#include + +#define COPY(NAME, VIN1, VIN2, SIN3) \ + svuint64_t NAME ## _1 = VIN1; \ + svuint64_t NAME ## _2 = VIN2; \ + uint64_t NAME ## _3[4]; \ + memcpy(NAME ## _3, SIN3, 4 * sizeof(uint64_t)) + +#define MULTIPLY(PRED, DEST, OP) \ + mul(PRED, &DEST ## _1, &OP ## _1, &DEST ## _2, &OP ## _2, DEST ## _3, OP ## _3) + +#define SQUARE(PRED, NAME) \ + sq(PRED, &NAME ## _1, &NAME ## _2, NAME ## _3) + +#define SQUARE_DEST(PRED, DEST, SRC) \ + COPY(DEST, SRC ## _1, SRC ## _2, SRC ## _3); \ + SQUARE(PRED, DEST); + +#define POW_ACC(PRED, NAME, CNT, TAIL) \ + for (size_t i = 0; i < CNT; i++) { \ + SQUARE(PRED, NAME); \ + } \ + MULTIPLY(PRED, NAME, TAIL); + +#define POW_ACC_DEST(PRED, DEST, CNT, HEAD, TAIL) \ + COPY(DEST, HEAD ## _1, HEAD ## _2, HEAD ## _3); \ + POW_ACC(PRED, DEST, CNT, TAIL) + +extern inline void add_constants( + svbool_t pg, + svuint64_t *state1, + svuint64_t *const1, + svuint64_t *state2, + svuint64_t *const2, + uint64_t *state3, + uint64_t *const3 +) { + uint64_t Ms = 0xFFFFFFFF00000001ull; + svuint64_t Mv = svindex_u64(Ms, 0); + + uint64_t p_1 = Ms - const3[0]; + uint64_t p_2 = Ms - const3[1]; + uint64_t p_3 = Ms - const3[2]; + uint64_t p_4 = Ms - const3[3]; + + uint64_t x_1, x_2, x_3, x_4; + uint32_t adj_1 = -__builtin_sub_overflow(state3[0], p_1, &x_1); + uint32_t adj_2 = -__builtin_sub_overflow(state3[1], p_2, &x_2); + uint32_t adj_3 = -__builtin_sub_overflow(state3[2], p_3, &x_3); + uint32_t adj_4 = -__builtin_sub_overflow(state3[3], p_4, &x_4); + + state3[0] = x_1 - (uint64_t)adj_1; + state3[1] = x_2 - (uint64_t)adj_2; + state3[2] = x_3 - (uint64_t)adj_3; + state3[3] = x_4 - (uint64_t)adj_4; + + svuint64_t p1 = svsub_x(pg, Mv, *const1); + svuint64_t p2 = svsub_x(pg, Mv, *const2); + + svuint64_t x1 = svsub_x(pg, *state1, p1); + svuint64_t x2 = svsub_x(pg, *state2, p2); + + svbool_t pt1 = svcmplt_u64(pg, *state1, p1); + svbool_t pt2 = svcmplt_u64(pg, *state2, p2); + + *state1 = svsub_m(pt1, x1, (uint32_t)-1); + *state2 = svsub_m(pt2, x2, (uint32_t)-1); +} + +extern inline void mul( + svbool_t pg, + svuint64_t *r1, + const svuint64_t *op1, + svuint64_t *r2, + const svuint64_t *op2, + uint64_t *r3, + const uint64_t *op3 +) { + __uint128_t x_1 = r3[0]; + __uint128_t x_2 = r3[1]; + __uint128_t x_3 = r3[2]; + __uint128_t x_4 = r3[3]; + + x_1 *= (__uint128_t) op3[0]; + x_2 *= (__uint128_t) op3[1]; + x_3 *= (__uint128_t) op3[2]; + x_4 *= (__uint128_t) op3[3]; + + uint64_t x0_1 = x_1; + uint64_t x0_2 = x_2; + uint64_t x0_3 = x_3; + uint64_t x0_4 = x_4; + + svuint64_t l1 = svmul_x(pg, *r1, *op1); + svuint64_t l2 = svmul_x(pg, *r2, *op2); + + uint64_t x1_1 = (x_1 >> 64); + uint64_t x1_2 = (x_2 >> 64); + uint64_t x1_3 = (x_3 >> 64); + uint64_t x1_4 = (x_4 >> 64); + + uint64_t a_1, a_2, a_3, a_4; + uint64_t e_1 = __builtin_add_overflow(x0_1, (x0_1 << 32), &a_1); + uint64_t e_2 = __builtin_add_overflow(x0_2, (x0_2 << 32), &a_2); + uint64_t e_3 = __builtin_add_overflow(x0_3, (x0_3 << 32), &a_3); + uint64_t e_4 = __builtin_add_overflow(x0_4, (x0_4 << 32), &a_4); + + svuint64_t ls1 = svlsl_x(pg, l1, 32); + svuint64_t ls2 = svlsl_x(pg, l2, 32); + + svuint64_t a1 = svadd_x(pg, l1, ls1); + svuint64_t a2 = svadd_x(pg, l2, ls2); + + svbool_t e1 = svcmplt(pg, a1, l1); + svbool_t e2 = svcmplt(pg, a2, l2); + + svuint64_t as1 = svlsr_x(pg, a1, 32); + svuint64_t as2 = svlsr_x(pg, a2, 32); + + svuint64_t b1 = svsub_x(pg, a1, as1); + svuint64_t b2 = svsub_x(pg, a2, as2); + + b1 = svsub_m(e1, b1, 1); + b2 = svsub_m(e2, b2, 1); + + uint64_t b_1 = a_1 - (a_1 >> 32) - e_1; + uint64_t b_2 = a_2 - (a_2 >> 32) - e_2; + uint64_t b_3 = a_3 - (a_3 >> 32) - e_3; + uint64_t b_4 = a_4 - (a_4 >> 32) - e_4; + + uint64_t r_1, r_2, r_3, r_4; + uint32_t c_1 = __builtin_sub_overflow(x1_1, b_1, &r_1); + uint32_t c_2 = __builtin_sub_overflow(x1_2, b_2, &r_2); + uint32_t c_3 = __builtin_sub_overflow(x1_3, b_3, &r_3); + uint32_t c_4 = __builtin_sub_overflow(x1_4, b_4, &r_4); + + svuint64_t h1 = svmulh_x(pg, *r1, *op1); + svuint64_t h2 = svmulh_x(pg, *r2, *op2); + + svuint64_t tr1 = svsub_x(pg, h1, b1); + svuint64_t tr2 = svsub_x(pg, h2, b2); + + svbool_t c1 = svcmplt_u64(pg, h1, b1); + svbool_t c2 = svcmplt_u64(pg, h2, b2); + + *r1 = svsub_m(c1, tr1, (uint32_t) -1); + *r2 = svsub_m(c2, tr2, (uint32_t) -1); + + uint32_t minus1_1 = 0 - c_1; + uint32_t minus1_2 = 0 - c_2; + uint32_t minus1_3 = 0 - c_3; + uint32_t minus1_4 = 0 - c_4; + + r3[0] = r_1 - (uint64_t)minus1_1; + r3[1] = r_2 - (uint64_t)minus1_2; + r3[2] = r_3 - (uint64_t)minus1_3; + r3[3] = r_4 - (uint64_t)minus1_4; +} + +extern inline void sq(svbool_t pg, svuint64_t *a, svuint64_t *b, uint64_t *c) { + mul(pg, a, a, b, b, c, c); +} + +extern inline void apply_sbox( + svbool_t pg, + svuint64_t *state1, + svuint64_t *state2, + uint64_t *state3 +) { + COPY(x, *state1, *state2, state3); // copy input to x + SQUARE(pg, x); // x contains input^2 + mul(pg, state1, &x_1, state2, &x_2, state3, x_3); // state contains input^3 + SQUARE(pg, x); // x contains input^4 + mul(pg, state1, &x_1, state2, &x_2, state3, x_3); // state contains input^7 +} + +extern inline void apply_inv_sbox( + svbool_t pg, + svuint64_t *state_1, + svuint64_t *state_2, + uint64_t *state_3 +) { + // base^10 + COPY(t1, *state_1, *state_2, state_3); + SQUARE(pg, t1); + + // base^100 + SQUARE_DEST(pg, t2, t1); + + // base^100100 + POW_ACC_DEST(pg, t3, 3, t2, t2); + + // base^100100100100 + POW_ACC_DEST(pg, t4, 6, t3, t3); + + // compute base^100100100100100100100100 + POW_ACC_DEST(pg, t5, 12, t4, t4); + + // compute base^100100100100100100100100100100 + POW_ACC_DEST(pg, t6, 6, t5, t3); + + // compute base^1001001001001001001001001001000100100100100100100100100100100 + POW_ACC_DEST(pg, t7, 31, t6, t6); + + // compute base^1001001001001001001001001001000110110110110110110110110110110111 + SQUARE(pg, t7); + MULTIPLY(pg, t7, t6); + SQUARE(pg, t7); + SQUARE(pg, t7); + MULTIPLY(pg, t7, t1); + MULTIPLY(pg, t7, t2); + mul(pg, state_1, &t7_1, state_2, &t7_2, state_3, t7_3); +} + +#endif //RPO_SVE_RPO_HASH_H diff --git a/arch/arm64-sve/test.c b/arch/arm64-sve/test.c new file mode 100644 index 0000000..78e2f50 --- /dev/null +++ b/arch/arm64-sve/test.c @@ -0,0 +1,27 @@ +#include +#include "library.h" + +void print_array(size_t len, uint64_t arr[len]); + +int main() { + uint64_t C[STATE_WIDTH] = {1, 1, 1, 1 ,1, 1, 1, 1 ,1, 1, 1, 1}; + uint64_t T[STATE_WIDTH] = {1, 2, 3, 4, 1, 2, 3, 4,1, 2, 3, 4}; + + add_constants_and_apply_sbox(T, C); + add_constants_and_apply_inv_sbox(T, C); + + print_array(STATE_WIDTH, T); + + return 0; +} + +void print_array(size_t len, uint64_t arr[len]) +{ + printf("["); + for (size_t i = 0; i < len; i++) + { + printf("%lu ", arr[i]); + } + + printf("]\n"); +} diff --git a/build.rs b/build.rs new file mode 100644 index 0000000..7d95857 --- /dev/null +++ b/build.rs @@ -0,0 +1,17 @@ +fn main() { + #[cfg(feature = "arch-arm64-sve")] + compile_arch_arm64_sve(); +} + +#[cfg(feature = "arch-arm64-sve")] +fn compile_arch_arm64_sve() { + println!("cargo:rerun-if-changed=arch/arm64-sve/library.c"); + println!("cargo:rerun-if-changed=arch/arm64-sve/library.h"); + println!("cargo:rerun-if-changed=arch/arm64-sve/rpo_hash.h"); + + cc::Build::new() + .file("arch/arm64-sve/library.c") + .flag("-march=armv8-a+sve") + .flag("-O3") + .compile("rpo_sve"); +} diff --git a/src/hash/rpo/mod.rs b/src/hash/rpo/mod.rs index 95f2c97..dc7df3f 100644 --- a/src/hash/rpo/mod.rs +++ b/src/hash/rpo/mod.rs @@ -10,6 +10,19 @@ use mds_freq::mds_multiply_freq; #[cfg(test)] mod tests; +#[cfg(feature = "arch-arm64-sve")] +#[link(name = "rpo_sve", kind = "static")] +extern "C" { + fn add_constants_and_apply_sbox( + state: *mut std::ffi::c_ulong, + constants: *const std::ffi::c_ulong, + ) -> bool; + fn add_constants_and_apply_inv_sbox( + state: *mut std::ffi::c_ulong, + constants: *const std::ffi::c_ulong, + ) -> bool; +} + // CONSTANTS // ================================================================================================ @@ -345,18 +358,65 @@ impl Rpo256 { pub fn apply_round(state: &mut [Felt; STATE_WIDTH], round: usize) { // apply first half of RPO round Self::apply_mds(state); - Self::add_constants(state, &ARK1[round]); - Self::apply_sbox(state); + if !Self::optimized_add_constants_and_apply_sbox(state, &ARK1[round]) { + Self::add_constants(state, &ARK1[round]); + Self::apply_sbox(state); + } // apply second half of RPO round Self::apply_mds(state); - Self::add_constants(state, &ARK2[round]); - Self::apply_inv_sbox(state); + if !Self::optimized_add_constants_and_apply_inv_sbox(state, &ARK2[round]) { + Self::add_constants(state, &ARK2[round]); + Self::apply_inv_sbox(state); + } } // HELPER FUNCTIONS // -------------------------------------------------------------------------------------------- + #[inline(always)] + #[cfg(feature = "arch-arm64-sve")] + fn optimized_add_constants_and_apply_sbox( + state: &mut [Felt; STATE_WIDTH], + ark: &[Felt; STATE_WIDTH], + ) -> bool { + unsafe { + add_constants_and_apply_sbox(state.as_mut_ptr() as *mut u64, ark.as_ptr() as *const u64) + } + } + + #[inline(always)] + #[cfg(not(feature = "arch-arm64-sve"))] + fn optimized_add_constants_and_apply_sbox( + _state: &mut [Felt; STATE_WIDTH], + _ark: &[Felt; STATE_WIDTH], + ) -> bool { + false + } + + #[inline(always)] + #[cfg(feature = "arch-arm64-sve")] + fn optimized_add_constants_and_apply_inv_sbox( + state: &mut [Felt; STATE_WIDTH], + ark: &[Felt; STATE_WIDTH], + ) -> bool { + unsafe { + add_constants_and_apply_inv_sbox( + state.as_mut_ptr() as *mut u64, + ark.as_ptr() as *const u64, + ) + } + } + + #[inline(always)] + #[cfg(not(feature = "arch-arm64-sve"))] + fn optimized_add_constants_and_apply_inv_sbox( + _state: &mut [Felt; STATE_WIDTH], + _ark: &[Felt; STATE_WIDTH], + ) -> bool { + false + } + #[inline(always)] fn apply_mds(state: &mut [Felt; STATE_WIDTH]) { let mut result = [ZERO; STATE_WIDTH];