fix: Add handling of pow with 0

This commit is contained in:
alivecow 2024-11-15 10:26:38 +01:00
parent 5dc299372a
commit 2a9db307d9

View file

@ -54,7 +54,9 @@ impl Polynomial {
pub fn pow(&self, mut exponent: u128) -> Polynomial { pub fn pow(&self, mut exponent: u128) -> Polynomial {
if exponent == 0 { if exponent == 0 {
return Polynomial::new(vec![FieldElement::new(vec![0])]); return Polynomial::new(vec![FieldElement::new(
polynomial_2_block(vec![0], "gcm").unwrap(),
)]);
} }
let base = self.clone(); let base = self.clone();
@ -69,6 +71,12 @@ impl Polynomial {
} }
pub fn pow_mod(mut self, mut exponent: u128, modulus: Polynomial) -> Polynomial { pub fn pow_mod(mut self, mut exponent: u128, modulus: Polynomial) -> Polynomial {
if exponent == 0 {
return Polynomial::new(vec![FieldElement::new(
polynomial_2_block(vec![0], "gcm").unwrap(),
)]);
}
let mut result: Polynomial = Polynomial::new(vec![FieldElement::new( let mut result: Polynomial = Polynomial::new(vec![FieldElement::new(
polynomial_2_block(vec![0], "gcm").unwrap(), polynomial_2_block(vec![0], "gcm").unwrap(),
)]); )]);
@ -536,15 +544,13 @@ impl ByteArray {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use base64::prelude::*;
use serde_json::json; use serde_json::json;
use std::fs;
#[test] #[test]
fn test_byte_array_shift1() { fn test_byte_array_shift1() {
let mut byte_array: ByteArray = ByteArray(vec![0x00, 0x01]); let mut byte_array: ByteArray = ByteArray(vec![0x00, 0x01]);
let shifted_array: ByteArray = ByteArray(vec![0x00, 0x02]); let shifted_array: ByteArray = ByteArray(vec![0x00, 0x02]);
byte_array.left_shift("xex"); byte_array.left_shift("xex").unwrap();
assert_eq!(byte_array.0, shifted_array.0); assert_eq!(byte_array.0, shifted_array.0);
} }
@ -553,7 +559,7 @@ mod tests {
fn test_byte_array_shift2() { fn test_byte_array_shift2() {
let mut byte_array: ByteArray = ByteArray(vec![0xFF, 0x00]); let mut byte_array: ByteArray = ByteArray(vec![0xFF, 0x00]);
let shifted_array: ByteArray = ByteArray(vec![0xFE, 0x01]); let shifted_array: ByteArray = ByteArray(vec![0xFE, 0x01]);
byte_array.left_shift("xex"); byte_array.left_shift("xex").unwrap();
assert_eq!( assert_eq!(
byte_array.0, shifted_array.0, byte_array.0, shifted_array.0,
@ -566,7 +572,7 @@ mod tests {
fn test_byte_array_shift1_gcm() { fn test_byte_array_shift1_gcm() {
let mut byte_array: ByteArray = ByteArray(vec![0xFF, 0x00]); let mut byte_array: ByteArray = ByteArray(vec![0xFF, 0x00]);
let shifted_array: ByteArray = ByteArray(vec![0x7F, 0x80]); let shifted_array: ByteArray = ByteArray(vec![0x7F, 0x80]);
byte_array.left_shift("gcm"); byte_array.left_shift("gcm").unwrap();
assert_eq!( assert_eq!(
byte_array.0, shifted_array.0, byte_array.0, shifted_array.0,
@ -579,7 +585,7 @@ mod tests {
fn test_byte_array_shift1_right_gcm() { fn test_byte_array_shift1_right_gcm() {
let mut byte_array: ByteArray = ByteArray(vec![0xFF, 0x00]); let mut byte_array: ByteArray = ByteArray(vec![0xFF, 0x00]);
let shifted_array: ByteArray = ByteArray(vec![0xFE, 0x00]); let shifted_array: ByteArray = ByteArray(vec![0xFE, 0x00]);
byte_array.right_shift("gcm"); byte_array.right_shift("gcm").unwrap();
assert_eq!( assert_eq!(
byte_array.0, shifted_array.0, byte_array.0, shifted_array.0,
@ -592,7 +598,7 @@ mod tests {
fn test_byte_array_shift_right() { fn test_byte_array_shift_right() {
let mut byte_array: ByteArray = ByteArray(vec![0x02]); let mut byte_array: ByteArray = ByteArray(vec![0x02]);
let shifted_array: ByteArray = ByteArray(vec![0x01]); let shifted_array: ByteArray = ByteArray(vec![0x01]);
byte_array.right_shift("xex"); byte_array.right_shift("xex").unwrap();
assert_eq!( assert_eq!(
byte_array.0, shifted_array.0, byte_array.0, shifted_array.0,
@ -814,6 +820,21 @@ mod tests {
//assert_eq!(BASE64_STANDARD.encode(product), "MoAAAAAAAAAAAAAAAAAAAA=="); //assert_eq!(BASE64_STANDARD.encode(product), "MoAAAAAAAAAAAAAAAAAAAA==");
} }
#[test]
fn test_poly_pow_with_zero() {
let json1 = json!([
"JAAAAAAAAAAAAAAAAAAAAA==",
"wAAAAAAAAAAAAAAAAAAAAA==",
"ACAAAAAAAAAAAAAAAAAAAA=="
]);
let element1: Polynomial = Polynomial::from_c_array(&json1);
let result = element1.pow(0);
assert_eq!(result.to_c_array(), vec!["gAAAAAAAAAAAAAAAAAAAAA=="]);
//assert_eq!(BASE64_STANDARD.encode(product), "MoAAAAAAAAAAAAAAAAAAAA==");
}
#[test] #[test]
fn test_field_pow_mod_01() { fn test_field_pow_mod_01() {
let json1 = json!([ let json1 = json!([
@ -840,6 +861,36 @@ mod tests {
//assert_eq!(BASE64_STANDARD.encode(product), "MoAAAAAAAAAAAAAAAAAAAA=="); //assert_eq!(BASE64_STANDARD.encode(product), "MoAAAAAAAAAAAAAAAAAAAA==");
} }
#[test]
fn test_field_pow_mod_with_zero() {
let json1 = json!([
"JAAAAAAAAAAAAAAAAAAAAA==",
"wAAAAAAAAAAAAAAAAAAAAA==",
"ACAAAAAAAAAAAAAAAAAAAA=="
]);
let element1: Polynomial = Polynomial::from_c_array(&json1);
let result = element1.pow(0);
assert_eq!(result.to_c_array(), vec!["gAAAAAAAAAAAAAAAAAAAAA=="]);
//assert_eq!(BASE64_STANDARD.encode(product), "MoAAAAAAAAAAAAAAAAAAAA==");
}
#[test]
fn test_field_pow_mod_10mill() {
let json1 = json!([
"JAAAAAAAAAAAAAAAAAAAAA==",
"wAAAAAAAAAAAAAAAAAAAAA==",
"ACAAAAAAAAAAAAAAAAAAAA=="
]);
let element1: Polynomial = Polynomial::from_c_array(&json1);
let result = element1.pow(10000000);
assert_eq!(result.to_c_array(), vec!["gAAAAAAAAAAAAAAAAAAAAA=="]);
//assert_eq!(BASE64_STANDARD.encode(product), "MoAAAAAAAAAAAAAAAAAAAA==");
}
#[test] #[test]
fn test_poly_div_01() { fn test_poly_div_01() {
let element1 = let element1 =