From 5dc299372a742b2ba76302574c36ca2f6d5f9ad4 Mon Sep 17 00:00:00 2001 From: alivecow Date: Fri, 15 Nov 2024 10:13:05 +0100 Subject: [PATCH] fix: Add handling of zero mulitplication for polynomials --- src/utils/field.rs | 105 +++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 102 insertions(+), 3 deletions(-) diff --git a/src/utils/field.rs b/src/utils/field.rs index 14d0c46..758ea75 100644 --- a/src/utils/field.rs +++ b/src/utils/field.rs @@ -151,6 +151,15 @@ impl Polynomial { (Polynomial::new(quotient_coeffs), remainder) } + + fn is_zero(&self) -> bool { + for field_element in &self.polynomial { + if !field_element.is_zero() { + return false; + } + } + true + } } impl Clone for Polynomial { @@ -163,8 +172,10 @@ impl Clone for Polynomial { impl Mul for Polynomial { type Output = Self; - fn mul(self, rhs: Self) -> Self::Output { + if self.is_zero() || rhs.is_zero() { + return Polynomial::new(vec![FieldElement::new(vec![0; 16])]); + } let mut polynomial: Vec = vec![FieldElement::new(vec![0; 16]); self.polynomial.len() + rhs.polynomial.len() - 1]; for i in 0..self.polynomial.len() { @@ -180,6 +191,9 @@ impl Mul for Polynomial { impl Mul for &Polynomial { type Output = Polynomial; fn mul(self, rhs: Self) -> Self::Output { + if self.is_zero() || rhs.is_zero() { + return Polynomial::new(vec![FieldElement::new(vec![0])]); + } let mut polynomial: Vec = vec![FieldElement::new(vec![0; 16]); self.polynomial.len() + rhs.polynomial.len() - 1]; for i in 0..self.polynomial.len() { @@ -296,6 +310,10 @@ impl FieldElement { //eprintln!("Inverse rhs {:?}", inverse); FieldElement::new(inverse) } + + fn is_zero(&self) -> bool { + self.field_element.iter().all(|&x| x == 0x00) + } } impl Mul for FieldElement { @@ -662,6 +680,68 @@ mod tests { ); } + #[test] + fn test_field_add_zero() { + let json1 = json!([ + "NeverGonnaGiveYouUpAAA==", + "NeverGonnaLetYouDownAA==", + "NeverGonnaRunAroundAAA==", + "AndDesertYouAAAAAAAAAA==" + ]); + let json2 = json!(["AAAAAAAAAAAAAAAAAAAAAA=="]); + let element1: Polynomial = Polynomial::from_c_array(&json1); + let element2: Polynomial = Polynomial::from_c_array(&json2); + + let sum = element2 + element1; + + assert_eq!( + sum.to_c_array(), + vec![ + "NeverGonnaGiveYouUpAAA==", + "NeverGonnaLetYouDownAA==", + "NeverGonnaRunAroundAAA==", + "AndDesertYouAAAAAAAAAA==" + ] + ); + } + + #[test] + fn test_field_add_zero_to_zero() { + let json1 = json!(["AAAAAAAAAAAAAAAAAAAAAA=="]); + let json2 = json!(["AAAAAAAAAAAAAAAAAAAAAA=="]); + let element1: Polynomial = Polynomial::from_c_array(&json1); + let element2: Polynomial = Polynomial::from_c_array(&json2); + + let sum = element2 + element1; + + assert_eq!(sum.to_c_array(), vec!["AAAAAAAAAAAAAAAAAAAAAA=="]); + } + + #[test] + fn test_field_add_short_to_long() { + let json1 = json!(["AAAAAAAAAAAAAAAAAAAAAA=="]); + let json2 = json!([ + "NeverGonnaGiveYouUpAAA==", + "NeverGonnaLetYouDownAA==", + "NeverGonnaRunAroundAAA==", + "AndDesertYouAAAAAAAAAA==" + ]); + let element1: Polynomial = Polynomial::from_c_array(&json1); + let element2: Polynomial = Polynomial::from_c_array(&json2); + + let sum = element2 + element1; + + assert_eq!( + sum.to_c_array(), + vec![ + "NeverGonnaGiveYouUpAAA==", + "NeverGonnaLetYouDownAA==", + "NeverGonnaRunAroundAAA==", + "AndDesertYouAAAAAAAAAA==" + ] + ); + } + #[test] fn test_field_mul_01() { let json1 = json!([ @@ -690,7 +770,26 @@ mod tests { } #[test] - fn test_field_pow_01() { + fn test_poly_mul_with_zero() { + let json1 = json!([ + "JAAAAAAAAAAAAAAAAAAAAA==", + "wAAAAAAAAAAAAAAAAAAAAA==", + "ACAAAAAAAAAAAAAAAAAAAA==" + ]); + let json2 = json!(["AAAAAAAAAAAAAAAAAAAAAA=="]); + let element1: Polynomial = Polynomial::from_c_array(&json1); + let element2: Polynomial = Polynomial::from_c_array(&json2); + + //eprintln!("{:?}", element1); + + let result = element1 * element2; + + assert_eq!(result.to_c_array(), vec!["AAAAAAAAAAAAAAAAAAAAAA=="]); + //assert_eq!(BASE64_STANDARD.encode(product), "MoAAAAAAAAAAAAAAAAAAAA=="); + } + + #[test] + fn test_poly_pow_01() { let json1 = json!([ "JAAAAAAAAAAAAAAAAAAAAA==", "wAAAAAAAAAAAAAAAAAAAAA==", @@ -792,6 +891,6 @@ mod tests { let result = element1.pow_mod(1000, modulus); eprintln!("Result is: {:02X?}", result); - assert_eq!(result.to_c_array(), vec!["XrEhmKuat+Glt5zZWtMo6g=="]); + assert_eq!(result.to_c_array(), vec!["oNXl5P8xq2WpUTP92u25zg=="]); } }