diff --git a/Cargo.toml b/Cargo.toml index 56bbf9e..9754f32 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -2,7 +2,7 @@ name = "plonky" description = "Recursive SNARKs based on Plonk and Halo" version = "0.1.0" -authors = ["Daniel Lubarov"] +authors = ["Daniel Lubarov ", "William Borgeaud "] readme = "README.md" license = "MIT OR Apache-2.0" repository = "https://github.com/mir-protocol/plonky" diff --git a/src/field/field.rs b/src/field/field.rs index ad7fc58..9682f55 100644 --- a/src/field/field.rs +++ b/src/field/field.rs @@ -340,6 +340,10 @@ pub trait Field: self.exp(Self::from_canonical_usize(power)) } + fn kth_root_usize(&self, k: usize) -> Self { + self.kth_root(Self::from_canonical_usize(k)) + } + fn kth_root_u32(&self, k: u32) -> Self { self.kth_root(Self::from_canonical_u32(k)) } diff --git a/src/gadgets/mod.rs b/src/gadgets/mod.rs new file mode 100644 index 0000000..ba4f77c --- /dev/null +++ b/src/gadgets/mod.rs @@ -0,0 +1,3 @@ +pub use range_check::*; + +mod range_check; diff --git a/src/gadgets/range_check.rs b/src/gadgets/range_check.rs new file mode 100644 index 0000000..e69de29 diff --git a/src/gates2/arithmetic.rs b/src/gates2/arithmetic.rs new file mode 100644 index 0000000..917dfcf --- /dev/null +++ b/src/gates2/arithmetic.rs @@ -0,0 +1,96 @@ +use crate::{CircuitBuilder2, CircuitConfig, ConstraintPolynomial, DeterministicGate, DeterministicGateAdapter, Field, GateRef, Target2}; + +/// A gate which can be configured to perform various arithmetic. In particular, it computes +/// +/// ```text +/// output := product_weight * multiplicand_0 * multiplicand_1 +/// + addend_weight * addend +/// ``` +/// +/// where `product_weight` and `addend_weight` are constants, and the other variables are wires. +#[derive(Eq, PartialEq, Hash)] +pub struct ArithmeticGate2; + +impl ArithmeticGate2 { + pub fn get_ref() -> GateRef { + GateRef::new(DeterministicGateAdapter::new(ArithmeticGate2)) + } + + pub const CONST_PRODUCT_WEIGHT: usize = 0; + pub const CONST_ADDEND_WEIGHT: usize = 1; + + pub const WIRE_MULTIPLICAND_0: usize = 0; + pub const WIRE_MULTIPLICAND_1: usize = 1; + pub const WIRE_ADDEND: usize = 2; + pub const WIRE_OUTPUT: usize = 3; + + /// Computes `x y + z`. + pub fn mul_add( + builder: &mut CircuitBuilder2, + x: Target2, + y: Target2, + z: Target2, + ) -> Target2 { + let gate_type = ArithmeticGate2::get_ref(); + let constants = vec![F::ONE, F::ONE]; + let gate = builder.add_gate(gate_type, constants); + + builder.route(x, Target2::wire(gate, Self::WIRE_MULTIPLICAND_0)); + builder.route(y, Target2::wire(gate, Self::WIRE_MULTIPLICAND_1)); + builder.route(z, Target2::wire(gate, Self::WIRE_ADDEND)); + + Target2::wire(gate, Self::WIRE_OUTPUT) + } + + /// Computes `x y`. + pub fn mul( + builder: &mut CircuitBuilder2, + x: Target2, + y: Target2, + ) -> Target2 { + let zero = builder.zero(); + Self::mul_add(builder, x, y, zero) + } + + /// Computes `x + y`. + pub fn add( + builder: &mut CircuitBuilder2, + x: Target2, + y: Target2, + ) -> Target2 { + let one = builder.one(); + Self::mul_add(builder, x, one, y) + } +} + +impl DeterministicGate for ArithmeticGate2 { + fn id(&self) -> String { + "ArithmeticGate".into() + } + + fn outputs(&self, _config: CircuitConfig) -> Vec<(usize, ConstraintPolynomial)> { + let const_0 = ConstraintPolynomial::local_constant(Self::CONST_PRODUCT_WEIGHT); + let const_1 = ConstraintPolynomial::local_constant(Self::CONST_ADDEND_WEIGHT); + let multiplicand_0 = ConstraintPolynomial::local_wire_value(Self::WIRE_MULTIPLICAND_0); + let multiplicand_1 = ConstraintPolynomial::local_wire_value(Self::WIRE_MULTIPLICAND_1); + let addend = ConstraintPolynomial::local_wire_value(Self::WIRE_ADDEND); + + let out = const_0 * multiplicand_0 * &multiplicand_1 + const_1 * &addend; + vec![(Self::WIRE_OUTPUT, out)] + } +} + +#[cfg(test)] +mod tests { + use crate::{CircuitBuilder2, CircuitConfig, TweedledumBase}; + use crate::gates2::arithmetic::ArithmeticGate2; + + fn add() { + let config = CircuitConfig { num_wires: 3, num_routed_wires: 3, security_bits: 128 }; + let mut builder = CircuitBuilder2::::new(config); + let one = builder.one(); + let two = builder.two(); + let sum = ArithmeticGate2::add(&mut builder, one, one); + todo!() + } +} diff --git a/src/gates2/buffer.rs b/src/gates2/buffer.rs new file mode 100644 index 0000000..27cc2b3 --- /dev/null +++ b/src/gates2/buffer.rs @@ -0,0 +1,17 @@ +use crate::{CircuitConfig, ConstraintPolynomial, DeterministicGate, Field}; + +/// A gate which doesn't perform any arithmetic, but just acts as a buffer for receiving data. +/// Some gates, such as the Rescue round gate, "output" their results using one of the next gate's +/// "input" wires. The last such gate has no next gate of the same type, so we add a buffer gate +/// for receiving the last gate's output. +pub struct BufferGate2; + +impl DeterministicGate for BufferGate2 { + fn id(&self) -> String { + "Buffer".into() + } + + fn outputs(&self, _config: CircuitConfig) -> Vec<(usize, ConstraintPolynomial)> { + Vec::new() + } +} diff --git a/src/gates2/constant.rs b/src/gates2/constant.rs new file mode 100644 index 0000000..612aeb0 --- /dev/null +++ b/src/gates2/constant.rs @@ -0,0 +1,25 @@ +use crate::{CircuitConfig, ConstraintPolynomial, DeterministicGate, DeterministicGateAdapter, Field, GateRef}; + +/// A gate which takes a single constant parameter and outputs that value. +pub struct ConstantGate2; + +impl ConstantGate2 { + pub fn get_ref() -> GateRef { + GateRef::new(DeterministicGateAdapter::new(ConstantGate2)) + } + + pub const CONST_INPUT: usize = 0; + + pub const WIRE_OUTPUT: usize = 0; +} + +impl DeterministicGate for ConstantGate2 { + fn id(&self) -> String { + "ConstantGate".into() + } + + fn outputs(&self, _config: CircuitConfig) -> Vec<(usize, ConstraintPolynomial)> { + let out = ConstraintPolynomial::local_constant(Self::CONST_INPUT); + vec![(Self::WIRE_OUTPUT, out)] + } +} diff --git a/src/gates2/curve_add.rs b/src/gates2/curve_add.rs new file mode 100644 index 0000000..08cd34e --- /dev/null +++ b/src/gates2/curve_add.rs @@ -0,0 +1,201 @@ +use serde::export::PhantomData; + +use crate::{CircuitConfig, ConstraintPolynomial, Curve, Field, Gate2, PartialWitness2, SimpleGenerator, Target2, Wire, WitnessGenerator2}; + +pub struct CurveAddGate2 { + _phantom: PhantomData, +} + +impl CurveAddGate2 { + pub const WIRE_GROUP_ACC_X: usize = 0; + pub const WIRE_GROUP_ACC_Y: usize = 1; + pub const WIRE_SCALAR_ACC_OLD: usize = 2; + pub const WIRE_SCALAR_ACC_NEW: usize = 3; + pub const WIRE_ADDEND_X: usize = 4; + pub const WIRE_ADDEND_Y: usize = 5; + pub const WIRE_SCALAR_BIT: usize = 6; + pub const WIRE_INVERSE: usize = 7; + pub const WIRE_LAMBDA: usize = 8; +} + +impl Gate2 for CurveAddGate2 { + fn id(&self) -> String { + "CurveAddGate".into() + } + + fn constraints( + &self, + _config: CircuitConfig, + ) -> Vec::BaseField>> { + // Notation: + // - p1 is the accumulator; + // - p2 is the addend; + // - p3 = p1 + p2; + // - p4 = if scalar_bit { p3 } else { p1 } + + let x1 = ConstraintPolynomial::::local_wire_value(Self::WIRE_GROUP_ACC_X); + let y1 = ConstraintPolynomial::::local_wire_value(Self::WIRE_GROUP_ACC_Y); + let x4 = ConstraintPolynomial::::next_wire_value(Self::WIRE_GROUP_ACC_X); + let y4 = ConstraintPolynomial::::next_wire_value(Self::WIRE_GROUP_ACC_Y); + let scalar_acc_old = ConstraintPolynomial::::local_wire_value(Self::WIRE_SCALAR_ACC_OLD); + let scalar_acc_new = ConstraintPolynomial::::local_wire_value(Self::WIRE_SCALAR_ACC_NEW); + let x2 = ConstraintPolynomial::::local_wire_value(Self::WIRE_ADDEND_X); + let y2 = ConstraintPolynomial::::local_wire_value(Self::WIRE_ADDEND_Y); + let scalar_bit = ConstraintPolynomial::::local_wire_value(Self::WIRE_SCALAR_BIT); + let inverse = ConstraintPolynomial::::local_wire_value(Self::WIRE_INVERSE); + let lambda = ConstraintPolynomial::::local_wire_value(Self::WIRE_LAMBDA); + + let computed_lambda = (&y1 - &y2) * &inverse; + let x3 = lambda.square() - &x1 - &x2; + // We subtract x4 instead of x3 in order to minimize degree. This will give an incorrect + // result for y3 if x3 != x4, which happens when scalar_bit = 0, but in that case y3 will + // be ignored (i.e. multiplied by zero), so we're okay. + let y3 = &lambda * (&x1 - &x4) - &y1; + + let not_scalar_bit = ConstraintPolynomial::constant_usize(1) - &scalar_bit; + let computed_x4 = &scalar_bit * &x3 + ¬_scalar_bit * &x1; + let computed_y4 = &scalar_bit * &y3 + ¬_scalar_bit * &y1; + + vec![ + &computed_lambda - &lambda, + &computed_x4 - &x4, + &computed_y4 - &y4, + &scalar_acc_new - (scalar_acc_old.double() + &scalar_bit), + &scalar_bit * ¬_scalar_bit, + &inverse * (&x1 - &x2) - C::BaseField::ONE, + ] + } + + fn generators( + &self, + _config: CircuitConfig, + gate_index: usize, + _local_constants: Vec<::BaseField>, + _next_constants: Vec<::BaseField>, + ) -> Vec>> { + let gen = CurveAddGateGenerator:: { gate_index, _phantom: PhantomData }; + vec![Box::new(gen)] + } +} + +struct CurveAddGateGenerator { + gate_index: usize, + _phantom: PhantomData, +} + +impl SimpleGenerator for CurveAddGateGenerator { + fn dependencies(&self) -> Vec> { + vec![ + Target2::Wire(Wire { + gate: self.gate_index, + input: CurveAddGate2::::WIRE_GROUP_ACC_X, + }), + Target2::Wire(Wire { + gate: self.gate_index, + input: CurveAddGate2::::WIRE_GROUP_ACC_Y, + }), + Target2::Wire(Wire { + gate: self.gate_index, + input: CurveAddGate2::::WIRE_SCALAR_ACC_OLD, + }), + Target2::Wire(Wire { + gate: self.gate_index, + input: CurveAddGate2::::WIRE_ADDEND_X, + }), + Target2::Wire(Wire { + gate: self.gate_index, + input: CurveAddGate2::::WIRE_ADDEND_Y, + }), + ] + } + + fn run_once( + &mut self, + witness: &PartialWitness2, + ) -> PartialWitness2 { + // Notation: + // - p1 is the accumulator; + // - p2 is the addend; + // - p3 = p1 + p2; + // - p4 = if scalar_bit { p3 } else { p1 } + + let x1_wire = Wire { + gate: self.gate_index, + input: CurveAddGate2::::WIRE_GROUP_ACC_X, + }; + let y1_wire = Wire { + gate: self.gate_index, + input: CurveAddGate2::::WIRE_GROUP_ACC_Y, + }; + let x4_wire = Wire { + gate: self.gate_index + 1, + input: CurveAddGate2::::WIRE_GROUP_ACC_X, + }; + let y4_wire = Wire { + gate: self.gate_index + 1, + input: CurveAddGate2::::WIRE_GROUP_ACC_Y, + }; + let scalar_acc_old_wire = Wire { + gate: self.gate_index, + input: CurveAddGate2::::WIRE_SCALAR_ACC_OLD, + }; + let scalar_acc_new_wire = Wire { + gate: self.gate_index, + input: CurveAddGate2::::WIRE_SCALAR_ACC_NEW, + }; + let x2_wire = Wire { + gate: self.gate_index, + input: CurveAddGate2::::WIRE_ADDEND_X, + }; + let y2_wire = Wire { + gate: self.gate_index, + input: CurveAddGate2::::WIRE_ADDEND_Y, + }; + let scalar_bit_wire = Wire { + gate: self.gate_index, + input: CurveAddGate2::::WIRE_SCALAR_BIT, + }; + let inverse_wire = Wire { + gate: self.gate_index, + input: CurveAddGate2::::WIRE_INVERSE, + }; + let lambda_wire = Wire { + gate: self.gate_index, + input: CurveAddGate2::::WIRE_LAMBDA, + }; + + let x1 = witness.get_wire(x1_wire); + let y1 = witness.get_wire(y1_wire); + + let scalar_acc_old = witness.get_wire(scalar_acc_old_wire); + + let x2 = witness.get_wire(x2_wire); + let y2 = witness.get_wire(y2_wire); + + let scalar_bit = witness.get_wire(scalar_bit_wire); + debug_assert!(scalar_bit.is_zero() || scalar_bit.is_one()); + + let scalar_acc_new = scalar_acc_old.double() + scalar_bit; + + let dx = x1 - x2; + let dy = y1 - y2; + let inverse = dx.multiplicative_inverse().expect("x_1 = x_2"); + let lambda = dy * inverse; + let x3 = lambda.square() - x1 - x2; + let y3 = lambda * (x1 - x3) - y1; + + let (x4, y4) = if scalar_bit.is_one() { + (x3, y3) + } else { + (x1, y1) + }; + + let mut result = PartialWitness2::new(); + result.set_wire(x4_wire, x4); + result.set_wire(y4_wire, y4); + result.set_wire(scalar_acc_new_wire, scalar_acc_new); + result.set_wire(inverse_wire, inverse); + result.set_wire(lambda_wire, lambda); + result + } +} diff --git a/src/gates2/curve_dbl.rs b/src/gates2/curve_dbl.rs new file mode 100644 index 0000000..f133943 --- /dev/null +++ b/src/gates2/curve_dbl.rs @@ -0,0 +1,103 @@ +use std::marker::PhantomData; + +use crate::{CircuitConfig, ConstraintPolynomial, Curve, Field, Gate2, PartialWitness2, SimpleGenerator, Target2, Wire, WitnessGenerator2}; + +pub struct CurveDblGate2 { + _phantom: PhantomData, +} + +impl CurveDblGate2 { + pub const WIRE_X_OLD: usize = 0; + pub const WIRE_Y_OLD: usize = 1; + pub const WIRE_X_NEW: usize = 2; + pub const WIRE_Y_NEW: usize = 3; + pub const WIRE_INVERSE: usize = 4; + pub const WIRE_LAMBDA: usize = 5; +} + +impl Gate2 for CurveDblGate2 { + fn id(&self) -> String { + "CurveDblGate".into() + } + + fn constraints(&self, _config: CircuitConfig) -> Vec> { + let x_old = ConstraintPolynomial::::local_wire_value(Self::WIRE_X_OLD); + let y_old = ConstraintPolynomial::::local_wire_value(Self::WIRE_Y_OLD); + let x_new = ConstraintPolynomial::::local_wire_value(Self::WIRE_X_NEW); + let y_new = ConstraintPolynomial::::local_wire_value(Self::WIRE_Y_NEW); + let inverse = ConstraintPolynomial::::local_wire_value(Self::WIRE_INVERSE); + let lambda = ConstraintPolynomial::::local_wire_value(Self::WIRE_LAMBDA); + + let computed_lambda_numerator = x_old.square().triple() + C::A; + let computed_lambda = &computed_lambda_numerator * &inverse; + let computed_x_new = lambda.square() - x_old.double(); + let computed_y_new = &lambda * (&x_old - &x_new) - &y_old; + + vec![ + // Verify that computed_lambda matches lambda. + &computed_lambda - &lambda, + // Verify that computed_x_new matches x_new. + &computed_x_new - &x_new, + // Verify that computed_y_new matches y_new. + &computed_y_new - &y_new, + // Verify that 2 * y_old times its purported inverse is 1. + y_old.double() * &inverse - 1, + ] + } + + fn generators( + &self, + _config: CircuitConfig, + gate_index: usize, + _local_constants: Vec, + _next_constants: Vec, + ) -> Vec>> { + let gen = CurveDblGateGenerator:: { gate_index, _phantom: PhantomData }; + vec![Box::new(gen)] + } +} + +struct CurveDblGateGenerator { + gate_index: usize, + _phantom: PhantomData, +} + +impl SimpleGenerator for CurveDblGateGenerator { + fn dependencies(&self) -> Vec> { + vec![ + Target2::Wire(Wire { gate: self.gate_index, input: CurveDblGate2::::WIRE_X_OLD }), + Target2::Wire(Wire { gate: self.gate_index, input: CurveDblGate2::::WIRE_Y_OLD }), + ] + } + + fn run_once( + &mut self, + witness: &PartialWitness2, + ) -> PartialWitness2 { + let x_old_wire = Wire { gate: self.gate_index, input: CurveDblGate2::::WIRE_X_OLD }; + let y_old_wire = Wire { gate: self.gate_index, input: CurveDblGate2::::WIRE_Y_OLD }; + let x_new_wire = Wire { gate: self.gate_index, input: CurveDblGate2::::WIRE_X_NEW }; + let y_new_wire = Wire { gate: self.gate_index, input: CurveDblGate2::::WIRE_Y_NEW }; + let inverse_wire = Wire { gate: self.gate_index, input: CurveDblGate2::::WIRE_INVERSE }; + let lambda_wire = Wire { gate: self.gate_index, input: CurveDblGate2::::WIRE_LAMBDA }; + + let x_old = witness.get_wire(x_old_wire); + let y_old = witness.get_wire(y_old_wire); + + let inverse = y_old.double().multiplicative_inverse().expect("y = 0"); + let mut lambda_numerator = x_old.square().triple(); + if C::A.is_nonzero() { + lambda_numerator = lambda_numerator + C::A; + } + let lambda = lambda_numerator * inverse; + let x_new = lambda.square() - x_old.double(); + let y_new = lambda * (x_old - x_new) - y_old; + + let mut result = PartialWitness2::new(); + result.set_wire(inverse_wire, inverse); + result.set_wire(lambda_wire, lambda); + result.set_wire(x_new_wire, x_new); + result.set_wire(y_new_wire, y_new); + result + } +} diff --git a/src/gates2/curve_endo.rs b/src/gates2/curve_endo.rs new file mode 100644 index 0000000..aad7360 --- /dev/null +++ b/src/gates2/curve_endo.rs @@ -0,0 +1,134 @@ +use std::marker::PhantomData; + +use crate::{CircuitConfig, ConstraintPolynomial, Curve, Field, Gate2, PartialWitness2, SimpleGenerator, Target2, Wire, WitnessGenerator2, HaloCurve}; + +/// Performs a step of Halo's accumulate-with-endomorphism loop. +pub struct CurveEndoGate2 { + _phantom: PhantomData, +} + +impl CurveEndoGate2 { + pub const WIRE_ADDEND_X: usize = 0; + pub const WIRE_ADDEND_Y: usize = 1; + pub const WIRE_SCALAR_BIT_0: usize = 2; + pub const WIRE_SCALAR_BIT_1: usize = 3; + pub const WIRE_GROUP_ACC_X: usize = 4; + pub const WIRE_GROUP_ACC_Y: usize = 5; + pub const WIRE_INVERSE: usize = 6; + pub const WIRE_LAMBDA: usize = 7; +} + +impl Gate2 for CurveEndoGate2 { + fn id(&self) -> String { + "CurveEndoGate".into() + } + + fn constraints(&self, _config: CircuitConfig) -> Vec> { + let addend_x = ConstraintPolynomial::::local_wire_value(Self::WIRE_ADDEND_X); + let addend_y = ConstraintPolynomial::::local_wire_value(Self::WIRE_ADDEND_Y); + let scalar_bit_0 = ConstraintPolynomial::::local_wire_value(Self::WIRE_SCALAR_BIT_0); + let scalar_bit_1 = ConstraintPolynomial::::local_wire_value(Self::WIRE_SCALAR_BIT_1); + let group_acc_old_x = ConstraintPolynomial::::local_wire_value(Self::WIRE_GROUP_ACC_X); + let group_acc_old_y = ConstraintPolynomial::::local_wire_value(Self::WIRE_GROUP_ACC_Y); + let group_acc_new_x = ConstraintPolynomial::::next_wire_value(Self::WIRE_GROUP_ACC_X); + let group_acc_new_y = ConstraintPolynomial::::next_wire_value(Self::WIRE_GROUP_ACC_Y); + let inverse = ConstraintPolynomial::::local_wire_value(Self::WIRE_INVERSE); + let lambda = ConstraintPolynomial::::local_wire_value(Self::WIRE_LAMBDA); + + // Conditionally apply the endo and conditionally negate in order to get S_i, which is + // the actual point we want to add to the accumulator. + let s_i_x = (&scalar_bit_1 * (C::ZETA - C::BaseField::ONE) + 1) * &addend_x; + let s_i_y = (scalar_bit_0.double() - 1) * &addend_y; + + let computed_lambda = (&group_acc_old_y - &s_i_y) * &inverse; + let computed_group_acc_new_x = lambda.square() - &group_acc_old_x - &s_i_x; + let computed_group_acc_new_y = &lambda * (&group_acc_old_x - &group_acc_new_x) - &group_acc_old_y; + + vec![ + &computed_group_acc_new_x - &group_acc_new_x, + &computed_group_acc_new_y - &group_acc_new_y, + &scalar_bit_0 * (&scalar_bit_0 - 1), + &scalar_bit_1 * (&scalar_bit_1 - 1), + &inverse * (&group_acc_old_x - &s_i_x) - 1, + &computed_lambda - &lambda, + ] + } + + fn generators( + &self, + _config: CircuitConfig, + gate_index: usize, + _local_constants: Vec, + _next_constants: Vec<::BaseField>, + ) -> Vec>> { + let gen = CurveEndoGateGenerator:: { gate_index, _phantom: PhantomData }; + vec![Box::new(gen)] + } +} + +struct CurveEndoGateGenerator { + gate_index: usize, + _phantom: PhantomData, +} + +impl SimpleGenerator for CurveEndoGateGenerator { + fn dependencies(&self) -> Vec> { + vec![ + Target2::Wire(Wire { gate: self.gate_index, input: CurveEndoGate2::::WIRE_ADDEND_X }), + Target2::Wire(Wire { gate: self.gate_index, input: CurveEndoGate2::::WIRE_ADDEND_Y }), + Target2::Wire(Wire { gate: self.gate_index, input: CurveEndoGate2::::WIRE_SCALAR_BIT_0 }), + Target2::Wire(Wire { gate: self.gate_index, input: CurveEndoGate2::::WIRE_SCALAR_BIT_1 }), + Target2::Wire(Wire { gate: self.gate_index, input: CurveEndoGate2::::WIRE_GROUP_ACC_X }), + Target2::Wire(Wire { gate: self.gate_index, input: CurveEndoGate2::::WIRE_GROUP_ACC_Y }), + ] + } + + fn run_once( + &mut self, + witness: &PartialWitness2, + ) -> PartialWitness2 { + let addend_x_wire = Wire { gate: self.gate_index, input: CurveEndoGate2::::WIRE_ADDEND_X }; + let addend_y_wire = Wire { gate: self.gate_index, input: CurveEndoGate2::::WIRE_ADDEND_Y }; + let scalar_bit_0_wire = Wire { gate: self.gate_index, input: CurveEndoGate2::::WIRE_SCALAR_BIT_0 }; + let scalar_bit_1_wire = Wire { gate: self.gate_index, input: CurveEndoGate2::::WIRE_SCALAR_BIT_1 }; + let group_acc_old_x_wire = Wire { gate: self.gate_index, input: CurveEndoGate2::::WIRE_GROUP_ACC_X }; + let group_acc_old_y_wire = Wire { gate: self.gate_index, input: CurveEndoGate2::::WIRE_GROUP_ACC_Y }; + let group_acc_new_x_wire = Wire { gate: self.gate_index + 1, input: CurveEndoGate2::::WIRE_GROUP_ACC_X }; + let group_acc_new_y_wire = Wire { gate: self.gate_index + 1, input: CurveEndoGate2::::WIRE_GROUP_ACC_Y }; + let inverse_wire = Wire { gate: self.gate_index, input: CurveEndoGate2::::WIRE_INVERSE }; + let lambda_wire = Wire { gate: self.gate_index, input: CurveEndoGate2::::WIRE_LAMBDA }; + + // Load input values. + let addend_x = witness.get_wire(addend_x_wire); + let addend_y = witness.get_wire(addend_y_wire); + let scalar_bit_0 = witness.get_wire(scalar_bit_0_wire); + let scalar_bit_1 = witness.get_wire(scalar_bit_1_wire); + let group_acc_old_x = witness.get_wire(group_acc_old_x_wire); + let group_acc_old_y = witness.get_wire(group_acc_old_y_wire); + + // Compute S_i as defined in Halo. + let mut s_i_x = addend_x; + if scalar_bit_0 == C::BaseField::ONE { + s_i_x = s_i_x * C::ZETA; + } + let mut s_i_y = addend_y; + if scalar_bit_1 == C::BaseField::ZERO { + s_i_y = -s_i_y; + } + + // Compute group_acc_new = group_acc_old_x + s_i. + let dx = group_acc_old_x - s_i_x; + let dy = group_acc_old_y - s_i_y; + let inverse = dx.multiplicative_inverse().expect("x_1 = x_2"); + let lambda = dy * inverse; + let group_acc_new_x = lambda.square() - group_acc_old_x - s_i_x; + let group_acc_new_y = lambda * (group_acc_old_x - group_acc_new_x) - group_acc_old_y; + + let mut result = PartialWitness2::new(); + result.set_wire(inverse_wire, inverse); + result.set_wire(lambda_wire, lambda); + result.set_wire(group_acc_new_x_wire, group_acc_new_x); + result.set_wire(group_acc_new_y_wire, group_acc_new_y); + result + } +} diff --git a/src/gates2/deterministic_gate.rs b/src/gates2/deterministic_gate.rs new file mode 100644 index 0000000..ddf6d2f --- /dev/null +++ b/src/gates2/deterministic_gate.rs @@ -0,0 +1,139 @@ +use serde::export::PhantomData; + +use crate::{CircuitConfig, ConstraintPolynomial, EvaluationVars, Field, Gate2, PartialWitness2, SimpleGenerator, Target2, Wire, WitnessGenerator2}; + +/// A deterministic gate. Each entry in `outputs()` describes how that output is evaluated; this is +/// used to create both the constraint set and the generator set. +/// +/// `DeterministicGate`s do not automatically implement `Gate`; they should instead be wrapped in +/// `DeterministicGateAdapter`. +pub trait DeterministicGate: 'static { + /// A unique identifier for this gate. + fn id(&self) -> String; + + /// A vector of `(i, c)` pairs, where `i` is the index of an output and `c` is the polynomial + /// defining how that output is evaluated. + fn outputs(&self, config: CircuitConfig) -> Vec<(usize, ConstraintPolynomial)>; + + /// Any additional constraints to be enforced, besides the (automatically provided) ones that + /// constraint output values. + fn additional_constraints(&self, _config: CircuitConfig) -> Vec> { + Vec::new() + } + + /// Any additional generators, besides the (automatically provided) ones that generate output + /// values. + fn additional_generators( + &self, + _config: CircuitConfig, + _gate_index: usize, + ) -> Vec>> { + Vec::new() + } +} + +/// A wrapper around `DeterministicGate` which implements `Gate`. Note that a blanket implementation +/// is not possible in this context given Rust's coherence rules. +pub struct DeterministicGateAdapter + ?Sized> { + gate: Box, + _phantom: PhantomData, +} + +impl> DeterministicGateAdapter { + pub fn new(gate: DG) -> Self { + Self { gate: Box::new(gate), _phantom: PhantomData } + } +} + +impl> Gate2 for DeterministicGateAdapter { + fn id(&self) -> String { + self.gate.id() + } + + fn constraints(&self, config: CircuitConfig) -> Vec> { + // For each output, we add a constraint of the form `out - expression = 0`, + // then we append any additional constraints that the gate defines. + self.gate.outputs(config).into_iter() + .map(|(i, out)| out - ConstraintPolynomial::local_wire_value(i)) + .chain(self.gate.additional_constraints(config).into_iter()) + .collect() + } + + fn generators( + &self, + config: CircuitConfig, + gate_index: usize, + local_constants: Vec, + next_constants: Vec, + ) -> Vec>> { + self.gate.outputs(config) + .into_iter() + .map(|(input_index, out)| { + let og = OutputGenerator { + gate_index, + input_index, + out, + local_constants: local_constants.clone(), + next_constants: next_constants.clone(), + }; + + // We need the type system to treat this as a boxed `WitnessGenerator2`, rather + // than a boxed `OutputGenerator`. + let b: Box::> = Box::new(og); + b + }) + .chain(self.gate.additional_generators(config, gate_index)) + .collect() + } +} + +struct OutputGenerator { + gate_index: usize, + input_index: usize, + out: ConstraintPolynomial, + local_constants: Vec, + next_constants: Vec, +} + +impl SimpleGenerator for OutputGenerator { + fn dependencies(&self) -> Vec> { + self.out.dependencies(self.gate_index) + .into_iter() + .map(Target2::Wire) + .collect() + } + + fn run_once(&mut self, witness: &PartialWitness2) -> PartialWitness2 { + let mut local_wire_values = Vec::new(); + let mut next_wire_values = Vec::new(); + + // Get an exclusive upper bound on the largest input index in this constraint. + let input_limit_exclusive = self.out.max_wire_input_index() + .map_or(0, |i| i + 1); + + for input in 0..input_limit_exclusive { + let local_wire = Wire { gate: self.gate_index, input }; + let next_wire = Wire { gate: self.gate_index + 1, input }; + + // Lookup the values if they exist. If not, we can just insert a zero, knowing + // that it will not be used. (If it was used, it would have been included in our + // dependencies, and this generator would not have run yet.) + let local_value = witness.try_get_target(Target2::Wire(local_wire)).unwrap_or(F::ZERO); + let next_value = witness.try_get_target(Target2::Wire(next_wire)).unwrap_or(F::ZERO); + + local_wire_values.push(local_value); + next_wire_values.push(next_value); + } + + let vars = EvaluationVars { + local_constants: &self.local_constants, + next_constants: &self.next_constants, + local_wire_values: &local_wire_values, + next_wire_values: &next_wire_values, + }; + + let result_wire = Wire { gate: self.gate_index, input: self.input_index }; + let result_value = self.out.evaluate(vars); + PartialWitness2::singleton(Target2::Wire(result_wire), result_value) + } +} diff --git a/src/gates2/gate.rs b/src/gates2/gate.rs new file mode 100644 index 0000000..32a6a6b --- /dev/null +++ b/src/gates2/gate.rs @@ -0,0 +1,78 @@ +use std::hash::{Hash, Hasher}; +use std::rc::Rc; + +use crate::{CircuitConfig, ConstraintPolynomial, Field, WitnessGenerator2}; + +/// A custom gate. +// TODO: Remove CircuitConfig params? Could just use fields within each struct. +pub trait Gate2: 'static { + fn id(&self) -> String; + + /// A set of expressions which must evaluate to zero. + fn constraints(&self, config: CircuitConfig) -> Vec>; + + fn generators( + &self, + config: CircuitConfig, + gate_index: usize, + local_constants: Vec, + next_constants: Vec, + ) -> Vec>>; + + /// The number of constants used by this gate. + fn num_constants(&self, config: CircuitConfig) -> usize { + self.constraints(config) + .into_iter() + .map(|c| c.max_constant_index().map_or(0, |i| i + 1)) + .max() + .unwrap_or(0) + } + + /// The minimum number of wires required to use this gate. + fn min_wires(&self, config: CircuitConfig) -> usize { + self.constraints(config) + .into_iter() + .map(|c| c.max_wire_input_index().map_or(0, |i| i + 1)) + .max() + .unwrap_or(0) + } + + /// The maximum degree among this gate's constraint polynomials. + fn degree(&self, config: CircuitConfig) -> usize { + self.constraints(config) + .into_iter() + .map(|c| c.degree()) + .max() + .unwrap_or(0) + } +} + +/// A wrapper around an `Rc` which implements `PartialEq`, `Eq` and `Hash` based on gate IDs. +#[derive(Clone)] +pub struct GateRef(pub(crate) Rc>); + +impl GateRef { + pub fn new>(gate: G) -> GateRef { + GateRef(Rc::new(gate)) + } +} + +impl PartialEq for GateRef { + fn eq(&self, other: &Self) -> bool { + self.0.id() == other.0.id() + } +} + +impl Hash for GateRef { + fn hash(&self, state: &mut H) { + self.0.id().hash(state) + } +} + +impl Eq for GateRef {} + +/// A gate along with any constants used to configure it. +pub struct GateInstance { + pub gate_type: GateRef, + pub constants: Vec, +} diff --git a/src/gates2/limb_sum.rs b/src/gates2/limb_sum.rs new file mode 100644 index 0000000..197fe06 --- /dev/null +++ b/src/gates2/limb_sum.rs @@ -0,0 +1,59 @@ +use crate::{CircuitConfig, ConstraintPolynomial, DeterministicGate, DeterministicGateAdapter, Field, GateRef}; + +/// A gate which takes as inputs limbs of sum small base, verifies that each limb is in `[0, base)`, +/// and outputs the weighted sum `limb[0] + base limb[1] + base^2 limb[2] + ...`. +pub struct LimbSumGate { + base: usize, + num_limbs: usize, +} + +impl LimbSumGate { + pub fn get_ref(base: usize, num_limbs: usize) -> GateRef { + let gate = LimbSumGate { base, num_limbs }; + GateRef::new(DeterministicGateAdapter::new(gate)) + } +} + +impl DeterministicGate for LimbSumGate { + fn id(&self) -> String { + format!("LimbSumGate[base={}, num_limbs={}]", self.base, self.num_limbs) + } + + fn outputs(&self, _config: CircuitConfig) -> Vec<(usize, ConstraintPolynomial)> { + // We compute `out = limb[0] + base * limb[1] + base^2 * limb[2] + ...`. + let out = (0..self.num_limbs).map(|i| { + let limb = ConstraintPolynomial::local_wire_value(i); + let weight = F::from_canonical_usize(self.base).exp_usize(i); + limb * weight + }).sum(); + + vec![(self.num_limbs, out)] + } + + fn additional_constraints(&self, _config: CircuitConfig) -> Vec> { + // For each limb, + (0..self.num_limbs).map(|i| { + let limb = ConstraintPolynomial::local_wire_value(i); + + // Assert that this limb is in `[0, base)` by enforcing that + // `limb (limb - 1) .. (limb - (base - 1)) = 0`. + (0..self.base).map(|possible_value| { + &limb - possible_value + }).product() + }).collect() + } +} + +#[cfg(test)] +mod tests { + use crate::{CircuitBuilder2, CircuitConfig, TweedledumBase}; + + fn valid() { + let config = CircuitConfig { num_wires: 3, num_routed_wires: 3, security_bits: 128 }; + let mut builder = CircuitBuilder2::::new(config); + let zero = builder.zero(); + let one = builder.one(); + let two = builder.two(); + todo!() + } +} diff --git a/src/gates2/mod.rs b/src/gates2/mod.rs new file mode 100644 index 0000000..acd1cc4 --- /dev/null +++ b/src/gates2/mod.rs @@ -0,0 +1,23 @@ +pub use arithmetic::*; +pub use buffer::*; +pub use constant::*; +pub use curve_add::*; +pub use curve_dbl::*; +pub use curve_endo::*; +pub use deterministic_gate::*; +pub use gate::*; +pub use limb_sum::*; +pub use public_input::*; +pub use rescue::*; + +mod arithmetic; +mod buffer; +mod constant; +mod curve_add; +mod curve_dbl; +mod curve_endo; +mod deterministic_gate; +mod gate; +mod limb_sum; +mod public_input; +mod rescue; diff --git a/src/gates2/public_input.rs b/src/gates2/public_input.rs new file mode 100644 index 0000000..7345b0d --- /dev/null +++ b/src/gates2/public_input.rs @@ -0,0 +1,47 @@ +use crate::{CircuitConfig, ConstraintPolynomial, Field, Gate2, GateRef, WitnessGenerator2}; + +/// A gate for receiving public inputs. These gates will be placed at static indices and the wire +/// polynomials will always be opened at those indices. +/// +/// Each `PublicInputGate` can receive `max(num_wires, 2 * num_routed_wires)` public inputs. If all +/// wires are routed, each `PublicInputGate` simply receives `num_wires` public inputs, but if not, +/// it gets a bit more complex. We place a `BufferGate` after each `PublicInputGate`, then "copy" +/// any non-routed public inputs from the `PublicInputGate` to the routed wires of the following +/// `BufferGate`. These routed `BufferGate` wires can then be used to route public inputs elsewhere. +pub struct PublicInputGate2; + +impl Gate2 for PublicInputGate2 { + fn id(&self) -> String { + "PublicInputGate".into() + } + + fn constraints(&self, config: CircuitConfig) -> Vec> { + let routed_pis = config.num_routed_wires; + let non_routed_pis = config.advice_wires().min(routed_pis); + + // For each non-routed PI, we "copy" that PI to the following gate, which should be a + // `BufferGate` just for receiving these these PIs and making them routable. + (0..non_routed_pis).map(|i| { + let non_routed_pi_wire = ConstraintPolynomial::local_wire_value(routed_pis + i); + let routed_receiving_wire = ConstraintPolynomial::next_wire_value(i); + non_routed_pi_wire - routed_receiving_wire + }).collect() + } + + fn generators( + &self, + _config: CircuitConfig, + _gate_index: usize, + _local_constants: Vec, + _next_constants: Vec, + ) -> Vec>> { + // CircuitBuilder handles copying public input values around. + Vec::new() + } +} + +impl PublicInputGate2 { + pub fn get_ref() -> GateRef { + GateRef::new(PublicInputGate2) + } +} diff --git a/src/gates2/rescue.rs b/src/gates2/rescue.rs new file mode 100644 index 0000000..c54e83a --- /dev/null +++ b/src/gates2/rescue.rs @@ -0,0 +1,177 @@ +use crate::{apply_mds, apply_mds_constraint_polys, CircuitConfig, ConstraintPolynomial, Field, Gate2, GateRef, PartialWitness2, SimpleGenerator, Target2, Wire, WitnessGenerator2}; + +/// Implements a round of the Rescue permutation, modified with a different key schedule to reduce +/// the number of constants involved. +#[derive(Copy, Clone)] +pub struct ModifiedRescueGate { + width: usize, + alpha: usize, +} + +impl ModifiedRescueGate { + pub fn get_ref(width: usize, alpha: usize) -> GateRef { + GateRef::new(ModifiedRescueGate { width, alpha }) + } + + /// Returns the index of the `i`th accumulator wire. These act as both input and output wires. + pub fn wire_acc(&self, i: usize) -> usize { + debug_assert!(i < self.width); + i + } + + /// Returns the index of the `i`th (purported) root wire. + pub fn wire_root(&self, i: usize) -> usize { + debug_assert!(i < self.width); + self.width + i + } +} + +impl Gate2 for ModifiedRescueGate { + fn id(&self) -> String { + format!("ModifiedRescueGate[width={}, alpha={}]", self.width, self.alpha) + } + + fn constraints(&self, _config: CircuitConfig) -> Vec> { + let w = self.width; + + // Load the input layer variables. + let layer_0 = (0..w) + .map(|i| ConstraintPolynomial::local_wire_value(self.wire_acc(i))) + .collect::>(); + + // Load the (purported) alpha'th root layer variables. + let layer_1 = (0..w) + .map(|i| ConstraintPolynomial::local_wire_value(self.wire_root(i))) + .collect::>(); + + // Compute the input layer from the alpha'th root layer, so that we can verify the + // (purported) roots. + let computed_inputs = layer_1.iter() + .map(|x| x.exp(self.alpha)) + .collect::>(); + + // Apply the MDS matrix. + let layer_2 = apply_mds_constraint_polys(layer_1); + + // Add a constant to the first element. + let mut layer_3 = layer_2; + layer_3[0] = &layer_3[0] + ConstraintPolynomial::local_constant(0); + + // Raise to the alpha'th power. + let layer_4 = layer_3.iter() + .map(|x| x.exp(self.alpha)) + .collect::>(); + + // Apply the MDS matrix. + let layer_5 = apply_mds_constraint_polys(layer_4); + + // Add a constant to the first element. + let mut layer_6 = layer_5; + layer_6[0] = &layer_6[0] + ConstraintPolynomial::local_constant(1); + + let mut constraints = Vec::new(); + for i in 0..w { + // Check that the computed input matches the actual input. + constraints.push(&computed_inputs[i] - &layer_0[i]); + + // Check that the computed output matches the actual output. + let actual_out = ConstraintPolynomial::next_wire_value(self.wire_acc(i)); + constraints.push(&layer_6[i] - actual_out); + } + constraints + } + + fn generators( + &self, + _config: CircuitConfig, + gate_index: usize, + local_constants: Vec, + _next_constants: Vec, + ) -> Vec>> { + let gen = ModifiedRescueGenerator:: { + gate: *self, + gate_index, + constants: local_constants.clone(), + }; + vec![Box::new(gen)] + } +} + +struct ModifiedRescueGenerator { + gate: ModifiedRescueGate, + gate_index: usize, + constants: Vec, +} + +impl SimpleGenerator for ModifiedRescueGenerator { + fn dependencies(&self) -> Vec> { + (0..self.gate.width) + .map(|i| Target2::Wire(Wire { gate: self.gate_index, input: self.gate.wire_acc(i) })) + .collect() + } + + fn run_once(&mut self, witness: &PartialWitness2) -> PartialWitness2 { + let w = self.gate.width; + + // Load inputs. + let layer_0 = (0..w) + .map(|i| witness.get_wire( + Wire { gate: self.gate_index, input: self.gate.wire_acc(i) })) + .collect::>(); + + // Take alpha'th roots. + let layer_1 = layer_0.iter() + .map(|x| x.kth_root_usize(self.gate.alpha)) + .collect::>(); + let layer_roots = layer_1.clone(); + + // Apply MDS matrix. + let layer_2 = apply_mds(layer_1); + + // Add a constant to the first element. + let mut layer_3 = layer_2; + layer_3[0] = layer_3[0] + self.constants[0]; + + // Raise to the alpha'th power. + let layer_4 = layer_3.iter() + .map(|x| x.exp_usize(self.gate.alpha)) + .collect::>(); + + // Apply MDS matrix. + let layer_5 = apply_mds(layer_4); + + // Add a constant to the first element. + let mut layer_6 = layer_5; + layer_6[0] = layer_6[0] + self.constants[1]; + + let mut result = PartialWitness2::new(); + for i in 0..w { + // Set the i'th root wire. + result.set_wire( + Wire { gate: self.gate_index, input: self.gate.wire_root(i) }, + layer_roots[i]); + // Set the i'th output wire. + result.set_wire( + Wire { gate: self.gate_index + 1, input: self.gate.wire_acc(i) }, + layer_6[i]); + } + result + } +} + +#[cfg(test)] +mod tests { + use crate::{CircuitConfig, ModifiedRescueGate, TweedledumBase}; + + #[test] + fn rescue_gate_degree() { + let config = CircuitConfig { + num_wires: 10, + num_routed_wires: 10, + security_bits: 128, + }; + + let gate = ModifiedRescueGate::get_ref::(4, 5); + assert_eq!(gate.0.degree(config), 5); + } +} diff --git a/src/lib.rs b/src/lib.rs index 42cb6e4..b198ae2 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -12,7 +12,6 @@ // Unfortunately it makes rustc complain, so we include #![allow(incomplete_features)] - pub use bigint::*; pub use circuit_bigint::*; pub use circuit_builder::*; @@ -23,11 +22,14 @@ pub use conversions::*; pub use curve::*; pub use fft::*; pub use field::*; +pub use gadgets::*; pub use gates::*; +pub use gates2::*; pub use hash_to_curve::*; pub use mds::*; pub use partition::*; pub use plonk::*; +pub use plonk2::*; pub use plonk_proof::*; pub use plonk_recursion::*; pub use poly_commit::*; @@ -49,12 +51,15 @@ mod conversions; mod curve; mod fft; mod field; +mod gadgets; mod gates; +mod gates2; pub mod halo; mod hash_to_curve; mod mds; mod partition; mod plonk; +mod plonk2; pub mod plonk_challenger; mod plonk_proof; mod plonk_recursion; diff --git a/src/mds.rs b/src/mds.rs index a8fc6eb..cd178c6 100644 --- a/src/mds.rs +++ b/src/mds.rs @@ -1,4 +1,4 @@ -use crate::Field; +use crate::{Field, ConstraintPolynomial}; use std::any::TypeId; use once_cell::sync::Lazy; use std::sync::Mutex; @@ -28,6 +28,12 @@ pub struct MdsMatrix { } impl MdsMatrix { + /// Returns the width and height of this matrix. + pub fn size(&self) -> usize { + self.unparameterized.rows.len() + } + + /// Returns the entry at row `r` and column `c`. pub fn get(&self, r: usize, c: usize) -> F { F::from_canonical_u64_vec(self.unparameterized.rows[r][c].clone()) } @@ -39,17 +45,30 @@ struct UnparameterizedMdsMatrix { rows: Vec>>, } -/// Apply an MDS matrix to the given state vector. -pub(crate) fn apply_mds(inputs: Vec) -> Vec { - let n = inputs.len(); - let mut result = vec![F::ZERO; n]; +/// Apply an MDS matrix to the given vector of field elements. +pub(crate) fn apply_mds(vec: Vec) -> Vec { + let n = vec.len(); let mds = mds_matrix::(n); - for r in 0..n { - for c in 0..n { - result[r] = result[r] + mds.get(r, c) * inputs[c]; - } - } - result + + (0..n) + .map(|r| (0..n) + .map(|c| mds.get(r, c) * vec[c]) + .fold(F::ZERO, |acc, x| acc + x)) + .collect() +} + +/// Applies an MDS matrix to the given vector of constraint polynomials. +pub(crate) fn apply_mds_constraint_polys( + vec: Vec>, +) -> Vec> { + let n = vec.len(); + let mds = mds_matrix::(n); + + (0..n) + .map(|r| (0..n) + .map(|c| &vec[c] * mds.get(r, c)) + .sum()) + .collect() } /// Returns entry `(r, c)` of an `n` by `n` MDS matrix. diff --git a/src/plonk2/circuit_builder.rs b/src/plonk2/circuit_builder.rs new file mode 100644 index 0000000..e33f12c --- /dev/null +++ b/src/plonk2/circuit_builder.rs @@ -0,0 +1,88 @@ +use std::collections::HashSet; + +use crate::{CircuitConfig, ConstantGate2, CopyGenerator, Field, GateInstance, GateRef, Target2, Wire, WitnessGenerator2}; + +pub struct CircuitBuilder2 { + config: CircuitConfig, + gates: HashSet>, + gate_instances: Vec>, + generators: Vec>>, +} + +impl CircuitBuilder2 { + pub fn new(config: CircuitConfig) -> Self { + CircuitBuilder2 { + config, + gates: HashSet::new(), + gate_instances: Vec::new(), + generators: Vec::new(), + } + } + + /// Adds a gate to the circuit, and returns its index. + pub fn add_gate(&mut self, gate_type: GateRef, constants: Vec) -> usize { + // If we haven't seen a gate of this type before, check that it's compatible with our + // circuit configuration, then register it. + if !self.gates.contains(&gate_type) { + self.check_gate_compatibility(&gate_type); + self.gates.insert(gate_type.clone()); + } + + let index = self.gate_instances.len(); + self.gate_instances.push(GateInstance { gate_type, constants }); + index + } + + fn check_gate_compatibility(&self, gate: &GateRef) { + assert!(gate.0.min_wires(self.config) <= self.config.num_wires); + } + + /// Shorthand for `generate_copy` and `assert_equal`. + /// Both elements must be routable, otherwise this method will panic. + pub fn route(&mut self, src: Target2, dst: Target2) { + self.generate_copy(src, dst); + self.assert_equal(src, dst); + } + + /// Adds a generator which will copy `src` to `dst`. + pub fn generate_copy(&mut self, src: Target2, dst: Target2) { + self.add_generator(CopyGenerator { src, dst }); + } + + /// Uses Plonk's permutation argument to require that two elements be equal. + /// Both elements must be routable, otherwise this method will panic. + pub fn assert_equal(&mut self, x: Target2, y: Target2) { + assert!(x.is_routable()); + assert!(y.is_routable()); + } + + pub fn add_generator>(&mut self, generator: G) { + self.generators.push(Box::new(generator)); + } + + /// Returns a routable target with a value of 0. + pub fn zero(&mut self) -> Target2 { + self.constant(F::ZERO) + } + + /// Returns a routable target with a value of 1. + pub fn one(&mut self) -> Target2 { + self.constant(F::ONE) + } + + /// Returns a routable target with a value of 2. + pub fn two(&mut self) -> Target2 { + self.constant(F::TWO) + } + + /// Returns a routable target with a value of `ORDER - 1`. + pub fn neg_one(&mut self) -> Target2 { + self.constant(F::NEG_ONE) + } + + /// Returns a routable target with the given constant value. + pub fn constant(&mut self, c: F) -> Target2 { + let gate = self.add_gate(ConstantGate2::get_ref(), vec![c]); + Target2::Wire(Wire { gate, input: ConstantGate2::WIRE_OUTPUT }) + } +} diff --git a/src/plonk2/circuit_data.rs b/src/plonk2/circuit_data.rs new file mode 100644 index 0000000..03591a4 --- /dev/null +++ b/src/plonk2/circuit_data.rs @@ -0,0 +1,88 @@ +use crate::{FftPrecomputation, MsmPrecomputation, AffinePoint, Curve, Proof2}; +use crate::plonk2::prover::prove2; +use crate::plonk2::verifier::verify2; + +#[derive(Copy, Clone)] +pub struct CircuitConfig { + pub num_wires: usize, + pub num_routed_wires: usize, + pub security_bits: usize, +} + +impl CircuitConfig { + pub fn advice_wires(&self) -> usize { + self.num_wires - self.num_routed_wires + } +} + +/// Circuit data required by the prover or the verifier. +pub struct CircuitData { + prover_only: ProverOnlyCircuitData, + verifier_only: VerifierOnlyCircuitData, + common: CommonCircuitData, +} + +impl CircuitData { + pub fn prove2(&self) -> Proof2 { + prove2(&self.prover_only, &self.common) + } + + pub fn verify2(&self) { + verify2(&self.verifier_only, &self.common) + } +} + +/// Circuit data required by the prover. +pub struct ProverCircuitData { + prover_only: ProverOnlyCircuitData, + common: CommonCircuitData, +} + +impl ProverCircuitData { + pub fn prove2(&self) -> Proof2 { + prove2(&self.prover_only, &self.common) + } +} + +/// Circuit data required by the prover. +pub struct VerifierCircuitData { + verifier_only: VerifierOnlyCircuitData, + common: CommonCircuitData, +} + +impl VerifierCircuitData { + pub fn verify2(&self) { + verify2(&self.verifier_only, &self.common) + } +} + +/// Circuit data required by the prover, but not the verifier. +pub(crate) struct ProverOnlyCircuitData { + /// A precomputation used for FFTs of degree 8n, where n is the number of gates. + pub fft_precomputation_8n: FftPrecomputation, + + /// A precomputation used for MSMs involving `generators`. + pub pedersen_g_msm_precomputation: MsmPrecomputation, +} + +/// Circuit data required by the verifier, but not the prover. +pub(crate) struct VerifierOnlyCircuitData {} + +/// Circuit data required by both the prover and the verifier. +pub(crate) struct CommonCircuitData { + pub config: CircuitConfig, + + pub degree: usize, + + /// A commitment to each constant polynomial. + pub c_constants: Vec>, + + /// A commitment to each permutation polynomial. + pub c_s_sigmas: Vec>, + + /// A precomputation used for MSMs involving `generators`. + pub pedersen_g_msm_precomputation: MsmPrecomputation, + + /// A precomputation used for FFTs of degree n, where n is the number of gates. + pub fft_precomputation_n: FftPrecomputation, +} diff --git a/src/plonk2/constraint_polynomial.rs b/src/plonk2/constraint_polynomial.rs new file mode 100644 index 0000000..ecb9272 --- /dev/null +++ b/src/plonk2/constraint_polynomial.rs @@ -0,0 +1,419 @@ +use std::collections::{HashMap, HashSet}; +use std::hash::{Hash, Hasher}; +use std::iter::{Product, Sum}; +use std::ops::{Add, Mul, Neg, Sub}; +use std::ptr; +use std::rc::Rc; + +use crate::{Field, Wire}; + +pub(crate) struct EvaluationVars<'a, F: Field> { + pub(crate) local_constants: &'a [F], + pub(crate) next_constants: &'a [F], + pub(crate) local_wire_values: &'a [F], + pub(crate) next_wire_values: &'a [F], +} + +/// A polynomial over all the variables that are subject to constraints (local constants, next +/// constants, local wire values, and next wire values). This representation does not require any +/// particular form; it permits arbitrary forms such as `(x + 1)^3 + y z`. +// Implementation note: This is a wrapper because we want to hide complexity behind +// `ConstraintPolynomialInner` and `ConstraintPolynomialRef`. In particular, the caller shouldn't +// need to know that we use reference counting internally, and shouldn't have to deal with wrapper +// types related to reference counting. +#[derive(Clone)] +pub struct ConstraintPolynomial(ConstraintPolynomialRef); + +impl ConstraintPolynomial { + pub fn constant(c: F) -> Self { + Self::from_inner(ConstraintPolynomialInner::Constant(c)) + } + + pub fn constant_usize(c: usize) -> Self { + Self::constant(F::from_canonical_usize(c)) + } + + pub fn zero() -> Self { + Self::constant(F::ZERO) + } + + pub fn one() -> Self { + Self::constant(F::ONE) + } + + pub fn local_constant(index: usize) -> Self { + Self::from_inner(ConstraintPolynomialInner::LocalConstant(index)) + } + + pub fn next_constant(index: usize) -> Self { + Self::from_inner(ConstraintPolynomialInner::NextConstant(index)) + } + + pub fn local_wire_value(index: usize) -> Self { + Self::from_inner(ConstraintPolynomialInner::LocalWireValue(index)) + } + + pub fn next_wire_value(index: usize) -> Self { + Self::from_inner(ConstraintPolynomialInner::NextWireValue(index)) + } + + // TODO: Have these take references? + pub fn add(&self, rhs: &Self) -> Self { + // TODO: Special case for either operand being 0. + Self::from_inner(ConstraintPolynomialInner::Sum { + lhs: self.0.clone(), + rhs: rhs.0.clone(), + }) + } + + pub fn sub(&self, rhs: &Self) -> Self { + // TODO: Special case for either operand being 0. + // TODO: Faster to have a dedicated ConstraintPolynomialInner::Difference? + // TODO: `self + -rhs`? + self.add(&rhs.neg()) + } + + pub fn double(&self) -> Self { + self.clone().add(self) + } + + pub fn triple(&self) -> Self { + self * 3 + } + + pub fn mul(&self, rhs: &Self) -> Self { + // TODO: Special case for either operand being 1. + Self::from_inner(ConstraintPolynomialInner::Product { + lhs: self.0.clone(), + rhs: rhs.0.clone(), + }) + } + + pub fn exp(&self, exponent: usize) -> Self { + Self::from_inner(ConstraintPolynomialInner::Exponentiation { + base: self.0.clone(), + exponent, + }) + } + + pub fn square(&self) -> Self { + self * self + } + + pub(crate) fn degree(&self) -> usize { + (self.0).0.degree() + } + + /// Returns the set of wires that this constraint would depend on if it were applied at a + /// certain gate index. + pub(crate) fn dependencies(&self, gate: usize) -> Vec { + let mut deps = HashSet::new(); + self.0.0.add_dependencies(gate, &mut deps); + deps.into_iter().collect() + } + + /// Find the largest input index among the wires this constraint depends on. + pub(crate) fn max_wire_input_index(&self) -> Option { + self.dependencies(0) + .into_iter() + .map(|wire| wire.input) + .max() + } + + pub(crate) fn max_constant_index(&self) -> Option { + let mut indices = HashSet::new(); + self.0.0.add_constant_indices(&mut indices); + indices.into_iter().max() + } + + pub(crate) fn evaluate(&self, vars: EvaluationVars) -> F { + let results = Self::evaluate_all(&[self.clone()], vars); + assert_eq!(results.len(), 1); + results[0] + } + + /// Evaluate multiple constraint polynomials simultaneously. This can be more efficient than + /// evaluating them sequentially, since shared intermediate results will only be computed once. + pub(crate) fn evaluate_all( + polynomials: &[ConstraintPolynomial], + vars: EvaluationVars, + ) -> Vec { + let mut mem = HashMap::new(); + polynomials.iter() + .map(|p| p.0.evaluate_memoized(&vars, &mut mem)) + .collect() + } + + fn from_inner(inner: ConstraintPolynomialInner) -> Self { + Self(ConstraintPolynomialRef::new(inner)) + } +} + +impl Neg for ConstraintPolynomial { + type Output = Self; + + fn neg(self) -> Self { + // TODO: Faster to have a dedicated ConstraintPolynomialInner::Negation? + self * ConstraintPolynomial::constant(F::NEG_ONE) + } +} + +impl Neg for &ConstraintPolynomial { + type Output = ConstraintPolynomial; + + fn neg(self) -> ConstraintPolynomial { + self.clone().neg() + } +} + +/// Generates the following variants of a binary operation: +/// - `Self . Self` +/// - `&Self . Self` +/// - `Self . &Self` +/// - `&Self . &Self` +/// - `Self . F` +/// - `&Self . F` +/// - `Self . usize` +/// - `&Self . usize` +/// where `Self` is `ConstraintPolynomial`. +/// +/// Takes the following arguments: +/// - `$trait`: the name of the binary operation trait to implement +/// - `$method`: the name of the method in the trait. It is assumed that `ConstraintPolynomial` +/// contains a method with the same name, implementing the `Self . Self` variant. +macro_rules! binop_variants { + ($trait:ident, $method:ident) => { + impl $trait for ConstraintPolynomial { + type Output = Self; + + fn $method(self, rhs: Self) -> Self { + ConstraintPolynomial::$method(&self, &rhs) + } + } + + impl $trait<&Self> for ConstraintPolynomial { + type Output = Self; + + fn $method(self, rhs: &Self) -> Self { + ConstraintPolynomial::$method(&self, rhs) + } + } + + impl $trait> for &ConstraintPolynomial { + type Output = ConstraintPolynomial; + + fn $method(self, rhs: ConstraintPolynomial) -> Self::Output { + ConstraintPolynomial::$method(self, &rhs) + } + } + + impl $trait for &ConstraintPolynomial { + type Output = ConstraintPolynomial; + + fn $method(self, rhs: Self) -> Self::Output { + ConstraintPolynomial::$method(self, rhs) + } + } + + impl $trait for ConstraintPolynomial { + type Output = Self; + + fn $method(self, rhs: F) -> Self { + ConstraintPolynomial::$method(&self, &ConstraintPolynomial::constant(rhs)) + } + } + + impl $trait for &ConstraintPolynomial { + type Output = ConstraintPolynomial; + + fn $method(self, rhs: F) -> Self::Output { + ConstraintPolynomial::$method(self, &ConstraintPolynomial::constant(rhs)) + } + } + + impl $trait for ConstraintPolynomial { + type Output = Self; + + fn $method(self, rhs: usize) -> Self { + ConstraintPolynomial::$method(&self, &ConstraintPolynomial::constant_usize(rhs)) + } + } + + impl $trait for &ConstraintPolynomial { + type Output = ConstraintPolynomial; + + fn $method(self, rhs: usize) -> Self::Output { + ConstraintPolynomial::$method(self, &ConstraintPolynomial::constant_usize(rhs)) + } + } + }; +} + +binop_variants!(Add, add); +binop_variants!(Sub, sub); +binop_variants!(Mul, mul); + +impl Sum for ConstraintPolynomial { + fn sum>(iter: I) -> Self { + iter.fold( + ConstraintPolynomial::zero(), + |sum, x| sum + x) + } +} + +impl Product for ConstraintPolynomial { + fn product>(iter: I) -> Self { + iter.fold( + ConstraintPolynomial::one(), + |product, x| product * x) + } +} + +enum ConstraintPolynomialInner { + Constant(F), + + LocalConstant(usize), + NextConstant(usize), + LocalWireValue(usize), + NextWireValue(usize), + + Sum { + lhs: ConstraintPolynomialRef, + rhs: ConstraintPolynomialRef, + }, + Product { + lhs: ConstraintPolynomialRef, + rhs: ConstraintPolynomialRef, + }, + Exponentiation { + base: ConstraintPolynomialRef, + exponent: usize, + }, +} + +impl ConstraintPolynomialInner { + fn add_dependencies(&self, gate: usize, deps: &mut HashSet) { + match self { + ConstraintPolynomialInner::Constant(_) => (), + ConstraintPolynomialInner::LocalConstant(_) => (), + ConstraintPolynomialInner::NextConstant(_) => (), + ConstraintPolynomialInner::LocalWireValue(i) => + { deps.insert(Wire { gate, input: *i }); }, + ConstraintPolynomialInner::NextWireValue(i) => + { deps.insert(Wire { gate: gate + 1, input: *i }); } + ConstraintPolynomialInner::Sum { lhs, rhs } => { + lhs.0.add_dependencies(gate, deps); + rhs.0.add_dependencies(gate, deps); + }, + ConstraintPolynomialInner::Product { lhs, rhs } => { + lhs.0.add_dependencies(gate, deps); + rhs.0.add_dependencies(gate, deps); + }, + ConstraintPolynomialInner::Exponentiation { base, exponent: _ } => { + base.0.add_dependencies(gate, deps); + }, + } + } + + fn add_constant_indices(&self, indices: &mut HashSet) { + match self { + ConstraintPolynomialInner::Constant(_) => (), + ConstraintPolynomialInner::LocalConstant(i) => { indices.insert(*i); }, + ConstraintPolynomialInner::NextConstant(i) => { indices.insert(*i); }, + ConstraintPolynomialInner::LocalWireValue(_) => (), + ConstraintPolynomialInner::NextWireValue(_) => (), + ConstraintPolynomialInner::Sum { lhs, rhs } => { + lhs.0.add_constant_indices(indices); + rhs.0.add_constant_indices(indices); + }, + ConstraintPolynomialInner::Product { lhs, rhs } => { + lhs.0.add_constant_indices(indices); + rhs.0.add_constant_indices(indices); + }, + ConstraintPolynomialInner::Exponentiation { base, exponent: _ } => { + base.0.add_constant_indices(indices); + }, + } + } + + fn evaluate( + &self, + vars: &EvaluationVars, + mem: &mut HashMap, F>, + ) -> F { + match self { + ConstraintPolynomialInner::Constant(c) => *c, + ConstraintPolynomialInner::LocalConstant(i) => vars.local_constants[*i], + ConstraintPolynomialInner::NextConstant(i) => vars.next_constants[*i], + ConstraintPolynomialInner::LocalWireValue(i) => vars.local_wire_values[*i], + ConstraintPolynomialInner::NextWireValue(i) => vars.next_wire_values[*i], + ConstraintPolynomialInner::Sum { lhs, rhs } => { + let lhs = lhs.evaluate_memoized(vars, mem); + let rhs = rhs.evaluate_memoized(vars, mem); + lhs + rhs + }, + ConstraintPolynomialInner::Product { lhs, rhs } => { + let lhs = lhs.evaluate_memoized(vars, mem); + let rhs = rhs.evaluate_memoized(vars, mem); + lhs * rhs + }, + ConstraintPolynomialInner::Exponentiation { base, exponent } => { + let base = base.evaluate_memoized(vars, mem); + base.exp_usize(*exponent) + }, + } + } + + fn degree(&self) -> usize { + match self { + ConstraintPolynomialInner::Constant(_) => 0, + ConstraintPolynomialInner::LocalConstant(_) => 1, + ConstraintPolynomialInner::NextConstant(_) => 1, + ConstraintPolynomialInner::LocalWireValue(_) => 1, + ConstraintPolynomialInner::NextWireValue(_) => 1, + ConstraintPolynomialInner::Sum { lhs, rhs } => lhs.0.degree().max(rhs.0.degree()), + ConstraintPolynomialInner::Product { lhs, rhs } => lhs.0.degree() + rhs.0.degree(), + ConstraintPolynomialInner::Exponentiation { base, exponent } => base.0.degree() * exponent, + } + } +} + +/// Wraps `Rc`, and implements `Hash` and `Eq` based on references rather +/// than content. This is useful when we want to use constraint polynomials as `HashMap` keys, but +/// we want address-based hashing for performance reasons. +#[derive(Clone)] +struct ConstraintPolynomialRef(Rc>); + +impl ConstraintPolynomialRef { + fn new(inner: ConstraintPolynomialInner) -> Self { + Self(Rc::new(inner)) + } + + fn evaluate_memoized( + &self, + vars: &EvaluationVars, + mem: &mut HashMap, + ) -> F { + if let Some(&result) = mem.get(self) { + result + } else { + let result = self.0.evaluate(vars, mem); + mem.insert(self.clone(), result); + result + } + } +} + +impl PartialEq for ConstraintPolynomialRef { + fn eq(&self, other: &Self) -> bool { + ptr::eq(&*self.0, &*other.0) + } +} + +impl Eq for ConstraintPolynomialRef {} + +impl Hash for ConstraintPolynomialRef { + fn hash(&self, state: &mut H) { + ptr::hash(&*self.0, state); + } +} diff --git a/src/plonk2/gate.rs b/src/plonk2/gate.rs new file mode 100644 index 0000000..e69de29 diff --git a/src/plonk2/generator.rs b/src/plonk2/generator.rs new file mode 100644 index 0000000..a67d9cb --- /dev/null +++ b/src/plonk2/generator.rs @@ -0,0 +1,52 @@ +use crate::{Field, PartialWitness2, Target2}; + +/// A generator participates in the generation of the witness. +pub trait WitnessGenerator2: 'static { + /// Targets to be "watched" by this generator. Whenever a target in the watch list is populated, + /// the generator will be queued to run. + fn watch_list(&self) -> Vec>; + + /// Run this generator, returning a `PartialWitness` containing any new witness elements, and a + /// flag indicating whether the generator is finished. If the flag is true, the generator will + /// never be run again, otherwise it will be queued for another run next time a target in its + /// watch list is populated. + fn run(&mut self, witness: &PartialWitness2) -> (PartialWitness2, bool); +} + +/// A generator which runs once after a list of dependencies is present in the witness. +pub trait SimpleGenerator: 'static { + fn dependencies(&self) -> Vec>; + + fn run_once(&mut self, witness: &PartialWitness2) -> PartialWitness2; +} + +impl> WitnessGenerator2 for SG { + fn watch_list(&self) -> Vec> { + self.dependencies() + } + + fn run(&mut self, witness: &PartialWitness2) -> (PartialWitness2, bool) { + if witness.contains_all(&self.dependencies()) { + (self.run_once(witness), true) + } else { + (PartialWitness2::new(), false) + } + } +} + +/// A generator which copies one wire to another. +pub(crate) struct CopyGenerator { + pub(crate) src: Target2, + pub(crate) dst: Target2, +} + +impl SimpleGenerator for CopyGenerator { + fn dependencies(&self) -> Vec> { + vec![self.src] + } + + fn run_once(&mut self, witness: &PartialWitness2) -> PartialWitness2 { + let value = witness.get_target(self.src); + PartialWitness2::singleton(self.dst, value) + } +} diff --git a/src/plonk2/mod.rs b/src/plonk2/mod.rs new file mode 100644 index 0000000..b8f8422 --- /dev/null +++ b/src/plonk2/mod.rs @@ -0,0 +1,20 @@ +pub use circuit_builder::*; +pub use circuit_data::*; +pub use constraint_polynomial::*; +pub use gate::*; +pub use generator::*; +pub use proof::*; +pub use target::*; +pub use witness::*; + +mod circuit_builder; +mod circuit_data; +mod constraint_polynomial; +mod gate; +mod generator; +mod partitions; +mod proof; +mod prover; +mod target; +mod verifier; +mod witness; diff --git a/src/plonk2/partitions.rs b/src/plonk2/partitions.rs new file mode 100644 index 0000000..f73905c --- /dev/null +++ b/src/plonk2/partitions.rs @@ -0,0 +1,49 @@ +use std::collections::HashMap; + +use crate::{Field, Target2}; + +#[derive(Debug)] +pub(crate) struct Partitions2 { + partitions: Vec>>, + indices: HashMap, usize>, +} + +impl Partitions2 { + pub fn new() -> Self { + Self { + partitions: Vec::new(), + indices: HashMap::new(), + } + } + + /// Adds the targets as new singleton partitions if they are not already present, then merges + /// their partitions if the targets are not already in the same partition. + pub fn merge(&mut self, a: Target2, b: Target2) { + let a_index = self.get_index(a); + let b_index = self.get_index(b); + + if a_index != b_index { + // Merge a's partition into b's partition, leaving a's partition empty. + // We have to clone because Rust's borrow checker doesn't know that + // self.partitions[b_index] and self.partitions[b_index] are disjoint. + let mut a_partition = self.partitions[a_index].clone(); + let b_partition = &mut self.partitions[b_index]; + for a_sibling in &a_partition { + *self.indices.get_mut(a_sibling).unwrap() = b_index; + } + b_partition.append(&mut a_partition); + } + } + + /// Gets the partition index of a given target. If the target is not present, adds it as a new + /// singleton partition and returns the new partition's index. + fn get_index(&mut self, target: Target2) -> usize { + if let Some(&index) = self.indices.get(&target) { + index + } else { + let index = self.partitions.len(); + self.partitions.push(vec![target]); + index + } + } +} diff --git a/src/plonk2/proof.rs b/src/plonk2/proof.rs new file mode 100644 index 0000000..e3e451a --- /dev/null +++ b/src/plonk2/proof.rs @@ -0,0 +1,3 @@ +pub struct Proof2 {} + +pub struct ProofTarget2 {} diff --git a/src/plonk2/prover.rs b/src/plonk2/prover.rs new file mode 100644 index 0000000..4c12c46 --- /dev/null +++ b/src/plonk2/prover.rs @@ -0,0 +1,8 @@ +use crate::{Curve, ProverOnlyCircuitData, CommonCircuitData, Proof2}; + +pub(crate) fn prove2( + prover_data: &ProverOnlyCircuitData, + common_data: &CommonCircuitData, +) -> Proof2 { + todo!() +} diff --git a/src/plonk2/target.rs b/src/plonk2/target.rs new file mode 100644 index 0000000..c519370 --- /dev/null +++ b/src/plonk2/target.rs @@ -0,0 +1,29 @@ +use std::convert::Infallible; +use std::marker::PhantomData; + +use crate::{Field, Wire}; + +/// A location in the witness. +#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)] +pub enum Target2 { + Wire(Wire), + PublicInput { index: usize }, + VirtualAdviceTarget { index: usize }, + // Trick taken from https://github.com/rust-lang/rust/issues/32739#issuecomment-627765543. + _Field(Infallible, PhantomData), +} + +impl Target2 { + pub fn wire(gate: usize, input: usize) -> Self { + Self::Wire(Wire { gate, input }) + } + + pub fn is_routable(&self) -> bool { + match self { + Target2::Wire(wire) => wire.is_routable(), + Target2::PublicInput { .. } => true, + Target2::VirtualAdviceTarget { .. } => false, + Target2::_Field(_, _) => unreachable!(), + } + } +} diff --git a/src/plonk2/verifier.rs b/src/plonk2/verifier.rs new file mode 100644 index 0000000..9b31fda --- /dev/null +++ b/src/plonk2/verifier.rs @@ -0,0 +1,8 @@ +use crate::{Curve, VerifierOnlyCircuitData, CommonCircuitData}; + +pub(crate) fn verify2( + verifier_data: &VerifierOnlyCircuitData, + common_data: &CommonCircuitData, +) { + todo!() +} diff --git a/src/plonk2/witness.rs b/src/plonk2/witness.rs new file mode 100644 index 0000000..41196a9 --- /dev/null +++ b/src/plonk2/witness.rs @@ -0,0 +1,54 @@ +use std::collections::HashMap; + +use crate::{Field, Target2, Wire}; + +#[derive(Debug)] +pub struct PartialWitness2 { + target_values: HashMap, F>, +} + +impl PartialWitness2 { + pub fn new() -> Self { + PartialWitness2 { + target_values: HashMap::new(), + } + } + + pub fn singleton(target: Target2, value: F) -> Self { + let mut witness = PartialWitness2::new(); + witness.set_target(target, value); + witness + } + + pub fn is_empty(&self) -> bool { + self.target_values.is_empty() + } + + pub fn get_target(&self, target: Target2) -> F { + self.target_values[&target] + } + + pub fn try_get_target(&self, target: Target2) -> Option { + self.target_values.get(&target).cloned() + } + + pub fn get_wire(&self, wire: Wire) -> F { + self.get_target(Target2::Wire(wire)) + } + + pub fn contains(&self, target: Target2) -> bool { + self.target_values.contains_key(&target) + } + + pub fn contains_all(&self, targets: &[Target2]) -> bool { + targets.iter().all(|&t| self.contains(t)) + } + + pub fn set_target(&mut self, target: Target2, value: F) { + self.target_values.insert(target, value); + } + + pub fn set_wire(&mut self, wire: Wire, value: F) { + self.set_target(Target2::Wire(wire), value) + } +} diff --git a/src/rescue.rs b/src/rescue.rs index 6cd74fd..2aa0bb9 100644 --- a/src/rescue.rs +++ b/src/rescue.rs @@ -99,6 +99,7 @@ pub(crate) fn generate_rescue_constants( security_bits: usize, ) -> Vec<(Vec, Vec)> { // TODO: This should use deterministic randomness. + // TODO: Reject subgroup elements. // FIX: Use ChaCha CSPRNG with a seed. This is somewhat similar to official implementation // at https://github.com/KULeuven-COSIC/Marvellous/blob/master/instance_generator.sage where they // use SHAKE256 with a seed to generate randomness.