From 2a9db307d9f79c480bef7885bacf38f18659bd54 Mon Sep 17 00:00:00 2001 From: alivecow Date: Fri, 15 Nov 2024 10:26:38 +0100 Subject: [PATCH] fix: Add handling of pow with 0 --- src/utils/field.rs | 67 ++++++++++++++++++++++++++++++++++++++++------ 1 file changed, 59 insertions(+), 8 deletions(-) diff --git a/src/utils/field.rs b/src/utils/field.rs index 758ea75..5f47fce 100644 --- a/src/utils/field.rs +++ b/src/utils/field.rs @@ -54,7 +54,9 @@ impl Polynomial { pub fn pow(&self, mut exponent: u128) -> Polynomial { if exponent == 0 { - return Polynomial::new(vec![FieldElement::new(vec![0])]); + return Polynomial::new(vec![FieldElement::new( + polynomial_2_block(vec![0], "gcm").unwrap(), + )]); } let base = self.clone(); @@ -69,6 +71,12 @@ impl Polynomial { } pub fn pow_mod(mut self, mut exponent: u128, modulus: Polynomial) -> Polynomial { + if exponent == 0 { + return Polynomial::new(vec![FieldElement::new( + polynomial_2_block(vec![0], "gcm").unwrap(), + )]); + } + let mut result: Polynomial = Polynomial::new(vec![FieldElement::new( polynomial_2_block(vec![0], "gcm").unwrap(), )]); @@ -536,15 +544,13 @@ impl ByteArray { #[cfg(test)] mod tests { use super::*; - use base64::prelude::*; use serde_json::json; - use std::fs; #[test] fn test_byte_array_shift1() { let mut byte_array: ByteArray = ByteArray(vec![0x00, 0x01]); let shifted_array: ByteArray = ByteArray(vec![0x00, 0x02]); - byte_array.left_shift("xex"); + byte_array.left_shift("xex").unwrap(); assert_eq!(byte_array.0, shifted_array.0); } @@ -553,7 +559,7 @@ mod tests { fn test_byte_array_shift2() { let mut byte_array: ByteArray = ByteArray(vec![0xFF, 0x00]); let shifted_array: ByteArray = ByteArray(vec![0xFE, 0x01]); - byte_array.left_shift("xex"); + byte_array.left_shift("xex").unwrap(); assert_eq!( byte_array.0, shifted_array.0, @@ -566,7 +572,7 @@ mod tests { fn test_byte_array_shift1_gcm() { let mut byte_array: ByteArray = ByteArray(vec![0xFF, 0x00]); let shifted_array: ByteArray = ByteArray(vec![0x7F, 0x80]); - byte_array.left_shift("gcm"); + byte_array.left_shift("gcm").unwrap(); assert_eq!( byte_array.0, shifted_array.0, @@ -579,7 +585,7 @@ mod tests { fn test_byte_array_shift1_right_gcm() { let mut byte_array: ByteArray = ByteArray(vec![0xFF, 0x00]); let shifted_array: ByteArray = ByteArray(vec![0xFE, 0x00]); - byte_array.right_shift("gcm"); + byte_array.right_shift("gcm").unwrap(); assert_eq!( byte_array.0, shifted_array.0, @@ -592,7 +598,7 @@ mod tests { fn test_byte_array_shift_right() { let mut byte_array: ByteArray = ByteArray(vec![0x02]); let shifted_array: ByteArray = ByteArray(vec![0x01]); - byte_array.right_shift("xex"); + byte_array.right_shift("xex").unwrap(); assert_eq!( byte_array.0, shifted_array.0, @@ -814,6 +820,21 @@ mod tests { //assert_eq!(BASE64_STANDARD.encode(product), "MoAAAAAAAAAAAAAAAAAAAA=="); } + #[test] + fn test_poly_pow_with_zero() { + let json1 = json!([ + "JAAAAAAAAAAAAAAAAAAAAA==", + "wAAAAAAAAAAAAAAAAAAAAA==", + "ACAAAAAAAAAAAAAAAAAAAA==" + ]); + let element1: Polynomial = Polynomial::from_c_array(&json1); + + let result = element1.pow(0); + + assert_eq!(result.to_c_array(), vec!["gAAAAAAAAAAAAAAAAAAAAA=="]); + //assert_eq!(BASE64_STANDARD.encode(product), "MoAAAAAAAAAAAAAAAAAAAA=="); + } + #[test] fn test_field_pow_mod_01() { let json1 = json!([ @@ -840,6 +861,36 @@ mod tests { //assert_eq!(BASE64_STANDARD.encode(product), "MoAAAAAAAAAAAAAAAAAAAA=="); } + #[test] + fn test_field_pow_mod_with_zero() { + let json1 = json!([ + "JAAAAAAAAAAAAAAAAAAAAA==", + "wAAAAAAAAAAAAAAAAAAAAA==", + "ACAAAAAAAAAAAAAAAAAAAA==" + ]); + let element1: Polynomial = Polynomial::from_c_array(&json1); + + let result = element1.pow(0); + + assert_eq!(result.to_c_array(), vec!["gAAAAAAAAAAAAAAAAAAAAA=="]); + //assert_eq!(BASE64_STANDARD.encode(product), "MoAAAAAAAAAAAAAAAAAAAA=="); + } + + #[test] + fn test_field_pow_mod_10mill() { + let json1 = json!([ + "JAAAAAAAAAAAAAAAAAAAAA==", + "wAAAAAAAAAAAAAAAAAAAAA==", + "ACAAAAAAAAAAAAAAAAAAAA==" + ]); + let element1: Polynomial = Polynomial::from_c_array(&json1); + + let result = element1.pow(10000000); + + assert_eq!(result.to_c_array(), vec!["gAAAAAAAAAAAAAAAAAAAAA=="]); + //assert_eq!(BASE64_STANDARD.encode(product), "MoAAAAAAAAAAAAAAAAAAAA=="); + } + #[test] fn test_poly_div_01() { let element1 =