You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 

313 lines
8.8 KiB

const bigInt = require("snarkjs").bigInt;
const SMTMemDB = require("./smt_memdb");
const {hash0, hash1} = require("./smt_hashes_poseidon");
class SMT {
constructor(db, root) {
this.db = db;
this.root = root;
}
_splitBits(_key) {
let k = bigInt(_key);
const res = [];
while (!k.isZero()) {
if (k.isOdd()) {
res.push(true);
} else {
res.push(false);
}
k = k.shr(1);
}
while (res.length<256) res.push(false);
return res;
}
async update(_key, _newValue) {
const key = bigInt(_key);
const newValue = bigInt(_newValue);
const resFind = await this.find(key);
const res = {};
res.oldRoot = this.root;
res.oldKey = key;
res.oldValue = resFind.foundValue;
res.newKey = key;
res.newValue = newValue;
res.siblings = resFind.siblings;
const ins = [];
const dels = [];
let rtOld = hash1(key, resFind.foundValue);
let rtNew = hash1(key, newValue);
ins.push([rtNew, [1, key, newValue ]]);
dels.push(rtOld);
const keyBits = this._splitBits(key);
for (let level = resFind.siblings.length-1; level >=0; level--) {
let oldNode, newNode;
const sibling = resFind.siblings[level];
if (keyBits[level]) {
oldNode = [sibling, rtOld];
newNode = [sibling, rtNew];
} else {
oldNode = [rtOld, sibling];
newNode = [rtNew, sibling];
}
rtOld = hash0(oldNode[0], oldNode[1]);
rtNew = hash0(newNode[0], newNode[1]);
dels.push(rtOld);
ins.push([rtNew, newNode]);
}
res.newRoot = rtNew;
await this.db.multiIns(ins);
await this.db.setRoot(rtNew);
this.root = rtNew;
await this.db.multiDel(dels);
return res;
}
async delete(_key) {
const key = bigInt(_key);
const resFind = await this.find(key);
if (!resFind.found) throw new Error("Key does not exists");
const res = {
siblings: [],
delKey: key,
delValue: resFind.foundValue
};
const dels = [];
const ins = [];
let rtOld = hash1(key, resFind.foundValue);
let rtNew;
dels.push(rtOld);
let mixed;
if (resFind.siblings.length > 0) {
const record = await this.db.get(resFind.siblings[resFind.siblings.length - 1]);
if ((record.length == 3)&&(record[0].equals(bigInt.one))) {
mixed = false;
res.oldKey = record[1];
res.oldValue = record[2];
res.isOld0 = false;
rtNew = resFind.siblings[resFind.siblings.length - 1];
} else if (record.length == 2) {
mixed = true;
res.oldKey = key;
res.oldValue = bigInt(0);
res.isOld0 = true;
rtNew = bigInt.zero;
} else {
throw new Error("Invalid node. Database corrupted");
}
} else {
rtNew = bigInt.zero;
res.oldKey = key;
res.oldValue = bigInt(0);
res.isOld0 = true;
}
const keyBits = this._splitBits(key);
for (let level = resFind.siblings.length-1; level >=0; level--) {
let newSibling = resFind.siblings[level];
if ((level == resFind.siblings.length-1)&&(!res.isOld0)) {
newSibling = bigInt.zero;
}
const oldSibling = resFind.siblings[level];
if (keyBits[level]) {
rtOld = hash0(oldSibling, rtOld);
} else {
rtOld = hash0(rtOld, oldSibling);
}
dels.push(rtOld);
if (!newSibling.isZero()) {
mixed = true;
}
if (mixed) {
res.siblings.unshift(resFind.siblings[level]);
let newNode;
if (keyBits[level]) {
newNode = [newSibling, rtNew];
} else {
newNode = [rtNew, newSibling];
}
rtNew = hash0(newNode[0], newNode[1]);
ins.push([rtNew, newNode]);
}
}
await this.db.multiIns(ins);
await this.db.setRoot(rtNew);
this.root = rtNew;
await this.db.multiDel(dels);
res.newRoot = rtNew;
res.oldRoot = rtOld;
return res;
}
async insert(_key, _value) {
const key = bigInt(_key);
const value = bigInt(_value);
let addedOne = false;
const res = {};
res.oldRoot = this.root;
const newKeyBits = this._splitBits(key);
let rtOld;
const resFind = await this.find(key);
if (resFind.found) throw new Error("Key already exists");
res.siblings = resFind.siblings;
let mixed;
if (!resFind.isOld0) {
const oldKeyits = this._splitBits(resFind.notFoundKey);
for (let i= res.siblings.length; oldKeyits[i] == newKeyBits[i]; i++) {
res.siblings.push(bigInt.zero);
}
rtOld = hash1(resFind.notFoundKey, resFind.notFoundValue);
res.siblings.push(rtOld);
addedOne = true;
mixed = false;
} else if (res.siblings.length >0) {
mixed = true;
rtOld = bigInt.zero;
}
const inserts = [];
const dels = [];
let rt = hash1(key, value);
inserts.push([rt,[1, key, value]] );
for (let i=res.siblings.length-1; i>=0; i--) {
if ((i<res.siblings.length-1)&&(!res.siblings[i].isZero())) {
mixed = true;
}
if (mixed) {
const oldSibling = resFind.siblings[i];
if (newKeyBits[i]) {
rtOld = hash0(oldSibling, rtOld);
} else {
rtOld = hash0(rtOld, oldSibling);
}
dels.push(rtOld);
}
let newRt;
if (newKeyBits[i]) {
newRt = hash0(res.siblings[i], rt);
inserts.push([newRt,[res.siblings[i], rt]] );
} else {
newRt = hash0(rt, res.siblings[i]);
inserts.push([newRt,[rt, res.siblings[i]]] );
}
rt = newRt;
}
if (addedOne) res.siblings.pop();
while ((res.siblings.length>0) && (res.siblings[res.siblings.length-1].isZero())) {
res.siblings.pop();
}
res.oldKey = resFind.notFoundKey;
res.oldValue = resFind.notFoundValue;
res.newRoot = rt;
res.isOld0 = resFind.isOld0;
await this.db.multiIns(inserts);
await this.db.setRoot(rt);
this.root = rt;
await this.db.multiDel(dels);
return res;
}
async find(key) {
const keyBits = this._splitBits(key);
return await this._find(key, keyBits, this.root, 0);
}
async _find(key, keyBits, root, level) {
if (typeof root === "undefined") root = this.root;
let res;
if (root.isZero()) {
res = {
found: false,
siblings: [],
notFoundKey: key,
notFoundValue: bigInt.zero,
isOld0: true
};
return res;
}
const record = await this.db.get(root);
if ((record.length==3)&&(record[0].equals(bigInt.one))) {
if (record[1].equals(key)) {
res = {
found: true,
siblings: [],
foundValue: record[2],
isOld0: false
};
} else {
res = {
found: false,
siblings: [],
notFoundKey: record[1],
notFoundValue: record[2],
isOld0: false
};
}
} else {
if (keyBits[level] == 0) {
res = await this._find(key, keyBits, record[0], level+1);
res.siblings.unshift(record[1]);
} else {
res = await this._find(key, keyBits, record[1], level+1);
res.siblings.unshift(record[0]);
}
}
return res;
}
}
async function loadFromFile(fileName) {
}
async function newMemEmptyTrie() {
const db = new SMTMemDB();
const rt = await db.getRoot();
const smt = new SMT(db, rt);
return smt;
}
module.exports.loadFromFile = loadFromFile;
module.exports.newMemEmptyTrie = newMemEmptyTrie;
module.exports.SMT = SMT;
module.exports.SMTMemDB = SMTMemDB;