Merge fixes for pfmath functions #14

Merged
0xalivecow merged 4 commits from dev into main 2024-11-15 11:50:31 +00:00
Showing only changes of commit 5dc299372a - Show all commits

View file

@ -151,6 +151,15 @@ impl Polynomial {
(Polynomial::new(quotient_coeffs), remainder) (Polynomial::new(quotient_coeffs), remainder)
} }
fn is_zero(&self) -> bool {
for field_element in &self.polynomial {
if !field_element.is_zero() {
return false;
}
}
true
}
} }
impl Clone for Polynomial { impl Clone for Polynomial {
@ -163,8 +172,10 @@ impl Clone for Polynomial {
impl Mul for Polynomial { impl Mul for Polynomial {
type Output = Self; type Output = Self;
fn mul(self, rhs: Self) -> Self::Output { fn mul(self, rhs: Self) -> Self::Output {
if self.is_zero() || rhs.is_zero() {
return Polynomial::new(vec![FieldElement::new(vec![0; 16])]);
}
let mut polynomial: Vec<FieldElement> = let mut polynomial: Vec<FieldElement> =
vec![FieldElement::new(vec![0; 16]); self.polynomial.len() + rhs.polynomial.len() - 1]; vec![FieldElement::new(vec![0; 16]); self.polynomial.len() + rhs.polynomial.len() - 1];
for i in 0..self.polynomial.len() { for i in 0..self.polynomial.len() {
@ -180,6 +191,9 @@ impl Mul for Polynomial {
impl Mul for &Polynomial { impl Mul for &Polynomial {
type Output = Polynomial; type Output = Polynomial;
fn mul(self, rhs: Self) -> Self::Output { fn mul(self, rhs: Self) -> Self::Output {
if self.is_zero() || rhs.is_zero() {
return Polynomial::new(vec![FieldElement::new(vec![0])]);
}
let mut polynomial: Vec<FieldElement> = let mut polynomial: Vec<FieldElement> =
vec![FieldElement::new(vec![0; 16]); self.polynomial.len() + rhs.polynomial.len() - 1]; vec![FieldElement::new(vec![0; 16]); self.polynomial.len() + rhs.polynomial.len() - 1];
for i in 0..self.polynomial.len() { for i in 0..self.polynomial.len() {
@ -296,6 +310,10 @@ impl FieldElement {
//eprintln!("Inverse rhs {:?}", inverse); //eprintln!("Inverse rhs {:?}", inverse);
FieldElement::new(inverse) FieldElement::new(inverse)
} }
fn is_zero(&self) -> bool {
self.field_element.iter().all(|&x| x == 0x00)
}
} }
impl Mul for FieldElement { impl Mul for FieldElement {
@ -662,6 +680,68 @@ mod tests {
); );
} }
#[test]
fn test_field_add_zero() {
let json1 = json!([
"NeverGonnaGiveYouUpAAA==",
"NeverGonnaLetYouDownAA==",
"NeverGonnaRunAroundAAA==",
"AndDesertYouAAAAAAAAAA=="
]);
let json2 = json!(["AAAAAAAAAAAAAAAAAAAAAA=="]);
let element1: Polynomial = Polynomial::from_c_array(&json1);
let element2: Polynomial = Polynomial::from_c_array(&json2);
let sum = element2 + element1;
assert_eq!(
sum.to_c_array(),
vec![
"NeverGonnaGiveYouUpAAA==",
"NeverGonnaLetYouDownAA==",
"NeverGonnaRunAroundAAA==",
"AndDesertYouAAAAAAAAAA=="
]
);
}
#[test]
fn test_field_add_zero_to_zero() {
let json1 = json!(["AAAAAAAAAAAAAAAAAAAAAA=="]);
let json2 = json!(["AAAAAAAAAAAAAAAAAAAAAA=="]);
let element1: Polynomial = Polynomial::from_c_array(&json1);
let element2: Polynomial = Polynomial::from_c_array(&json2);
let sum = element2 + element1;
assert_eq!(sum.to_c_array(), vec!["AAAAAAAAAAAAAAAAAAAAAA=="]);
}
#[test]
fn test_field_add_short_to_long() {
let json1 = json!(["AAAAAAAAAAAAAAAAAAAAAA=="]);
let json2 = json!([
"NeverGonnaGiveYouUpAAA==",
"NeverGonnaLetYouDownAA==",
"NeverGonnaRunAroundAAA==",
"AndDesertYouAAAAAAAAAA=="
]);
let element1: Polynomial = Polynomial::from_c_array(&json1);
let element2: Polynomial = Polynomial::from_c_array(&json2);
let sum = element2 + element1;
assert_eq!(
sum.to_c_array(),
vec![
"NeverGonnaGiveYouUpAAA==",
"NeverGonnaLetYouDownAA==",
"NeverGonnaRunAroundAAA==",
"AndDesertYouAAAAAAAAAA=="
]
);
}
#[test] #[test]
fn test_field_mul_01() { fn test_field_mul_01() {
let json1 = json!([ let json1 = json!([
@ -690,7 +770,26 @@ mod tests {
} }
#[test] #[test]
fn test_field_pow_01() { fn test_poly_mul_with_zero() {
let json1 = json!([
"JAAAAAAAAAAAAAAAAAAAAA==",
"wAAAAAAAAAAAAAAAAAAAAA==",
"ACAAAAAAAAAAAAAAAAAAAA=="
]);
let json2 = json!(["AAAAAAAAAAAAAAAAAAAAAA=="]);
let element1: Polynomial = Polynomial::from_c_array(&json1);
let element2: Polynomial = Polynomial::from_c_array(&json2);
//eprintln!("{:?}", element1);
let result = element1 * element2;
assert_eq!(result.to_c_array(), vec!["AAAAAAAAAAAAAAAAAAAAAA=="]);
//assert_eq!(BASE64_STANDARD.encode(product), "MoAAAAAAAAAAAAAAAAAAAA==");
}
#[test]
fn test_poly_pow_01() {
let json1 = json!([ let json1 = json!([
"JAAAAAAAAAAAAAAAAAAAAA==", "JAAAAAAAAAAAAAAAAAAAAA==",
"wAAAAAAAAAAAAAAAAAAAAA==", "wAAAAAAAAAAAAAAAAAAAAA==",
@ -792,6 +891,6 @@ mod tests {
let result = element1.pow_mod(1000, modulus); let result = element1.pow_mod(1000, modulus);
eprintln!("Result is: {:02X?}", result); eprintln!("Result is: {:02X?}", result);
assert_eq!(result.to_c_array(), vec!["XrEhmKuat+Glt5zZWtMo6g=="]); assert_eq!(result.to_c_array(), vec!["oNXl5P8xq2WpUTP92u25zg=="]);
} }
} }