diff --git a/src/utils/field.rs b/src/utils/field.rs index 82234da..184b2c6 100644 --- a/src/utils/field.rs +++ b/src/utils/field.rs @@ -36,21 +36,46 @@ impl FieldElement { BASE64_STANDARD.encode(&self.field_element) } - pub fn pow(&self, mut exponent: u128) -> FieldElement { + pub fn pow(mut self, mut exponent: u128) -> FieldElement { + let mut result: FieldElement = + FieldElement::new(polynomial_2_block(vec![0], "gcm").unwrap()); + + if exponent == 1 { + eprintln!("special case 1: {:02X?}", self.clone()); + + return self; + } + if exponent == 0 { - // Return polynomial with coefficient 1 - return FieldElement::new(vec![1]); + let result = FieldElement::new(polynomial_2_block(vec![0], "gcm").unwrap()); + + eprintln!("Returned value is: {:02X?}", result); + return result; } - let base = self.clone(); - let mut result = base.clone(); - exponent -= 1; // Subtract 1 because we already set result to base - + //eprintln!("Initial result: {:?}", result); while exponent > 0 { - result = result * base.clone(); - exponent -= 1; + //eprintln!("Current exponent: {:02X}", exponent); + if exponent & 1 == 1 { + let temp = &self * &result; + eprintln!("Mult"); + eprintln!("After mod: {:?}", temp); + + result = temp + } + let temp_square = &self * &self; + eprintln!("Square"); + + eprintln!("After squaring: {:?}", temp_square); + self = temp_square; + //eprintln!("After mod: {:?}", self); + exponent >>= 1; } + eprintln!("result in powmod before reduction: {:02X?}", result); + + eprintln!("result in powmod after reduction: {:02X?}", result); + result } diff --git a/src/utils/poly.rs b/src/utils/poly.rs index e624361..642be5c 100644 --- a/src/utils/poly.rs +++ b/src/utils/poly.rs @@ -56,19 +56,63 @@ impl Polynomial { output } - pub fn pow(&self, mut exponent: u128) -> Polynomial { - if exponent == 0 { - return Polynomial::new(vec![FieldElement::new( - polynomial_2_block(vec![0], "gcm").unwrap(), - )]); + pub fn pow(mut self, mut exponent: u128) -> Polynomial { + let mut result: Polynomial = Polynomial::new(vec![FieldElement::new( + polynomial_2_block(vec![0], "gcm").unwrap(), + )]); + + if exponent == 1 { + eprintln!("special case 1: {:02X?}", self.clone()); + + return self; } - let base = self.clone(); - let mut result = base.clone(); - exponent -= 1; + if exponent == 0 { + let result = Polynomial::new(vec![FieldElement::new( + polynomial_2_block(vec![0], "gcm").unwrap(), + )]); + + eprintln!("Returned value is: {:02X?}", result); + return result; + } + + //eprintln!("Initial result: {:?}", result); while exponent > 0 { - result = result * base.clone(); - exponent -= 1; + //eprintln!("Current exponent: {:02X}", exponent); + if exponent & 1 == 1 { + let temp = &self * &result; + eprintln!("Mult"); + eprintln!("After mod: {:?}", temp); + + result = temp + } + let temp_square = &self * &self; + eprintln!("Square"); + + eprintln!("After squaring: {:?}", temp_square); + self = temp_square; + //eprintln!("After mod: {:?}", self); + exponent >>= 1; + } + + eprintln!("result in powmod before reduction: {:02X?}", result); + + while !result.polynomial.is_empty() + && result + .polynomial + .last() + .unwrap() + .as_ref() + .iter() + .all(|&x| x == 0) + { + result.polynomial.pop(); + } + + eprintln!("result in powmod after reduction: {:02X?}", result); + + if result.is_empty() { + result = Polynomial::new(vec![FieldElement::new(vec![0; 16])]); } result @@ -237,6 +281,18 @@ impl Polynomial { self } + + fn sqrt(self) -> Self { + let mut result = vec![]; + + for (position, element) in self.polynomial.iter().enumerate() { + if position % 2 == 0 { + result.push(element.clone().pow(2u128.pow(127))); + } + } + + Polynomial::new(result) + } } impl Clone for Polynomial { @@ -1112,4 +1168,29 @@ mod tests { assert_eq!(json!(result.to_c_array()), expected); } + + #[test] + fn test_poly_poly_sqrt() { + let json1 = json!([ + "5TxUxLHO1lHE/rSFquKIAg==", + "AAAAAAAAAAAAAAAAAAAAAA==", + "0DEUJYdHlmd4X7nzzIdcCA==", + "AAAAAAAAAAAAAAAAAAAAAA==", + "PKUa1+JHTxHE8y3LbuKIIA==", + "AAAAAAAAAAAAAAAAAAAAAA==", + "Ds96KiAKKoigKoiKiiKAiA==" + ]); + let expected = json!([ + "NeverGonnaGiveYouUpAAA==", + "NeverGonnaLetYouDownAA==", + "NeverGonnaRunAroundAAA==", + "AndDesertYouAAAAAAAAAA==" + ]); + let element1: Polynomial = Polynomial::from_c_array(&json1); + eprintln!("Starting poly sqrt"); + + let result = element1.sqrt(); + + assert_eq!(json!(result.to_c_array()), expected); + } }