//! Support for generating R1CS shape using bellperson. use std::{ cmp::Ordering, collections::{BTreeMap, HashMap}, }; use crate::traits::Group; use ff::{Field, PrimeField}; use bellperson::{ConstraintSystem, Index, LinearCombination, SynthesisError, Variable}; #[derive(Clone, Copy)] struct OrderedVariable(Variable); #[derive(Debug)] enum NamedObject { Constraint(usize), Var(Variable), Namespace, } impl Eq for OrderedVariable {} impl PartialEq for OrderedVariable { fn eq(&self, other: &OrderedVariable) -> bool { match (self.0.get_unchecked(), other.0.get_unchecked()) { (Index::Input(ref a), Index::Input(ref b)) => a == b, (Index::Aux(ref a), Index::Aux(ref b)) => a == b, _ => false, } } } impl PartialOrd for OrderedVariable { fn partial_cmp(&self, other: &Self) -> Option { Some(self.cmp(other)) } } impl Ord for OrderedVariable { fn cmp(&self, other: &Self) -> Ordering { match (self.0.get_unchecked(), other.0.get_unchecked()) { (Index::Input(ref a), Index::Input(ref b)) => a.cmp(b), (Index::Aux(ref a), Index::Aux(ref b)) => a.cmp(b), (Index::Input(_), Index::Aux(_)) => Ordering::Less, (Index::Aux(_), Index::Input(_)) => Ordering::Greater, } } } #[allow(clippy::upper_case_acronyms)] /// `ShapeCS` is a `ConstraintSystem` for creating `R1CSShape`s for a circuit. pub struct ShapeCS where G::Scalar: PrimeField + Field, { named_objects: HashMap, current_namespace: Vec, #[allow(clippy::type_complexity)] /// All constraints added to the `ShapeCS`. pub constraints: Vec<( LinearCombination, LinearCombination, LinearCombination, String, )>, inputs: Vec, aux: Vec, } fn proc_lc( terms: &LinearCombination, ) -> BTreeMap { let mut map = BTreeMap::new(); for (var, &coeff) in terms.iter() { map .entry(OrderedVariable(var)) .or_insert_with(Scalar::zero) .add_assign(&coeff); } // Remove terms that have a zero coefficient to normalize let mut to_remove = vec![]; for (var, coeff) in map.iter() { if coeff.is_zero().into() { to_remove.push(*var) } } for var in to_remove { map.remove(&var); } map } impl ShapeCS where G::Scalar: PrimeField, { /// Create a new, default `ShapeCS`, pub fn new() -> Self { ShapeCS::default() } /// Returns the number of constraints defined for this `ShapeCS`. pub fn num_constraints(&self) -> usize { self.constraints.len() } /// Returns the number of inputs defined for this `ShapeCS`. pub fn num_inputs(&self) -> usize { self.inputs.len() } /// Returns the number of aux inputs defined for this `ShapeCS`. pub fn num_aux(&self) -> usize { self.aux.len() } /// Print all public inputs, aux inputs, and constraint names. #[allow(dead_code)] pub fn pretty_print_list(&self) -> Vec { let mut result = Vec::new(); for input in &self.inputs { result.push(format!("INPUT {}", input)); } for aux in &self.aux { result.push(format!("AUX {}", aux)); } for &(ref _a, ref _b, ref _c, ref name) in &self.constraints { result.push(name.to_string()); } result } /// Print all iputs and a detailed representation of each constraint. #[allow(dead_code)] pub fn pretty_print(&self) -> String { let mut s = String::new(); for input in &self.inputs { s.push_str(&format!("INPUT {}\n", &input)) } let negone = -::one(); let powers_of_two = (0..G::Scalar::NUM_BITS) .map(|i| G::Scalar::from(2u64).pow_vartime(&[u64::from(i)])) .collect::>(); let pp = |s: &mut String, lc: &LinearCombination| { s.push('('); let mut is_first = true; for (var, coeff) in proc_lc::(lc) { if coeff == negone { s.push_str(" - ") } else if !is_first { s.push_str(" + ") } is_first = false; if coeff != ::one() && coeff != negone { for (i, x) in powers_of_two.iter().enumerate() { if x == &coeff { s.push_str(&format!("2^{} . ", i)); break; } } s.push_str(&format!("{:?} . ", coeff)) } match var.0.get_unchecked() { Index::Input(i) => { s.push_str(&format!("`I{}`", &self.inputs[i])); } Index::Aux(i) => { s.push_str(&format!("`A{}`", &self.aux[i])); } } } if is_first { // Nothing was visited, print 0. s.push('0'); } s.push(')'); }; for &(ref a, ref b, ref c, ref name) in &self.constraints { s.push('\n'); s.push_str(&format!("{}: ", name)); pp(&mut s, a); s.push_str(" * "); pp(&mut s, b); s.push_str(" = "); pp(&mut s, c); } s.push('\n'); s } /// Associate `NamedObject` with `path`. /// `path` must not already have an associated object. fn set_named_obj(&mut self, path: String, to: NamedObject) { assert!( !self.named_objects.contains_key(&path), "tried to create object at existing path: {}", path ); self.named_objects.insert(path, to); } } impl Default for ShapeCS where G::Scalar: PrimeField, { fn default() -> Self { let mut map = HashMap::new(); map.insert("ONE".into(), NamedObject::Var(ShapeCS::::one())); ShapeCS { named_objects: map, current_namespace: vec![], constraints: vec![], inputs: vec![String::from("ONE")], aux: vec![], } } } impl ConstraintSystem for ShapeCS where G::Scalar: PrimeField, { type Root = Self; fn alloc(&mut self, annotation: A, _f: F) -> Result where F: FnOnce() -> Result, A: FnOnce() -> AR, AR: Into, { let path = compute_path(&self.current_namespace, &annotation().into()); self.aux.push(path); Ok(Variable::new_unchecked(Index::Aux(self.aux.len() - 1))) } fn alloc_input(&mut self, annotation: A, _f: F) -> Result where F: FnOnce() -> Result, A: FnOnce() -> AR, AR: Into, { let path = compute_path(&self.current_namespace, &annotation().into()); self.inputs.push(path); Ok(Variable::new_unchecked(Index::Input(self.inputs.len() - 1))) } fn enforce(&mut self, annotation: A, a: LA, b: LB, c: LC) where A: FnOnce() -> AR, AR: Into, LA: FnOnce(LinearCombination) -> LinearCombination, LB: FnOnce(LinearCombination) -> LinearCombination, LC: FnOnce(LinearCombination) -> LinearCombination, { let path = compute_path(&self.current_namespace, &annotation().into()); let index = self.constraints.len(); self.set_named_obj(path.clone(), NamedObject::Constraint(index)); let a = a(LinearCombination::zero()); let b = b(LinearCombination::zero()); let c = c(LinearCombination::zero()); self.constraints.push((a, b, c, path)); } fn push_namespace(&mut self, name_fn: N) where NR: Into, N: FnOnce() -> NR, { let name = name_fn().into(); let path = compute_path(&self.current_namespace, &name); self.set_named_obj(path, NamedObject::Namespace); self.current_namespace.push(name); } fn pop_namespace(&mut self) { assert!(self.current_namespace.pop().is_some()); } fn get_root(&mut self) -> &mut Self::Root { self } } fn compute_path(ns: &[String], this: &str) -> String { assert!( !this.chars().any(|a| a == '/'), "'/' is not allowed in names" ); let mut name = String::new(); let mut needs_separation = false; for ns in ns.iter().chain(Some(this.to_string()).iter()) { if needs_separation { name += "/"; } name += ns; needs_separation = true; } name }