You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

142 lines
3.8 KiB

  1. use crate::prelude::*;
  2. use algebra::Field;
  3. use r1cs_core::{ConstraintSystem, SynthesisError};
  4. /// If `condition == 1`, then enforces that `self` and `other` are equal;
  5. /// otherwise, it doesn't enforce anything.
  6. pub trait ConditionalEqGadget<ConstraintF: Field>: Eq {
  7. fn conditional_enforce_equal<CS: ConstraintSystem<ConstraintF>>(
  8. &self,
  9. cs: CS,
  10. other: &Self,
  11. condition: &Boolean,
  12. ) -> Result<(), SynthesisError>;
  13. fn cost() -> usize;
  14. }
  15. impl<T: ConditionalEqGadget<ConstraintF>, ConstraintF: Field> ConditionalEqGadget<ConstraintF>
  16. for [T]
  17. {
  18. fn conditional_enforce_equal<CS: ConstraintSystem<ConstraintF>>(
  19. &self,
  20. mut cs: CS,
  21. other: &Self,
  22. condition: &Boolean,
  23. ) -> Result<(), SynthesisError> {
  24. for (i, (a, b)) in self.iter().zip(other.iter()).enumerate() {
  25. let mut cs = cs.ns(|| format!("Iteration {}", i));
  26. a.conditional_enforce_equal(&mut cs, b, condition)?;
  27. }
  28. Ok(())
  29. }
  30. fn cost() -> usize {
  31. unimplemented!()
  32. }
  33. }
  34. pub trait EqGadget<ConstraintF: Field>: Eq
  35. where
  36. Self: ConditionalEqGadget<ConstraintF>,
  37. {
  38. fn enforce_equal<CS: ConstraintSystem<ConstraintF>>(
  39. &self,
  40. cs: CS,
  41. other: &Self,
  42. ) -> Result<(), SynthesisError> {
  43. self.conditional_enforce_equal(cs, other, &Boolean::constant(true))
  44. }
  45. fn cost() -> usize {
  46. <Self as ConditionalEqGadget<ConstraintF>>::cost()
  47. }
  48. }
  49. impl<T: EqGadget<ConstraintF>, ConstraintF: Field> EqGadget<ConstraintF> for [T] {}
  50. pub trait NEqGadget<ConstraintF: Field>: Eq {
  51. fn enforce_not_equal<CS: ConstraintSystem<ConstraintF>>(
  52. &self,
  53. cs: CS,
  54. other: &Self,
  55. ) -> Result<(), SynthesisError>;
  56. fn cost() -> usize;
  57. }
  58. pub trait OrEqualsGadget<ConstraintF: Field>
  59. where
  60. Self: Sized,
  61. {
  62. fn enforce_equal_or<CS: ConstraintSystem<ConstraintF>>(
  63. cs: CS,
  64. cond: &Boolean,
  65. var: &Self,
  66. first: &Self,
  67. second: &Self,
  68. ) -> Result<(), SynthesisError>;
  69. fn cost() -> usize;
  70. }
  71. impl<ConstraintF: Field, T: Sized + ConditionalOrEqualsGadget<ConstraintF>>
  72. OrEqualsGadget<ConstraintF> for T
  73. {
  74. fn enforce_equal_or<CS: ConstraintSystem<ConstraintF>>(
  75. cs: CS,
  76. cond: &Boolean,
  77. var: &Self,
  78. first: &Self,
  79. second: &Self,
  80. ) -> Result<(), SynthesisError> {
  81. Self::conditional_enforce_equal_or(cs, cond, var, first, second, &Boolean::Constant(true))
  82. }
  83. fn cost() -> usize {
  84. <Self as ConditionalOrEqualsGadget<ConstraintF>>::cost()
  85. }
  86. }
  87. pub trait ConditionalOrEqualsGadget<ConstraintF: Field>
  88. where
  89. Self: Sized,
  90. {
  91. fn conditional_enforce_equal_or<CS: ConstraintSystem<ConstraintF>>(
  92. cs: CS,
  93. cond: &Boolean,
  94. var: &Self,
  95. first: &Self,
  96. second: &Self,
  97. should_enforce: &Boolean,
  98. ) -> Result<(), SynthesisError>;
  99. fn cost() -> usize;
  100. }
  101. impl<
  102. ConstraintF: Field,
  103. T: Sized + ConditionalEqGadget<ConstraintF> + CondSelectGadget<ConstraintF>,
  104. > ConditionalOrEqualsGadget<ConstraintF> for T
  105. {
  106. fn conditional_enforce_equal_or<CS: ConstraintSystem<ConstraintF>>(
  107. mut cs: CS,
  108. cond: &Boolean,
  109. var: &Self,
  110. first: &Self,
  111. second: &Self,
  112. should_enforce: &Boolean,
  113. ) -> Result<(), SynthesisError> {
  114. let match_opt = Self::conditionally_select(
  115. &mut cs.ns(|| "conditional_select_in_or"),
  116. cond,
  117. first,
  118. second,
  119. )?;
  120. var.conditional_enforce_equal(&mut cs.ns(|| "equals_in_or"), &match_opt, should_enforce)
  121. }
  122. fn cost() -> usize {
  123. <Self as ConditionalEqGadget<ConstraintF>>::cost()
  124. + <Self as CondSelectGadget<ConstraintF>>::cost()
  125. }
  126. }